Repository: slackhq/nebula Branch: master Commit: 1aa1a0476f6b Files: 216 Total size: 1.3 MB Directory structure: gitextract_aqr65ra4/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug-report.yml │ │ └── config.yml │ ├── dependabot.yml │ ├── pull_request_template.md │ └── workflows/ │ ├── gofmt.yml │ ├── release.yml │ ├── smoke/ │ │ ├── .gitignore │ │ ├── Dockerfile │ │ ├── build-relay.sh │ │ ├── build.sh │ │ ├── genconfig.sh │ │ ├── smoke-relay.sh │ │ ├── smoke-vagrant.sh │ │ ├── smoke.sh │ │ ├── vagrant-freebsd-amd64/ │ │ │ └── Vagrantfile │ │ ├── vagrant-linux-386/ │ │ │ └── Vagrantfile │ │ ├── vagrant-linux-amd64-ipv6disable/ │ │ │ └── Vagrantfile │ │ ├── vagrant-netbsd-amd64/ │ │ │ └── Vagrantfile │ │ └── vagrant-openbsd-amd64/ │ │ └── Vagrantfile │ ├── smoke-extra.yml │ ├── smoke.yml │ └── test.yml ├── .gitignore ├── .golangci.yaml ├── AUTHORS ├── CHANGELOG.md ├── CODEOWNERS ├── LICENSE ├── LOGGING.md ├── Makefile ├── README.md ├── SECURITY.md ├── allow_list.go ├── allow_list_test.go ├── bits.go ├── bits_test.go ├── boring.go ├── calculated_remote.go ├── calculated_remote_test.go ├── cert/ │ ├── Makefile │ ├── README.md │ ├── asn1.go │ ├── ca_pool.go │ ├── ca_pool_test.go │ ├── cert.go │ ├── cert_v1.go │ ├── cert_v1.pb.go │ ├── cert_v1.proto │ ├── cert_v1_test.go │ ├── cert_v2.asn1 │ ├── cert_v2.go │ ├── cert_v2_test.go │ ├── crypto.go │ ├── crypto_test.go │ ├── errors.go │ ├── helper_test.go │ ├── p256/ │ │ ├── p256.go │ │ └── p256_test.go │ ├── pem.go │ ├── pem_test.go │ ├── sign.go │ └── sign_test.go ├── cert_test/ │ └── cert.go ├── cmd/ │ ├── nebula/ │ │ ├── main.go │ │ ├── notify_linux.go │ │ └── notify_notlinux.go │ ├── nebula-cert/ │ │ ├── ca.go │ │ ├── ca_test.go │ │ ├── keygen.go │ │ ├── keygen_test.go │ │ ├── main.go │ │ ├── main_test.go │ │ ├── p11_cgo.go │ │ ├── p11_stub.go │ │ ├── passwords.go │ │ ├── passwords_test.go │ │ ├── print.go │ │ ├── print_test.go │ │ ├── sign.go │ │ ├── sign_test.go │ │ ├── test_darwin.go │ │ ├── test_linux.go │ │ ├── test_windows.go │ │ ├── verify.go │ │ └── verify_test.go │ └── nebula-service/ │ ├── logs_generic.go │ ├── logs_windows.go │ ├── main.go │ └── service.go ├── config/ │ ├── config.go │ └── config_test.go ├── connection_manager.go ├── connection_manager_test.go ├── connection_state.go ├── control.go ├── control_test.go ├── control_tester.go ├── dist/ │ ├── windows/ │ │ └── wintun/ │ │ ├── LICENSE.txt │ │ ├── README.md │ │ └── include/ │ │ └── wintun.h │ └── wireshark/ │ └── nebula.lua ├── dns_server.go ├── dns_server_test.go ├── docker/ │ ├── Dockerfile │ └── README.md ├── e2e/ │ ├── doc.go │ ├── handshakes_test.go │ ├── helpers_test.go │ ├── router/ │ │ ├── doc.go │ │ ├── hostmap.go │ │ └── router.go │ └── tunnels_test.go ├── examples/ │ ├── config.yml │ ├── go_service/ │ │ └── main.go │ └── service_scripts/ │ ├── nebula.init.d.sh │ ├── nebula.open-rc │ ├── nebula.plist │ └── nebula.service ├── firewall/ │ ├── cache.go │ └── packet.go ├── firewall.go ├── firewall_test.go ├── go.mod ├── go.sum ├── handshake_ix.go ├── handshake_manager.go ├── handshake_manager_test.go ├── header/ │ ├── header.go │ └── header_test.go ├── hostmap.go ├── hostmap_test.go ├── hostmap_tester.go ├── inside.go ├── inside_bsd.go ├── inside_generic.go ├── interface.go ├── iputil/ │ ├── packet.go │ └── packet_test.go ├── lighthouse.go ├── lighthouse_test.go ├── logger.go ├── main.go ├── message_metrics.go ├── nebula.pb.go ├── nebula.proto ├── noise.go ├── noiseutil/ │ ├── boring.go │ ├── boring_test.go │ ├── nist.go │ ├── notboring.go │ ├── notboring_test.go │ └── pkcs11.go ├── notboring.go ├── outside.go ├── outside_test.go ├── overlay/ │ ├── device.go │ ├── route.go │ ├── route_test.go │ ├── tun.go │ ├── tun_android.go │ ├── tun_darwin.go │ ├── tun_disabled.go │ ├── tun_freebsd.go │ ├── tun_ios.go │ ├── tun_linux.go │ ├── tun_linux_test.go │ ├── tun_netbsd.go │ ├── tun_notwin.go │ ├── tun_openbsd.go │ ├── tun_tester.go │ ├── tun_windows.go │ └── user.go ├── pkclient/ │ ├── pkclient.go │ ├── pkclient_cgo.go │ └── pkclient_stub.go ├── pki.go ├── punchy.go ├── punchy_test.go ├── relay_manager.go ├── remote_list.go ├── remote_list_test.go ├── routing/ │ ├── balance.go │ ├── balance_test.go │ ├── gateway.go │ └── gateway_test.go ├── service/ │ ├── listener.go │ ├── service.go │ └── service_test.go ├── ssh.go ├── sshd/ │ ├── command.go │ ├── server.go │ ├── session.go │ └── writer.go ├── stats.go ├── test/ │ ├── assert.go │ ├── logger.go │ └── tun.go ├── timeout.go ├── timeout_test.go ├── udp/ │ ├── conn.go │ ├── errors.go │ ├── udp_android.go │ ├── udp_bsd.go │ ├── udp_darwin.go │ ├── udp_generic.go │ ├── udp_linux.go │ ├── udp_linux_32.go │ ├── udp_linux_64.go │ ├── udp_netbsd.go │ ├── udp_rio_windows.go │ ├── udp_tester.go │ └── udp_windows.go ├── util/ │ ├── error.go │ └── error_test.go └── wintun/ ├── device.go └── tun.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug-report.yml ================================================ name: "\U0001F41B Bug Report" description: Report an issue or possible bug title: "\U0001F41B BUG:" labels: [] assignees: [] body: - type: markdown attributes: value: | ### Thank you for taking the time to file a bug report! Please fill out this form as completely as possible. - type: input id: version attributes: label: What version of `nebula` are you using? (`nebula -version`) placeholder: 0.0.0 validations: required: true - type: input id: os attributes: label: What operating system are you using? description: iOS and Android specific issues belong in the [mobile_nebula](https://github.com/DefinedNet/mobile_nebula) repo. placeholder: Linux, Mac, Windows validations: required: true - type: textarea id: description attributes: label: Describe the Bug description: A clear and concise description of what the bug is. validations: required: true - type: textarea id: logs attributes: label: Logs from affected hosts description: | Please provide logs from ALL affected hosts during the time of the issue. If you do not provide logs we will be unable to assist you! [Learn how to find Nebula logs here.](https://nebula.defined.net/docs/guides/viewing-nebula-logs/) Improve formatting by using ``` at the beginning and end of each log block. value: | ``` ``` validations: required: true - type: textarea id: configs attributes: label: Config files from affected hosts description: | Provide config files for all affected hosts. Improve formatting by using ``` at the beginning and end of each config file. value: | ``` ``` validations: required: true ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: true contact_links: - name: 💨 Performance Issues url: https://github.com/slackhq/nebula/discussions/new/choose about: 'We ask that you create a discussion instead of an issue for performance-related questions. This allows us to have a more open conversation about the issue and helps us to better understand the problem.' - name: 📄 Documentation Issues url: https://github.com/definednet/nebula-docs about: "If you've found an issue with the website documentation, please file it in the nebula-docs repository." - name: 📱 Mobile Nebula Issues url: https://github.com/definednet/mobile_nebula about: "If you're using the mobile Nebula app and have found an issue, please file it in the mobile_nebula repository." - name: 📘 Documentation url: https://nebula.defined.net/docs/ about: 'The documentation is the best place to start if you are new to Nebula.' - name: 💁 Support/Chat url: https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA about: 'For faster support, join us on Slack for assistance!' ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" - package-ecosystem: "gomod" directory: "/" schedule: interval: "weekly" groups: golang-x-dependencies: patterns: - "golang.org/x/*" zx2c4-dependencies: patterns: - "golang.zx2c4.com/*" protobuf-dependencies: patterns: - "github.com/golang/protobuf" - "google.golang.org/protobuf" ================================================ FILE: .github/pull_request_template.md ================================================ ================================================ FILE: .github/workflows/gofmt.yml ================================================ name: gofmt on: push: branches: - master pull_request: paths: - '.github/workflows/gofmt.yml' - '**.go' jobs: gofmt: name: Run gofmt runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest - name: gofmt run: | if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ] then find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d exit 1 fi ================================================ FILE: .github/workflows/release.yml ================================================ on: push: tags: - 'v[0-9]+.[0-9]+.[0-9]*' name: Create release and upload binaries jobs: build-linux: name: Build Linux/BSD All runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Build run: | make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" release-linux release-freebsd release-openbsd release-netbsd mkdir release mv build/*.tar.gz release - name: Upload artifacts uses: actions/upload-artifact@v6 with: name: linux-latest path: release build-windows: name: Build Windows runs-on: windows-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Build run: | echo $Env:GITHUB_REF.Substring(11) mkdir build\windows-amd64 $Env:GOARCH = "amd64" go build -trimpath -ldflags "-X main.Build=$($Env:GITHUB_REF.Substring(11))" -o build\windows-amd64\nebula.exe ./cmd/nebula-service go build -trimpath -ldflags "-X main.Build=$($Env:GITHUB_REF.Substring(11))" -o build\windows-amd64\nebula-cert.exe ./cmd/nebula-cert mkdir build\windows-arm64 $Env:GOARCH = "arm64" go build -trimpath -ldflags "-X main.Build=$($Env:GITHUB_REF.Substring(11))" -o build\windows-arm64\nebula.exe ./cmd/nebula-service go build -trimpath -ldflags "-X main.Build=$($Env:GITHUB_REF.Substring(11))" -o build\windows-arm64\nebula-cert.exe ./cmd/nebula-cert mkdir build\dist\windows mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts uses: actions/upload-artifact@v6 with: name: windows-latest path: build build-darwin: name: Build Universal Darwin env: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} runs-on: macos-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' uses: Apple-Actions/import-codesign-certs@v6 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} - name: Build, sign, and notarize env: AC_USERNAME: ${{ secrets.AC_USERNAME }} AC_PASSWORD: ${{ secrets.AC_PASSWORD }} run: | rm -rf release mkdir release make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" service build/darwin-amd64/nebula build/darwin-amd64/nebula-cert make BUILD_NUMBER="${GITHUB_REF#refs/tags/v}" service build/darwin-arm64/nebula build/darwin-arm64/nebula-cert lipo -create -output ./release/nebula ./build/darwin-amd64/nebula ./build/darwin-arm64/nebula lipo -create -output ./release/nebula-cert ./build/darwin-amd64/nebula-cert ./build/darwin-arm64/nebula-cert if [ -n "$AC_USERNAME" ]; then codesign -s "10BC1FDDEB6CE753550156C0669109FAC49E4D1E" -f -v --timestamp --options=runtime -i "net.defined.nebula" ./release/nebula codesign -s "10BC1FDDEB6CE753550156C0669109FAC49E4D1E" -f -v --timestamp --options=runtime -i "net.defined.nebula-cert" ./release/nebula-cert fi zip -j release/nebula-darwin.zip release/nebula-cert release/nebula if [ -n "$AC_USERNAME" ]; then xcrun notarytool submit ./release/nebula-darwin.zip --team-id "576H3XS7FP" --apple-id "$AC_USERNAME" --password "$AC_PASSWORD" --wait fi - name: Upload artifacts uses: actions/upload-artifact@v6 with: name: darwin-latest path: ./release/* build-docker: name: Create and Upload Docker Images # Technically we only need build-linux to succeed, but if any platforms fail we'll # want to investigate and restart the build needs: [build-linux, build-darwin, build-windows] runs-on: ubuntu-latest env: HAS_DOCKER_CREDS: ${{ vars.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }} # XXX It's not possible to write a conditional here, so instead we do it on every step #if: ${{ env.HAS_DOCKER_CREDS == 'true' }} steps: # Be sure to checkout the code before downloading artifacts, or they will # be overwritten - name: Checkout code if: ${{ env.HAS_DOCKER_CREDS == 'true' }} uses: actions/checkout@v6 - name: Download artifacts if: ${{ env.HAS_DOCKER_CREDS == 'true' }} uses: actions/download-artifact@v7 with: name: linux-latest path: artifacts - name: Login to Docker Hub if: ${{ env.HAS_DOCKER_CREDS == 'true' }} uses: docker/login-action@v3 with: username: ${{ vars.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Set up Docker Buildx if: ${{ env.HAS_DOCKER_CREDS == 'true' }} uses: docker/setup-buildx-action@v3 - name: Build and push images if: ${{ env.HAS_DOCKER_CREDS == 'true' }} env: DOCKER_IMAGE_REPO: ${{ vars.DOCKER_IMAGE_REPO || 'nebulaoss/nebula' }} DOCKER_IMAGE_TAG: ${{ vars.DOCKER_IMAGE_TAG || 'latest' }} run: | mkdir -p build/linux-{amd64,arm64} tar -zxvf artifacts/nebula-linux-amd64.tar.gz -C build/linux-amd64/ tar -zxvf artifacts/nebula-linux-arm64.tar.gz -C build/linux-arm64/ docker buildx build . --push -f docker/Dockerfile --platform linux/amd64,linux/arm64 --tag "${DOCKER_IMAGE_REPO}:${DOCKER_IMAGE_TAG}" --tag "${DOCKER_IMAGE_REPO}:${GITHUB_REF#refs/tags/v}" release: name: Create and Upload Release needs: [build-linux, build-darwin, build-windows] runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Download artifacts uses: actions/download-artifact@v7 with: path: artifacts - name: Zip Windows run: | cd artifacts/windows-latest cp windows-amd64/* . zip -r nebula-windows-amd64.zip nebula.exe nebula-cert.exe dist cp windows-arm64/* . zip -r nebula-windows-arm64.zip nebula.exe nebula-cert.exe dist - name: Create sha256sum run: | cd artifacts for dir in linux-latest darwin-latest windows-latest do ( cd $dir if [ "$dir" = windows-latest ] then sha256sum lighthouse1.yml <host2.yml <host3.yml HOST="host4" ../genconfig.sh >host4.yml <lighthouse1.yml HOST="host2" \ LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ ../genconfig.sh >host2.yml HOST="host3" \ LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host3.yml HOST="host4" \ LIGHTHOUSES="192.168.100.1 $NET.2:4242" \ OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \ ../genconfig.sh >host4.yml ../../../../nebula-cert ca -curve "${CURVE:-25519}" -name "Smoke Test" ../../../../nebula-cert sign -name "lighthouse1" -groups "lighthouse,lighthouse1" -ip "192.168.100.1/24" ../../../../nebula-cert sign -name "host2" -groups "host,host2" -ip "192.168.100.2/24" ../../../../nebula-cert sign -name "host3" -groups "host,host3" -ip "192.168.100.3/24" ../../../../nebula-cert sign -name "host4" -groups "host,host4" -ip "192.168.100.4/24" ) docker build -t "nebula:${NAME:-smoke}" . ================================================ FILE: .github/workflows/smoke/genconfig.sh ================================================ #!/bin/sh set -e FIREWALL_ALL='[{"port": "any", "proto": "any", "host": "any"}]' if [ "$STATIC_HOSTS" ] || [ "$LIGHTHOUSES" ] then echo "static_host_map:" echo "$STATIC_HOSTS" | while read -r NEBULA_IP STATIC do [ -z "$NEBULA_IP" ] || echo " '$NEBULA_IP': ['$STATIC']" done echo "$LIGHTHOUSES" | while read -r NEBULA_IP STATIC do [ -z "$NEBULA_IP" ] || echo " '$NEBULA_IP': ['$STATIC']" done echo fi lighthouse_hosts() { if [ "$LIGHTHOUSES" ] then echo echo "$LIGHTHOUSES" | while read -r NEBULA_IP STATIC do echo " - '$NEBULA_IP'" done else echo "[]" fi } cat <&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm nebula:smoke-relay -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 set +x echo echo " *** Testing ping from lighthouse1" echo set -x docker exec lighthouse1 ping -c1 192.168.100.2 docker exec lighthouse1 ping -c1 192.168.100.3 docker exec lighthouse1 ping -c1 192.168.100.4 set +x echo echo " *** Testing ping from host2" echo set -x docker exec host2 ping -c1 192.168.100.1 # Should fail because no relay configured in this direction ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 ! docker exec host2 ping -c1 192.168.100.4 -w5 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x docker exec host3 ping -c1 192.168.100.1 docker exec host3 ping -c1 192.168.100.2 docker exec host3 ping -c1 192.168.100.4 set +x echo echo " *** Testing ping from host4" echo set -x docker exec host4 ping -c1 192.168.100.1 # Should fail because relays not allowed ! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 docker exec host4 ping -c1 192.168.100.3 docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' sleep 5 if [ "$(jobs -r)" ] then echo "nebula still running after SIGTERM sent" >&2 exit 1 fi ================================================ FILE: .github/workflows/smoke/smoke-vagrant.sh ================================================ #!/bin/bash set -e -x set -o pipefail export VAGRANT_CWD="$PWD/vagrant-$1" mkdir -p logs cleanup() { echo echo " *** cleanup" echo set +e if [ "$(jobs -r)" ] then docker kill lighthouse1 host2 fi vagrant destroy -f } trap cleanup EXIT CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test vagrant up vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 15 # grab tcpdump pcaps for debugging docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap & # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap & #docker exec host2 ncat -nklv 0.0.0.0 2000 & #vagrant ssh -c "ncat -nklv 0.0.0.0 2000" & #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" & set +x echo echo " *** Testing ping from lighthouse1" echo set -x docker exec lighthouse1 ping -c1 192.168.100.2 docker exec lighthouse1 ping -c1 192.168.100.3 set +x echo echo " *** Testing ping from host2" echo set -x docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 #set +x #echo #echo " *** Testing ncat from host2" #echo #set -x # Should fail because not allowed by host3 inbound firewall #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x vagrant ssh -c "ping -c1 192.168.100.1" -- -T vagrant ssh -c "ping -c1 192.168.100.2" -- -T #set +x #echo #echo " *** Testing ncat from host3" #echo #set -x #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000" #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2 vagrant ssh -c "sudo xargs kill &2 exit 1 fi ================================================ FILE: .github/workflows/smoke/smoke.sh ================================================ #!/bin/bash set -e -x set -o pipefail mkdir -p logs cleanup() { echo echo " *** cleanup" echo set +e if [ "$(jobs -r)" ] then docker kill lighthouse1 host2 host3 host4 fi } trap cleanup EXIT CONTAINER="nebula:${NAME:-smoke}" docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test docker run --name host2 --rm "$CONTAINER" -config host2.yml -test docker run --name host3 --rm "$CONTAINER" -config host3.yml -test docker run --name host4 --rm "$CONTAINER" -config host4.yml -test docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/ [lighthouse1] /' & sleep 1 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/ [host2] /' & sleep 1 docker run --name host3 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host3.yml 2>&1 | tee logs/host3 | sed -u 's/^/ [host3] /' & sleep 1 docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host4.yml 2>&1 | tee logs/host4 | sed -u 's/^/ [host4] /' & sleep 1 # grab tcpdump pcaps for debugging docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & docker exec host2 ncat -nklv 0.0.0.0 2000 & docker exec host3 ncat -nklv 0.0.0.0 2000 & docker exec host4 ncat -nkluv 0.0.0.0 4000 & docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & set +x echo echo " *** Testing ping from lighthouse1" echo set -x docker exec lighthouse1 ping -c1 192.168.100.2 docker exec lighthouse1 ping -c1 192.168.100.3 set +x echo echo " *** Testing ping from host2" echo set -x docker exec host2 ping -c1 192.168.100.1 # Should fail because not allowed by host3 inbound firewall ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1 set +x echo echo " *** Testing ncat from host2" echo set -x # Should fail because not allowed by host3 inbound firewall ! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1 ! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo echo " *** Testing ping from host3" echo set -x docker exec host3 ping -c1 192.168.100.1 docker exec host3 ping -c1 192.168.100.2 set +x echo echo " *** Testing ncat from host3" echo set -x docker exec host3 ncat -nzv -w5 192.168.100.2 2000 docker exec host3 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 set +x echo echo " *** Testing ping from host4" echo set -x docker exec host4 ping -c1 192.168.100.1 # Should fail because not allowed by host4 outbound firewall ! docker exec host4 ping -c1 192.168.100.2 -w5 || exit 1 ! docker exec host4 ping -c1 192.168.100.3 -w5 || exit 1 set +x echo echo " *** Testing ncat from host4" echo set -x # Should fail because not allowed by host4 outbound firewall ! docker exec host4 ncat -nzv -w5 192.168.100.2 2000 || exit 1 ! docker exec host4 ncat -nzv -w5 192.168.100.3 2000 || exit 1 ! docker exec host4 ncat -nzuv -w5 192.168.100.2 3000 | grep -q host2 || exit 1 ! docker exec host4 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1 set +x echo echo " *** Testing conntrack" echo set -x # host2 speaking to host4 on UDP 4000 should allow it to reply, when firewall rules would normally not permit this docker exec host2 sh -c "/usr/bin/echo host2 | ncat -nuv 192.168.100.4 4000" docker exec host2 ncat -e '/usr/bin/echo helloagainfromhost2' -nkluv 0.0.0.0 4000 & docker exec host4 sh -c "/usr/bin/echo host4 | ncat -nuv 192.168.100.2 4000" docker exec host4 sh -c 'kill 1' docker exec host3 sh -c 'kill 1' docker exec host2 sh -c 'kill 1' docker exec lighthouse1 sh -c 'kill 1' sleep 5 if [ "$(jobs -r)" ] then echo "nebula still running after SIGTERM sent" >&2 exit 1 fi ================================================ FILE: .github/workflows/smoke/vagrant-freebsd-amd64/Vagrantfile ================================================ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| config.vm.box = "generic/freebsd14" config.vm.synced_folder "../build", "/nebula", type: "rsync" end ================================================ FILE: .github/workflows/smoke/vagrant-linux-386/Vagrantfile ================================================ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| config.vm.box = "ubuntu/xenial32" config.vm.synced_folder "../build", "/nebula" end ================================================ FILE: .github/workflows/smoke/vagrant-linux-amd64-ipv6disable/Vagrantfile ================================================ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| config.vm.box = "ubuntu/jammy64" config.vm.synced_folder "../build", "/nebula" config.vm.provision :shell do |shell| shell.inline = <<-EOF sed -i 's/GRUB_CMDLINE_LINUX=""/GRUB_CMDLINE_LINUX="ipv6.disable=1"/' /etc/default/grub update-grub EOF shell.privileged = true shell.reboot = true end end ================================================ FILE: .github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile ================================================ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| config.vm.box = "generic/netbsd9" config.vm.synced_folder "../build", "/nebula", type: "rsync" end ================================================ FILE: .github/workflows/smoke/vagrant-openbsd-amd64/Vagrantfile ================================================ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| config.vm.box = "generic/openbsd7" config.vm.synced_folder "../build", "/nebula", type: "rsync" end ================================================ FILE: .github/workflows/smoke-extra.yml ================================================ name: smoke-extra on: push: branches: - master pull_request: types: [opened, synchronize, labeled, reopened] paths: - '.github/workflows/smoke**' - '**Makefile' - '**.go' - '**.proto' - 'go.mod' - 'go.sum' jobs: smoke-extra: if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') name: Run extra smoke tests runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: add hashicorp source run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list - name: workaround AMD-V issue # https://github.com/cri-o/packaging/pull/306 run: sudo rmmod kvm_amd - name: install vagrant run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox - name: freebsd-amd64 run: make smoke-vagrant/freebsd-amd64 - name: openbsd-amd64 run: make smoke-vagrant/openbsd-amd64 - name: netbsd-amd64 run: make smoke-vagrant/netbsd-amd64 - name: linux-386 run: make smoke-vagrant/linux-386 - name: linux-amd64-ipv6disable run: make smoke-vagrant/linux-amd64-ipv6disable timeout-minutes: 30 ================================================ FILE: .github/workflows/smoke.yml ================================================ name: smoke on: push: branches: - master pull_request: paths: - '.github/workflows/smoke**' - '**Makefile' - '**.go' - '**.proto' - 'go.mod' - 'go.sum' jobs: smoke: name: Run multi node smoke test runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: build run: make bin-docker CGO_ENABLED=1 BUILD_ARGS=-race - name: setup docker image working-directory: ./.github/workflows/smoke run: ./build.sh - name: run smoke working-directory: ./.github/workflows/smoke run: ./smoke.sh - name: setup relay docker image working-directory: ./.github/workflows/smoke run: ./build-relay.sh - name: run smoke relay working-directory: ./.github/workflows/smoke run: ./smoke-relay.sh - name: setup docker image for P256 working-directory: ./.github/workflows/smoke run: NAME="smoke-p256" CURVE=P256 ./build.sh - name: run smoke-p256 working-directory: ./.github/workflows/smoke run: NAME="smoke-p256" ./smoke.sh timeout-minutes: 10 ================================================ FILE: .github/workflows/test.yml ================================================ name: Build and test on: push: branches: - master pull_request: paths: - '.github/workflows/test.yml' - '**Makefile' - '**.go' - '**.proto' - 'go.mod' - 'go.sum' jobs: test-linux: name: Build all and test on ubuntu-linux runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Build run: make all - name: Vet run: make vet - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: version: v2.5 - name: Test run: make test - name: End 2 end run: make e2evv - name: Build test mobile run: make build-test-mobile - uses: actions/upload-artifact@v6 with: name: e2e packet flow linux-latest path: e2e/mermaid/linux-latest if-no-files-found: warn test-linux-boringcrypto: name: Build and test on linux with boringcrypto runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Build run: make bin-boringcrypto - name: Test run: make test-boringcrypto - name: End 2 end run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0" test-linux-pkcs11: name: Build and test on linux with pkcs11 runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Build run: make bin-pkcs11 - name: Test run: make test-pkcs11 test: name: Build and test on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: os: [windows-latest, macos-latest] steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: '1.25' check-latest: true - name: Build nebula run: go build ./cmd/nebula - name: Build nebula-cert run: go build ./cmd/nebula-cert - name: Vet run: make vet - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: version: v2.5 - name: Test run: make test - name: End 2 end run: make e2evv - uses: actions/upload-artifact@v6 with: name: e2e packet flow ${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }} if-no-files-found: warn ================================================ FILE: .gitignore ================================================ /nebula /nebula-cert /nebula-arm /nebula-arm6 /nebula-darwin /nebula.exe /nebula-cert.exe **/coverage.out **/cover.out /cpu.pprof /build /*.tar.gz /e2e/mermaid/ **.crt **.key **.pem **.pub !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt ================================================ FILE: .golangci.yaml ================================================ version: "2" linters: default: none enable: - testifylint exclusions: generated: lax presets: - comments - common-false-positives - legacy - std-error-handling paths: - third_party$ - builtin$ - examples$ formatters: exclusions: generated: lax paths: - third_party$ - builtin$ - examples$ ================================================ FILE: AUTHORS ================================================ # This is the official list of Nebula authors for copyright purposes. # Names should be added to this file as: # Name or Organization # The email address is not required for organizations. Slack Technologies, Inc. Nate Brown Ryan Huber ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] ## [1.10.3] - 2026-02-06 ### Security - Fix an issue where blocklist bypass is possible when using curve P256 since the signature can have 2 valid representations. Both fingerprint representations will be tested against the blocklist. Any newly issued P256 based certificates will have their signature clamped to the low-s form. Nebula will assert the low-s signature form when validating certificates in a future version. [GHSA-69x3-g4r3-p962](https://github.com/slackhq/nebula/security/advisories/GHSA-69x3-g4r3-p962) ### Changed - Improve error reporting if nebula fails to start due to a tun device naming issue. (#1588) ## [1.10.2] - 2026-01-21 ### Fixed - Fix panic when using `use_system_route_table` that was introduced in v1.10.1. (#1580) ### Changed - Fix some typos in comments. (#1582) - Dependency updates. (#1581) ## [1.10.1] - 2026-01-16 See the [v1.10.1](https://github.com/slackhq/nebula/milestone/26?closed=1) milestone for a complete list of changes. ### Fixed - Fix a bug where an unsafe route derived from the system route table could be lost on a config reload. (#1573) - Fix the PEM banner for ECDSA P256 public keys. (#1552) - Fix a regression on Windows from 1.9.x where nebula could fall back to a less performant UDP listener if non-critical ioctls failed. (#1568) - Fix a bug in handshake processing when a peer sends an unexpected public key. (#1566) ### Added - Add a config option to control accepting `recv_error` packets which defaults to `always`. (#1569) ### Changed - Various dependency updates. (#1541, #1549, #1550, #1557, #1558, #1560, #1561, #1570, #1571) ## [1.10.0] - 2025-12-04 See the [v1.10.0](https://github.com/slackhq/nebula/milestone/16?closed=1) milestone for a complete list of changes. ### Added - Support for ipv6 and multiple ipv4/6 addresses in the overlay. A new v2 ASN.1 based certificate format. Certificates now have a unified interface for external implementations. (#1212, #1216, #1345, #1359, #1381, #1419, #1464, #1466, #1451, #1476, #1467, #1481, #1399, #1488, #1492, #1495, #1468, #1521, #1535, #1538) - Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331) - Add ECMP support for `unsafe_routes`. (#1332) - PKCS11 support for P256 keys when built with `pkcs11` tag (#1153, #1482) ### Changed - **NOTE**: `default_local_cidr_any` now defaults to false, meaning that any firewall rule intended to target an `unsafe_routes` entry must explicitly declare it via the `local_cidr` field. This is almost always the intended behavior. This flag is deprecated and will be removed in a future release. (#1373) - Improve logging when a relay is in use on an inbound packet. (#1533) - Avoid fatal errors if `rountines` is > 1 on systems that don't support more than 1 routine. (#1531) - Log a warning if a firewall rule contains an `any` that negates a more restrictive filter. (#1513) - Accept encrypted CA passphrase from an environment variable. (#1421) - Allow handshaking with any trusted remote. (#1509) - Log only the count of blocklisted certificate fingerprints instead of the entire list. (#1525) - Don't fatal when the ssh server is unable to be configured successfully. (#1520) - Update to build against go v1.25. (#1483) - Allow projects using `nebula` as a library with userspace networking to configure the `logger` and build version. (#1239) - Upgrade to `yaml.v3`. (#1148, #1371, #1438, #1478) ### Fixed - Fix a potential bug with udp ipv4 only on darwin. (#1532) - Improve lost packet statistics. (#1441, #1537) - Honor `remote_allow_list` in hole punch response. (#1186) - Fix a panic when `tun.use_system_route_table` is `true` and a route lacks a destination. (#1437) - Fix an issue when `tun.use_system_route_table: true` could result in heavy CPU utilization when many thousands of routes are present. (#1326) - Fix tests for 32 bit machines. (#1394) - Fix a possible 32bit integer underflow in config handling. (#1353) - Fix moving a udp address from one vpn address to another in the `static_host_map` which could cause rapid re-handshaking with an incorrect remote. (#1259) - Improve smoke tests in environments where the docker network is not the default. (#1347) ## [1.9.7] - 2025-10-10 ### Security - Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494) ### Changed - Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459) - Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1453) ## [1.9.6] - 2025-7-15 ### Added - Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413) ### Fixed - Fix Darwin freeze due to presence of some Network Extensions (#1426) - Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422) - Fix Windows freeze due to ICMP error handling (#1412) - Fix relay migration panic (#1403) ## [1.9.5] - 2024-12-05 ### Added - Gracefully ignore v2 certificates. (#1282) ### Fixed - Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277) ## [1.9.4] - 2024-09-09 ### Added - Support UDP dialing with gVisor. (#1181) ### Changed - Make some Nebula state programmatically available via control object. (#1188) - Switch internal representation of IPs to netip, to prepare for IPv6 support in the overlay. (#1173) - Minor build and cleanup changes. (#1171, #1164, #1162) - Various dependency updates. (#1195, #1190, #1174, #1168, #1167, #1161, #1147, #1146) ### Fixed - Fix a bug on big endian hosts, like mips. (#1194) - Fix a rare panic if a local index collision happens. (#1191) - Fix integer wraparound in the calculation of handshake timeouts on 32-bit targets. (#1185) ## [1.9.3] - 2024-06-06 ### Fixed - Initialize messageCounter to 2 instead of verifying later. (#1156) ## [1.9.2] - 2024-06-03 ### Fixed - Ensure messageCounter is set before handshake is complete. (#1154) ## [1.9.1] - 2024-05-29 ### Fixed - Fixed a potential deadlock in GetOrHandshake. (#1151) ## [1.9.0] - 2024-05-07 ### Deprecated - This release adds a new setting `default_local_cidr_any` that defaults to true to match previous behavior, but will default to false in the next release (1.10). When set to false, `local_cidr` is matched correctly for firewall rules on hosts acting as unsafe routers, and should be set for any firewall rules you want to allow unsafe route hosts to access. See the issue and example config for more details. (#1071, #1099) ### Added - Nebula now has an official Docker image `nebulaoss/nebula` that is distroless and contains just the `nebula` and `nebula-cert` binaries. You can find it here: https://hub.docker.com/r/nebulaoss/nebula (#1037) - Experimental binaries for `loong64` are now provided. (#1003) - Added example service script for OpenRC. (#711) - The SSH daemon now supports inlined host keys. (#1054) - The SSH daemon now supports certificates with `sshd.trusted_cas`. (#1098) ### Changed - Config setting `tun.unsafe_routes` is now reloadable. (#1083) - Small documentation and internal improvements. (#1065, #1067, #1069, #1108, #1109, #1111, #1135) - Various dependency updates. (#1139, #1138, #1134, #1133, #1126, #1123, #1110, #1094, #1092, #1087, #1086, #1085, #1072, #1063, #1059, #1055, #1053, #1047, #1046, #1034, #1022) ### Removed - Support for the deprecated `local_range` option has been removed. Please change to `preferred_ranges` (which is also now reloadable). (#1043) - We are now building with go1.22, which means that for Windows you need at least Windows 10 or Windows Server 2016. This is because support for earlier versions was removed in Go 1.21. See https://go.dev/doc/go1.21#windows (#981) - Removed vagrant example, as it was unmaintained. (#1129) - Removed Fedora and Arch nebula.service files, as they are maintained in the upstream repos. (#1128, #1132) - Remove the TCP round trip tracking metrics, as they never had correct data and were an experiment to begin with. (#1114) ### Fixed - Fixed a potential deadlock introduced in 1.8.1. (#1112) - Fixed support for Linux when IPv6 has been disabled at the OS level. (#787) - DNS will return NXDOMAIN now when there are no results. (#845) - Allow `::` in `lighthouse.dns.host`. (#1115) - Capitalization of `NotAfter` fixed in DNS TXT response. (#1127) - Don't log invalid certificates. It is untrusted data and can cause a large volume of logs. (#1116) ## [1.8.2] - 2024-01-08 ### Fixed - Fix multiple routines when listen.port is zero. This was a regression introduced in v1.6.0. (#1057) ### Changed - Small dependency update for Noise. (#1038) ## [1.8.1] - 2023-12-19 ### Security - Update `golang.org/x/crypto`, which includes a fix for CVE-2023-48795. (#1048) ### Fixed - Fix a deadlock introduced in v1.8.0 that could occur during handshakes. (#1044) - Fix mobile builds. (#1035) ## [1.8.0] - 2023-12-06 ### Deprecated - The next minor release of Nebula, 1.9.0, will require at least Windows 10 or Windows Server 2016. This is because support for earlier versions was removed in Go 1.21. See https://go.dev/doc/go1.21#windows ### Added - Linux: Notify systemd of service readiness. This should resolve timing issues with services that depend on Nebula being active. For an example of how to enable this, see: `examples/service_scripts/nebula.service`. (#929) - Windows: Use Registered IO (RIO) when possible. Testing on a Windows 11 machine shows ~50x improvement in throughput. (#905) - NetBSD, OpenBSD: Added rudimentary support. (#916, #812) - FreeBSD: Add support for naming tun devices. (#903) ### Changed - `pki.disconnect_invalid` will now default to true. This means that once a certificate expires, the tunnel will be disconnected. If you use SIGHUP to reload certificates without restarting Nebula, you should ensure all of your clients are on 1.7.0 or newer before you enable this feature. (#859) - Limit how often a busy tunnel can requery the lighthouse. The new config option `timers.requery_wait_duration` defaults to `60s`. (#940) - The internal structures for hostmaps were refactored to reduce memory usage and the potential for subtle bugs. (#843, #938, #953, #954, #955) - Lots of dependency updates. ### Fixed - Windows: Retry wintun device creation if it fails the first time. (#985) - Fix issues with firewall reject packets that could cause panics. (#957) - Fix relay migration during re-handshakes. (#964) - Various other refactors and fixes. (#935, #952, #972, #961, #996, #1002, #987, #1004, #1030, #1032, ...) ## [1.7.2] - 2023-06-01 ### Fixed - Fix a freeze during config reload if the `static_host_map` config was changed. (#886) ## [1.7.1] - 2023-05-18 ### Fixed - Fix IPv4 addresses returned by `static_host_map` DNS lookup queries being treated as IPv6 addresses. (#877) ## [1.7.0] - 2023-05-17 ### Added - `nebula-cert ca` now supports encrypting the CA's private key with a passphrase. Pass `-encrypt` in order to be prompted for a passphrase. Encryption is performed using AES-256-GCM and Argon2id for KDF. KDF parameters default to RFC recommendations, but can be overridden via CLI flags `-argon-memory`, `-argon-parallelism`, and `-argon-iterations`. (#386) - Support for curve P256 and BoringCrypto has been added. See README section "Curve P256 and BoringCrypto" for more details. (#865, #861, #769, #856, #803) - New firewall rule `local_cidr`. This could be used to filter destinations when using `unsafe_routes`. (#507) - Add `unsafe_route` option `install`. This controls whether the route is installed in the systems routing table. (#831) - Add `tun.use_system_route_table` option. Set to true to manage unsafe routes directly on the system route table with gateway routes instead of in Nebula configuration files. This is only supported on Linux. (#839) - The metric `certificate.ttl_seconds` is now exposed via stats. (#782) - Add `punchy.respond_delay` option. This allows you to change the delay before attempting punchy.respond. Default is 5 seconds. (#721) - Added SSH commands to allow the capture of a mutex profile. (#737) - You can now set `lighthouse.calculated_remotes` to make it possible to do handshakes without a lighthouse in certain configurations. (#759) - The firewall can be configured to send REJECT replies instead of the default DROP behavior. (#738) - For macOS, an example launchd configuration file is now provided. (#762) ### Changed - Lighthouses and other `static_host_map` entries that use DNS names will now be automatically refreshed to detect when the IP address changes. (#796) - Lighthouses send ACK replies back to clients so that they do not fall into connection testing as often by clients. (#851, #408) - Allow the `listen.host` option to contain a hostname. (#825) - When Nebula switches to a new certificate (such as via SIGHUP), we now rehandshake with all existing tunnels. This allows firewall groups to be updated and `pki.disconnect_invalid` to know about the new certificate expiration time. (#838, #857, #842, #840, #835, #828, #820, #807) ### Fixed - Always disconnect blocklisted hosts, even if `pki.disconnect_invalid` is not set. (#858) - Dependencies updated and go1.20 required. (#780, #824, #855, #854) - Fix possible race condition with relays. (#827) - FreeBSD: Fix connection to the localhost's own Nebula IP. (#808) - Normalize and document some common log field values. (#837, #811) - Fix crash if you set unlucky values for the firewall timeout configuration options. (#802) - Make DNS queries case insensitive. (#793) - Update example systemd configurations to want `nss-lookup`. (#791) - Errors with SSH commands now go to the SSH tunnel instead of stderr. (#757) - Fix a hang when shutting down Android. (#772) ## [1.6.1] - 2022-09-26 ### Fixed - Refuse to process underlay packets received from overlay IPs. This prevents confusion on hosts that have unsafe routes configured. (#741) - The ssh `reload` command did not work on Windows, since it relied on sending a SIGHUP signal internally. This has been fixed. (#725) - A regression in v1.5.2 that broke unsafe routes on Mobile clients has been fixed. (#729) ## [1.6.0] - 2022-06-30 ### Added - Experimental: nebula clients can be configured to act as relays for other nebula clients. Primarily useful when stubborn NATs make a direct tunnel impossible. (#678) - Configuration option to report manually specified `ip:port`s to lighthouses. (#650) - Windows arm64 build. (#638) - `punchy` and most `lighthouse` config options now support hot reloading. (#649) ### Changed - Build against go 1.18. (#656) - Promoted `routines` config from experimental to supported feature. (#702) - Dependencies updated. (#664) ### Fixed - Packets destined for the same host that sent it will be returned on MacOS. This matches the default behavior of other operating systems. (#501) - `unsafe_route` configuration will no longer crash on Windows. (#648) - A few panics that were introduced in 1.5.x. (#657, #658, #675) ### Security - You can set `listen.send_recv_error` to control the conditions in which `recv_error` messages are sent. Sending these messages can expose the fact that Nebula is running on a host, but it speeds up re-handshaking. (#670) ### Removed - `x509` config stanza support has been removed. (#685) ## [1.5.2] - 2021-12-14 ### Added - Warn when a non lighthouse node does not have lighthouse hosts configured. (#587) ### Changed - No longer fatals if expired CA certificates are present in `pki.ca`, as long as 1 valid CA is present. (#599) - `nebula-cert` will now enforce ipv4 addresses. (#604) - Warn on macOS if an unsafe route cannot be created due to a collision with an existing route. (#610) - Warn if you set a route MTU on platforms where we don't support it. (#611) ### Fixed - Rare race condition when tearing down a tunnel due to `recv_error` and sending packets on another thread. (#590) - Bug in `routes` and `unsafe_routes` handling that was introduced in 1.5.0. (#595) - `-test` mode no longer results in a crash. (#602) ### Removed - `x509.ca` config alias for `pki.ca`. (#604) ### Security - Upgraded `golang.org/x/crypto` to address an issue which allowed unauthenticated clients to cause a panic in SSH servers. (#603) ## 1.5.1 - 2021-12-13 (This release was skipped due to discovering #610 and #611 after the tag was created.) ## [1.5.0] - 2021-11-11 ### Added - SSH `print-cert` has a new `-raw` flag to get the PEM representation of a certificate. (#483) - New build architecture: Linux `riscv64`. (#542) - New experimental config option `remote_allow_ranges`. (#540) - New config option `pki.disconnect_invalid` that will tear down tunnels when they become invalid (through expiry or removal of root trust). Default is `false`. Note, this will not currently recognize if a remote has changed certificates since the last handshake. (#370) - New config option `unsafe_routes..metric` will set a metric for a specific unsafe route. It's useful if you have more than one identical route and want to prefer one against the other. (#353) ### Changed - Build against go 1.17. (#553) - Build with `CGO_ENABLED=0` set, to create more portable binaries. This could have an effect on DNS resolution if you rely on anything non-standard. (#421) - Windows now uses the [wintun](https://www.wintun.net/) driver which does not require installation. This driver is a large improvement over the TAP driver that was used in previous versions. If you had a previous version of `nebula` running, you will want to disable the tap driver in Control Panel, or uninstall the `tap0901` driver before running this version. (#289) - Darwin binaries are now universal (works on both amd64 and arm64), signed, and shipped in a notarized zip file. `nebula-darwin.zip` will be the only darwin release artifact. (#571) - Darwin uses syscalls and AF_ROUTE to configure the routing table, instead of using `/sbin/route`. Setting `tun.dev` is now allowed on Darwin as well, it must be in the format `utun[0-9]+` or it will be ignored. (#163) ### Deprecated - The `preferred_ranges` option has been supported as a replacement for `local_range` since v1.0.0. It has now been documented and `local_range` has been officially deprecated. (#541) ### Fixed - Valid recv_error packets were incorrectly marked as "spoofing" and ignored. (#482) - SSH server handles single `exec` requests correctly. (#483) - Signing a certificate with `nebula-cert sign` now verifies that the supplied ca-key matches the ca-crt. (#503) - If `preferred_ranges` (or the deprecated `local_range`) is configured, we will immediately switch to a preferred remote address after the reception of a handshake packet (instead of waiting until 1,000 packets have been sent). (#532) - A race condition when `punchy.respond` is enabled and ensures the correct vpn ip is sent a punch back response in highly queried node. (#566) - Fix a rare crash during handshake due to a race condition. (#535) ## [1.4.0] - 2021-05-11 ### Added - Ability to output qr code images in `print`, `ca`, and `sign` modes for `nebula-cert`. This is useful when configuring mobile clients. (#297) - Experimental: Nebula can now do work on more than 2 cpu cores in send and receive paths via the new `routines` config option. (#382, #391, #395) - ICMP ping requests can be responded to when the `tun.disabled` is `true`. This is useful so that you can "ping" a lighthouse running in this mode. (#342) - Run smoke tests via `make smoke-docker`. (#287) - More reported stats, udp memory use on linux, build version (when using Prometheus), firewall, handshake, and cached packet stats. (#390, #405, #450, #453) - IPv6 support for the underlay network. (#369) - End to end testing, run with `make e2e`. (#425, #427, #428) ### Changed - Darwin will now log stdout/stderr to a file when using `-service` mode. (#303) - Example systemd unit file now better arranged startup order when using `sshd` and other fixes. (#317, #412, #438) - Reduced memory utilization/garbage collection. (#320, #323, #340) - Reduced CPU utilization. (#329) - Build against go 1.16. (#381) - Refactored handshakes to improve performance and correctness. (#401, #402, #404, #416, #451) - Improved roaming support for mobile clients. (#394, #457) - Lighthouse performance and correctness improvements. (#406, #418, #429, #433, #437, #442, #449) - Better ordered startup to enable `sshd`, `stats`, and `dns` subsystems to listen on the nebula interface. (#375) ### Fixed - No longer report handshake packets as `lost` in stats. (#331) - Error handling in the `cert` package. (#339, #373) - Orphaned pending hostmap entries are cleaned up. (#344) - Most known data races are now resolved. (#396, #400, #424) - Refuse to run a lighthouse on an ephemeral port. (#399) - Removed the global references. (#423, #426, #446) - Reloading via ssh command avoids a panic. (#447) - Shutdown is now performed in a cleaner way. (#448) - Logs will now find their way to Windows event viewer when running under `-service` mode in Windows. (#443) ## [1.3.0] - 2020-09-22 ### Added - You can emit statistics about non-message packets by setting the option `stats.message_metrics`. You can similarly emit detailed statistics about lighthouse packets by setting the option `stats.lighthouse_metrics`. See the example config for more details. (#230) - We now support freebsd/amd64. This is experimental, please give us feedback. (#103) - We now release a binary for `linux/mips-softfloat` which has also been stripped to reduce filesize and hopefully have a better chance on running on small mips devices. (#231) - You can set `tun.disabled` to true to run a standalone lighthouse without a tun device (and thus, without root). (#269) - You can set `logging.disable_timestamp` to remove timestamps from log lines, which is useful when output is redirected to a logging system that already adds timestamps. (#288) ### Changed - Handshakes should now trigger faster, as we try to be proactive with sending them instead of waiting for the next timer tick in most cases. (#246, #265) - Previously, we would drop the conntrack table whenever firewall rules were changed during a SIGHUP. Now, we will maintain the table and just validate that an entry still matches with the new rule set. (#233) - Debug logs for firewall drops now include the reason. (#220, #239) - Logs for handshakes now include the fingerprint of the remote host. (#262) - Config item `pki.blacklist` is now `pki.blocklist`. (#272) - Better support for older Linux kernels. We now only set `SO_REUSEPORT` if `tun.routines` is greater than 1 (default is 1). We also only use the `recvmmsg` syscall if `listen.batch` is greater than 1 (default is 64). (#275) - It is possible to run Nebula as a library inside of another process now. Note that this is still experimental and the internal APIs around this might change in minor version releases. (#279) ### Deprecated - `pki.blacklist` is deprecated in favor of `pki.blocklist` with the same functionality. Existing configs will continue to load for this release to allow for migrations. (#272) ### Fixed - `advmss` is now set correctly for each route table entry when `tun.routes` is configured to have some routes with higher MTU. (#245) - Packets that arrive on the tun device with an unroutable destination IP are now dropped correctly, instead of wasting time making queries to the lighthouses for IP `0.0.0.0` (#267) ## [1.2.0] - 2020-04-08 ### Added - Add `logging.timestamp_format` config option. The primary purpose of this change is to allow logging timestamps with millisecond precision. (#187) - Support `unsafe_routes` on Windows. (#184) - Add `lighthouse.remote_allow_list` to filter which subnets we will use to handshake with other hosts. See the example config for more details. (#217) - Add `lighthouse.local_allow_list` to filter which local IP addresses and/or interfaces we advertise to the lighthouses. See the example config for more details. (#217) - Wireshark dissector plugin. Add this file in `dist/wireshark` to your Wireshark plugins folder to see Nebula packet headers decoded. (#216) - systemd unit for Arch, so it can be built entirely from this repo. (#216) ### Changed - Added a delay to punching via lighthouse signal to deal with race conditions in some linux conntrack implementations. (#210) See deprecated, this also adds a new `punchy.delay` option that defaults to `1s`. - Validate all `lighthouse.hosts` and `static_host_map` VPN IPs are in the subnet defined in our cert. Exit with a fatal error if they are not in our subnet, as this is an invalid configuration (we will not have the proper routes set up to communicate with these hosts). (#170) - Use absolute paths to system binaries on macOS and Windows. (#191) - Add configuration options for `handshakes`. This includes options to tweak `try_interval`, `retries` and `wait_rotation`. See example config for descriptions. (#179) - Allow `-config` file to not end in `.yaml` or `yml`. Useful when using `-test` and automated tools like Ansible that create temporary files without suffixes. (#189) - The config test mode, `-test`, is now more thorough and catches more parsing issues. (#177) - Various documentation and example fixes. (#196) - Improved log messages. (#181, #200) - Dependencies updated. (#188) ### Deprecated - `punchy`, `punch_back` configuration options have been collapsed under the now top level `punchy` config directive. (#210) `punchy.punch` - This is the old `punchy` option. Should we perform NAT hole punching (default false)? `punchy.respond` - This is the old `punch_back` option. Should we respond to hole punching by hole punching back (default false)? ### Fixed - Reduce memory allocations when not using `unsafe_routes`. (#198) - Ignore packets from self to self. (#192) - MTU fixed for `unsafe_routes`. (#209) ## [1.1.0] - 2020-01-17 ### Added - For macOS and Windows, build a special version of the binary that can install and manage its own service configuration. You can use this with `nebula -service`. If you are building from source, use `make service` to build this feature. - Support for `mips`, `mips64`, `386` and `ppc64le` processors on Linux. - You can now configure the DNS listen host and port with `lighthouse.dns.host` and `lighthouse.dns.port`. - Subnet and routing support. You can now add a `unsafe_routes` section to your config to allow hosts to act as gateways to other subnets. Read the example config for more details. This is supported on Linux and macOS. ### Changed - Certificates now have more verifications performed, including making sure the certificate lifespan does not exceed the lifespan of the root CA. This could cause issues if you have signed certificates with expirations beyond the expiration of your CA, and you will need to reissue your certificates. - If lighthouse interval is set to `0`, never update the lighthouse (mobile optimization). - Various documentation and example fixes. - Improved error messages. - Dependencies updated. ### Fixed - If you have a firewall rule with `group: ["one-group"]`, this will now be accepted, with a warning to use `group: "one-group"` instead. - The `listen.host` configuration option was previously ignored (the bind host was always 0.0.0.0). This option will now be honored. - The `ca_sha` and `ca_name` firewall rule options should now work correctly. ## [1.0.0] - 2019-11-19 ### Added - Initial public release. [Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.3...HEAD [1.10.3]: https://github.com/slackhq/nebula/releases/tag/v1.10.3 [1.10.2]: https://github.com/slackhq/nebula/releases/tag/v1.10.2 [1.10.1]: https://github.com/slackhq/nebula/releases/tag/v1.10.1 [1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0 [1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7 [1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6 [1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5 [1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 [1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1 [1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0 [1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2 [1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1 [1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0 [1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2 [1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1 [1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0 [1.6.1]: https://github.com/slackhq/nebula/releases/tag/v1.6.1 [1.6.0]: https://github.com/slackhq/nebula/releases/tag/v1.6.0 [1.5.2]: https://github.com/slackhq/nebula/releases/tag/v1.5.2 [1.5.0]: https://github.com/slackhq/nebula/releases/tag/v1.5.0 [1.4.0]: https://github.com/slackhq/nebula/releases/tag/v1.4.0 [1.3.0]: https://github.com/slackhq/nebula/releases/tag/v1.3.0 [1.2.0]: https://github.com/slackhq/nebula/releases/tag/v1.2.0 [1.1.0]: https://github.com/slackhq/nebula/releases/tag/v1.1.0 [1.0.0]: https://github.com/slackhq/nebula/releases/tag/v1.0.0 ================================================ FILE: CODEOWNERS ================================================ #ECCN:Open Source ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2018-2019 Slack Technologies, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: LOGGING.md ================================================ ### Logging conventions A log message (the string/format passed to `Info`, `Error`, `Debug` etc, as well as their `Sprintf` counterparts) should be a descriptive message about the event and may contain specific identifying characteristics. Regardless of the level of detail in the message identifying characteristics should always be included via `WithField`, `WithFields` or `WithError` If an error is being logged use `l.WithError(err)` so that there is better discoverability about the event as well as the specific error condition. #### Common fields - `cert` - a `cert.NebulaCertificate` object, do not `.String()` this manually, `logrus` will marshal objects properly for the formatter it is using. - `fingerprint` - a single `NebeulaCertificate` hex encoded fingerprint - `fingerprints` - an array of `NebulaCertificate` hex encoded fingerprints - `fwPacket` - a FirewallPacket object - `handshake` - an object containing: - `stage` - the current stage counter - `style` - noise handshake style `ix_psk0`, `xx`, etc - `header` - a nebula header object - `udpAddr` - a `net.UDPAddr` object - `udpIp` - a udp ip address - `vpnIp` - vpn ip of the host (remote or local) - `relay` - the vpnIp of the relay host that is or should be handling the relay packet - `relayFrom` - The vpnIp of the initial sender of the relayed packet - `relayTo` - The vpnIp of the final destination of a relayed packet #### Example: ``` l.WithError(err). WithField("vpnIp", IntIp(hostinfo.hostId)). WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix"}). Info("Invalid certificate from host") ``` ================================================ FILE: Makefile ================================================ NEBULA_CMD_PATH = "./cmd/nebula" CGO_ENABLED = 0 export CGO_ENABLED # Set up OS specific bits ifeq ($(OS),Windows_NT) NEBULA_CMD_SUFFIX = .exe NULL_FILE = nul # RIO on windows does pointer stuff that makes go vet angry VET_FLAGS = -unsafeptr=false else NEBULA_CMD_SUFFIX = NULL_FILE = /dev/null endif # Only defined the build number if we haven't already ifndef BUILD_NUMBER ifeq ($(shell git describe --exact-match 2>$(NULL_FILE)),) BUILD_NUMBER = $(shell git describe --abbrev=0 --match "v*" | cut -dv -f2)-$(shell git branch --show-current)-$(shell git describe --long --dirty | cut -d- -f2-) else BUILD_NUMBER = $(shell git describe --exact-match --dirty | cut -dv -f2) endif endif DOCKER_IMAGE_REPO ?= nebulaoss/nebula DOCKER_IMAGE_TAG ?= latest LDFLAGS = -X main.Build=$(BUILD_NUMBER) ALL_LINUX = linux-amd64 \ linux-386 \ linux-ppc64le \ linux-arm-5 \ linux-arm-6 \ linux-arm-7 \ linux-arm64 \ linux-mips \ linux-mipsle \ linux-mips64 \ linux-mips64le \ linux-mips-softfloat \ linux-riscv64 \ linux-loong64 ALL_FREEBSD = freebsd-amd64 \ freebsd-arm64 ALL_OPENBSD = openbsd-amd64 \ openbsd-arm64 ALL_NETBSD = netbsd-amd64 \ netbsd-arm64 ALL = $(ALL_LINUX) \ $(ALL_FREEBSD) \ $(ALL_OPENBSD) \ $(ALL_NETBSD) \ darwin-amd64 \ darwin-arm64 \ windows-amd64 \ windows-arm64 e2e: $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e e2ev: TEST_FLAGS += -v e2ev: e2e e2evv: TEST_ENV += TEST_LOGS=1 e2evv: e2ev e2evvv: TEST_ENV += TEST_LOGS=2 e2evvv: e2ev e2evvvv: TEST_ENV += TEST_LOGS=3 e2evvvv: e2ev e2e-bench: TEST_FLAGS = -bench=. -benchmem -run=^$ e2e-bench: e2e DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) docker: docker/linux-$(shell go env GOARCH) release: $(ALL:%=build/nebula-%.tar.gz) release-linux: $(ALL_LINUX:%=build/nebula-%.tar.gz) release-freebsd: $(ALL_FREEBSD:%=build/nebula-%.tar.gz) release-openbsd: $(ALL_OPENBSD:%=build/nebula-%.tar.gz) release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz) release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz BUILD_ARGS += -trimpath bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe mv $? . bin-windows-arm64: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe mv $? . bin-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert mv $? . bin-freebsd: build/freebsd-amd64/nebula build/freebsd-amd64/nebula-cert mv $? . bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert mv $? . bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert mv $? . bin-pkcs11: BUILD_ARGS += -tags pkcs11 bin-pkcs11: CGO_ENABLED = 1 bin-pkcs11: bin bin: go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH} go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert install: go install $(BUILD_ARGS) -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH} go install $(BUILD_ARGS) -ldflags "$(LDFLAGS)" ./cmd/nebula-cert build/linux-arm-%: GOENV += GOARM=$(word 3, $(subst -, ,$*)) build/linux-mips-%: GOENV += GOMIPS=$(word 3, $(subst -, ,$*)) # Build an extra small binary for mips-softfloat build/linux-mips-softfloat/%: LDFLAGS += -s -w # boringcrypto build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1 build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0 build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0 build/%/nebula: .FORCE GOOS=$(firstword $(subst -, , $*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \ go build $(BUILD_ARGS) -o $@ -ldflags "$(LDFLAGS)" ${NEBULA_CMD_PATH} build/%/nebula-cert: .FORCE GOOS=$(firstword $(subst -, , $*)) \ GOARCH=$(word 2, $(subst -, ,$*)) $(GOENV) \ go build $(BUILD_ARGS) -o $@ -ldflags "$(LDFLAGS)" ./cmd/nebula-cert build/%/nebula.exe: build/%/nebula mv $< $@ build/%/nebula-cert.exe: build/%/nebula-cert mv $< $@ build/nebula-%.tar.gz: build/%/nebula build/%/nebula-cert tar -zcv -C build/$* -f $@ nebula nebula-cert build/nebula-%.zip: build/%/nebula.exe build/%/nebula-cert.exe cd build/$* && zip ../nebula-$*.zip nebula.exe nebula-cert.exe docker/%: build/%/nebula build/%/nebula-cert docker build . $(DOCKER_BUILD_ARGS) -f docker/Dockerfile --platform "$(subst -,/,$*)" --tag "${DOCKER_IMAGE_REPO}:${DOCKER_IMAGE_TAG}" --tag "${DOCKER_IMAGE_REPO}:$(BUILD_NUMBER)" vet: go vet $(VET_FLAGS) -v ./... test: go test -v ./... test-boringcrypto: GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./... test-pkcs11: CGO_ENABLED=1 go test -v -tags pkcs11 ./... test-cov-html: go test -coverprofile=coverage.out go tool cover -html=coverage.out build-test-mobile: GOARCH=amd64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/') GOARCH=arm64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/') GOARCH=amd64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/') GOARCH=arm64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/') bench: go test -bench=. bench-cpu: go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof go tool pprof go-audit.test cpu.pprof bench-cpu-long: go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof go tool pprof go-audit.test cpu.pprof proto: nebula.pb.go cert/cert_v1.pb.go nebula.pb.go: nebula.proto .FORCE go build github.com/gogo/protobuf/protoc-gen-gogofaster PATH="$(CURDIR):$(PATH)" protoc --gogofaster_out=paths=source_relative:. $< rm protoc-gen-gogofaster cert/cert.pb.go: cert/cert.proto .FORCE $(MAKE) -C cert cert.pb.go service: @echo > $(NULL_FILE) $(eval NEBULA_CMD_PATH := "./cmd/nebula-service") ifeq ($(words $(MAKECMDGOALS)),1) @$(MAKE) service ${.DEFAULT_GOAL} --no-print-directory endif bin-docker: bin build/linux-amd64/nebula build/linux-amd64/nebula-cert smoke-docker: bin-docker cd .github/workflows/smoke/ && ./build.sh cd .github/workflows/smoke/ && ./smoke.sh cd .github/workflows/smoke/ && NAME="smoke-p256" CURVE="P256" ./build.sh cd .github/workflows/smoke/ && NAME="smoke-p256" ./smoke.sh smoke-relay-docker: bin-docker cd .github/workflows/smoke/ && ./build-relay.sh cd .github/workflows/smoke/ && ./smoke-relay.sh smoke-docker-race: BUILD_ARGS = -race smoke-docker-race: CGO_ENABLED = 1 smoke-docker-race: smoke-docker smoke-vagrant/%: bin-docker build/%/nebula cd .github/workflows/smoke/ && ./build.sh $* cd .github/workflows/smoke/ && ./smoke-vagrant.sh $* .FORCE: .PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% .DEFAULT_GOAL := bin ================================================ FILE: README.md ================================================ ## What is Nebula? Nebula is a scalable overlay networking tool with a focus on performance, simplicity and security. It lets you seamlessly connect computers anywhere in the world. Nebula is portable, and runs on Linux, OSX, Windows, iOS, and Android. It can be used to connect a small number of computers, but is also able to connect tens of thousands of computers. Nebula incorporates a number of existing concepts like encryption, security groups, certificates, and tunneling. What makes Nebula different to existing offerings is that it brings all of these ideas together, resulting in a sum that is greater than its individual parts. Further documentation can be found [here](https://nebula.defined.net/docs/). You can read more about Nebula [here](https://medium.com/p/884110a5579). You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-39pk4xopc-CUKlGcb5Z39dQ0cK1v7ehA). ## Supported Platforms #### Desktop and Server Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for downloads or see the [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) section. - Linux - 64 and 32 bit, arm, and others - Windows - MacOS - Freebsd #### Distribution Packages - [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/) ```sh sudo pacman -S nebula ``` - [Fedora Linux](https://src.fedoraproject.org/rpms/nebula) ```sh sudo dnf install nebula ``` - [Debian Linux](https://packages.debian.org/source/stable/nebula) ```sh sudo apt install nebula ``` - [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula) ```sh sudo apk add nebula ``` - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb) ```sh brew install nebula ``` - [Docker](https://hub.docker.com/r/nebulaoss/nebula) ```sh docker pull nebulaoss/nebula ``` #### Mobile - [iOS](https://apps.apple.com/us/app/mobile-nebula/id1509587936?itsct=apps_box&itscg=30200) - [Android](https://play.google.com/store/apps/details?id=net.defined.mobile_nebula&pcampaignid=pcampaignidMKT-Other-global-all-co-prtnr-py-PartBadge-Mar2515-1) ## Technical Overview Nebula is a mutually authenticated peer-to-peer software-defined network based on the [Noise Protocol Framework](https://noiseprotocol.org/). Nebula uses certificates to assert a node's IP address, name, and membership within user-defined groups. Nebula's user-defined groups allow for provider agnostic traffic filtering between nodes. Discovery nodes (aka lighthouses) allow individual peers to find each other and optionally use UDP hole punching to establish connections from behind most firewalls or NATs. Users can move data between nodes in any number of cloud service providers, datacenters, and endpoints, without needing to maintain a particular addressing scheme. Nebula uses Elliptic-curve Diffie-Hellman (`ECDH`) key exchange and `AES-256-GCM` in its default configuration. Nebula was created to provide a mechanism for groups of hosts to communicate securely, even across the internet, while enabling expressive firewall definitions similar in style to cloud security groups. ## Getting started (quickly) To set up a Nebula network, you'll need: #### 1. The [Nebula binaries](https://github.com/slackhq/nebula/releases) or [Distribution Packages](https://github.com/slackhq/nebula#distribution-packages) for your specific platform. Specifically you'll need `nebula-cert` and the specific nebula binary for each platform you use. #### 2. (Optional, but you really should..) At least one discovery node with a routable IP address, which we call a lighthouse. Nebula lighthouses allow nodes to find each other, anywhere in the world. A lighthouse is the only node in a Nebula network whose IP should not change. Running a lighthouse requires very few compute resources, and you can easily use the least expensive option from a cloud hosting provider. If you're not sure which provider to use, a number of us have used $6/mo [DigitalOcean](https://digitalocean.com) droplets as lighthouses. Once you have launched an instance, ensure that Nebula udp traffic (default port udp/4242) can reach it over the internet. #### 3. A Nebula certificate authority, which will be the root of trust for a particular Nebula network. ```sh ./nebula-cert ca -name "Myorganization, Inc" ``` This will create files named `ca.key` and `ca.cert` in the current directory. The `ca.key` file is the most sensitive file you'll create, because it is the key used to sign the certificates for individual nebula nodes/hosts. Please store this file somewhere safe, preferably with strong encryption. **Be aware!** By default, certificate authorities have a 1-year lifetime before expiration. See [this guide](https://nebula.defined.net/docs/guides/rotating-certificate-authority/) for details on rotating a CA. #### 4. Nebula host keys and certificates generated from that certificate authority This assumes you have four nodes, named lighthouse1, laptop, server1, host3. You can name the nodes any way you'd like, including FQDN. You'll also need to choose IP addresses and the associated subnet. In this example, we are creating a nebula network that will use 192.168.100.x/24 as its network range. This example also demonstrates nebula groups, which can later be used to define traffic rules in a nebula network. ```sh ./nebula-cert sign -name "lighthouse1" -ip "192.168.100.1/24" ./nebula-cert sign -name "laptop" -ip "192.168.100.2/24" -groups "laptop,home,ssh" ./nebula-cert sign -name "server1" -ip "192.168.100.9/24" -groups "servers" ./nebula-cert sign -name "host3" -ip "192.168.100.10/24" ``` By default, host certificates will expire 1 second before the CA expires. Use the `-duration` flag to specify a shorter lifetime. #### 5. Configuration files for each host Download a copy of the nebula [example configuration](https://github.com/slackhq/nebula/blob/master/examples/config.yml). * On the lighthouse node, you'll need to ensure `am_lighthouse: true` is set. * On the individual hosts, ensure the lighthouse is defined properly in the `static_host_map` section, and is added to the lighthouse `hosts` section. #### 6. Copy nebula credentials, configuration, and binaries to each host For each host, copy the nebula binary to the host, along with `config.yml` from step 5, and the files `ca.crt`, `{host}.crt`, and `{host}.key` from step 4. **DO NOT COPY `ca.key` TO INDIVIDUAL NODES.** #### 7. Run nebula on each host ```sh ./nebula -config /path/to/config.yml ``` For more detailed instructions, [find the full documentation here](https://nebula.defined.net/docs/). ## Building Nebula from source Make sure you have [go](https://go.dev/doc/install) installed and clone this repo. Change to the nebula directory. To build nebula for all platforms: `make all` To build nebula for a specific platform (ex, Windows): `make bin-windows` See the [Makefile](Makefile) for more details on build targets ## Curve P256 and BoringCrypto The default curve used for cryptographic handshakes and signatures is Curve25519. This is the recommended setting for most users. If your deployment has certain compliance requirements, you have the option of creating your CA using `nebula-cert ca -curve P256` to use NIST Curve P256. The CA will then sign certificates using ECDSA P256, and any hosts using these certificates will use P256 for ECDH handshakes. In addition, Nebula can be built using the [BoringCrypto GOEXPERIMENT](https://github.com/golang/go/blob/go1.20/src/crypto/internal/boring/README.md) by running either of the following make targets: ```sh make bin-boringcrypto make release-boringcrypto ``` This is not the recommended default deployment, but may be useful based on your compliance requirements. ## Credits Nebula was created at Slack Technologies, Inc by Nate Brown and Ryan Huber, with contributions from Oliver Fross, Alan Lam, Wade Simmons, and Lining Wang. ================================================ FILE: SECURITY.md ================================================ Security Policy =============== Reporting a Vulnerability ------------------------- If you believe you have found a security vulnerability with Nebula, please let us know right away. We will investigate all reports and do our best to quickly fix valid issues. You can submit your report on [HackerOne](https://hackerone.com/slack) and our security team will respond as soon as possible. ================================================ FILE: allow_list.go ================================================ package nebula import ( "fmt" "net/netip" "regexp" "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" ) type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny cidrTree *bart.Table[bool] } type RemoteAllowList struct { AllowList *AllowList // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList insideAllowLists *bart.Table[*AllowList] } type LocalAllowList struct { AllowList *AllowList // To avoid ambiguity, all rules must be true, or all rules must be false. nameRules []AllowListNameRule } type AllowListNameRule struct { Name *regexp.Regexp Allow bool } func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) { var nameRules []AllowListNameRule handleKey := func(key string, value any) (bool, error) { if key == "interfaces" { var err error nameRules, err = getAllowListInterfaces(k, value) if err != nil { return false, err } return true, nil } return false, nil } al, err := newAllowListFromConfig(c, k, handleKey) if err != nil { return nil, err } return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil } func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllowList, error) { al, err := newAllowListFromConfig(c, k, nil) if err != nil { return nil, err } remoteAllowRanges, err := getRemoteAllowRanges(c, rangesKey) if err != nil { return nil, err } return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil } // If the handleKey func returns true, the rest of the parsing is skipped // for this key. This allows parsing of special values like `interfaces`. func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) { r := c.Get(k) if r == nil { return nil, nil } return newAllowList(k, r, handleKey) } // If the handleKey func returns true, the rest of the parsing is skipped // for this key. This allows parsing of special values like `interfaces`. func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) { rawMap, ok := raw.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } tree := new(bart.Table[bool]) // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { firstValue bool allValuesMatch bool defaultSet bool allValues bool } rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false} for rawCIDR, rawValue := range rawMap { if handleKey != nil { handled, err := handleKey(rawCIDR, rawValue) if err != nil { return nil, err } if handled { continue } } value, ok := config.AsBool(rawValue) if !ok { return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) tree.Insert(ipNet, value) maskBits := ipNet.Bits() var rules *allowListRules if ipNet.Addr().Is4() { rules = &rules4 } else { rules = &rules6 } if rules.firstValue { rules.allValues = value rules.firstValue = false } else { if value != rules.allValues { rules.allValuesMatch = false } } // Check if this is 0.0.0.0/0 or ::/0 if maskBits == 0 { rules.defaultSet = true } } if !rules4.defaultSet { if rules4.allValuesMatch { tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) } } if !rules6.defaultSet { if rules6.allValuesMatch { tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) } } return &AllowList{cidrTree: tree}, nil } func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) { var nameRules []AllowListNameRule rawRules, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v) } firstEntry := true var allValues bool for name, rawAllow := range rawRules { allow, ok := config.AsBool(rawAllow) if !ok { return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow) } nameRE, err := regexp.Compile("^" + name + "$") if err != nil { return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err) } nameRules = append(nameRules, AllowListNameRule{ Name: nameRE, Allow: allow, }) if firstEntry { allValues = allow firstEntry = false } else { if allow != allValues { return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k) } } } return nameRules, nil } func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } remoteAllowRanges := new(bart.Table[*AllowList]) rawMap, ok := value.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) } for rawCIDR, rawValue := range rawMap { allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil) if err != nil { return nil, err } ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList) } return remoteAllowRanges, nil } func (al *AllowList) Allow(addr netip.Addr) bool { if al == nil { return true } result, _ := al.cidrTree.Lookup(addr) return result } func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool { if al == nil { return true } return al.AllowList.Allow(udpAddr) } func (al *LocalAllowList) AllowName(name string) bool { if al == nil || len(al.nameRules) == 0 { return true } for _, rule := range al.nameRules { if rule.Name.MatchString(name) { return rule.Allow } } // If no rules match, return the default, which is the inverse of the rules return !al.nameRules[0].Allow } func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool { if al == nil { return true } return al.AllowList.Allow(vpnAddr) } func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool { if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) { return false } return al.AllowList.Allow(udpAddr) } func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool { if !al.AllowList.Allow(udpAddr) { return false } for _, vpnAddr := range vpnAddrs { if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) { return false } } return true } func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList { if al.insideAllowLists != nil { inside, ok := al.insideAllowLists.Lookup(vpnAddr) if ok { return inside } } return nil } ================================================ FILE: allow_list_test.go ================================================ package nebula import ( "net/netip" "regexp" "testing" "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewAllowListFromConfig(t *testing.T) { l := test.NewLogger() c := config.NewC(l) c.Settings["allowlist"] = map[string]any{ "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": "abc", } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[string]any{ "192.168.0.0/16": true, "10.0.0.0/8": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, "fd00::/8": true, "fd00:fd00::/16": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, } r, err = newAllowListFromConfig(c, "allowlist", nil) if assert.NoError(t, err) { assert.NotNil(t, r) } c.Settings["allowlist"] = map[string]any{ "0.0.0.0/0": true, "10.0.0.0/8": false, "10.42.42.0/24": true, "::/0": false, "fd00::/8": true, "fd00:fd00::/16": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) if assert.NoError(t, err) { assert.NotNil(t, r) } // Test interface names c.Settings["allowlist"] = map[string]any{ "interfaces": map[string]any{ `docker.*`: "foo", }, } lr, err := NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[string]any{ "interfaces": map[string]any{ `docker.*`: false, `eth.*`: true, }, } lr, err = NewLocalAllowListFromConfig(c, "allowlist") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[string]any{ "interfaces": map[string]any{ `docker.*`: false, }, } lr, err = NewLocalAllowListFromConfig(c, "allowlist") if assert.NoError(t, err) { assert.NotNil(t, lr) } } func TestAllowList_Allow(t *testing.T) { assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) tree := new(bart.Table[bool]) tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false) tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true) tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true) tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true) tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false) tree.Insert(netip.MustParsePrefix("::1/128"), true) tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1"))) assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4"))) assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42"))) assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41"))) assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1"))) assert.True(t, al.Allow(netip.MustParseAddr("::1"))) assert.False(t, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) { assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0")) rules := []AllowListNameRule{ {Name: regexp.MustCompile("^docker.*$"), Allow: false}, {Name: regexp.MustCompile("^tun.*$"), Allow: false}, } al := &LocalAllowList{nameRules: rules} assert.False(t, al.AllowName("docker0")) assert.False(t, al.AllowName("tun0")) assert.True(t, al.AllowName("eth0")) rules = []AllowListNameRule{ {Name: regexp.MustCompile("^eth.*$"), Allow: true}, {Name: regexp.MustCompile("^ens.*$"), Allow: true}, } al = &LocalAllowList{nameRules: rules} assert.False(t, al.AllowName("docker0")) assert.True(t, al.AllowName("eth0")) assert.True(t, al.AllowName("ens5")) } ================================================ FILE: bits.go ================================================ package nebula import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" ) type Bits struct { length uint64 current uint64 bits []bool lostCounter metrics.Counter dupeCounter metrics.Counter outOfWindowCounter metrics.Counter } func NewBits(bits uint64) *Bits { b := &Bits{ length: bits, bits: make([]bool, bits, bits), current: 0, lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), } // There is no counter value 0, mark it to avoid counting a lost packet later. b.bits[0] = true b.current = 0 return b } func (b *Bits) Check(l *logrus.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true } // If i is within the window, check if it's been set already. if i > b.current-b.length || i < b.length && b.current < b.length { return !b.bits[i%b.length] } // Not within the window if l.Level >= logrus.DebugLevel { l.Debugf("rejected a packet (top) %d %d\n", b.current, i) } return false } func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter // The very first window can only be tracked as lost once we are on the 2nd window or greater if b.bits[i%b.length] == false && i > b.length { b.lostCounter.Inc(1) } b.bits[i%b.length] = true b.current = i return true } // If i is a jump, adjust the window, record lost, update current, and return true if i > b.current { lost := int64(0) // Zero out the bits between the current and the new counter value, limited by the window size, // since the window is shifting for n := b.current + 1; n <= min(i, b.current+b.length); n++ { if b.bits[n%b.length] == false && n > b.length { lost++ } b.bits[n%b.length] = false } // Only record any skipped packets as a result of the window moving further than the window length // Any loss within the new window will be accounted for in future calls lost += max(0, int64(i-b.current-b.length)) b.lostCounter.Inc(lost) b.bits[i%b.length] = true b.current = i return true } // If i is within the current window but below the current counter, // Check to see if it's a duplicate if i > b.current-b.length || i < b.length && b.current < b.length { if b.current == i || b.bits[i%b.length] == true { if l.Level >= logrus.DebugLevel { l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). Debug("Receive window") } b.dupeCounter.Inc(1) return false } b.bits[i%b.length] = true return true } // In all other cases, fail and don't change current. b.outOfWindowCounter.Inc(1) if l.Level >= logrus.DebugLevel { l.WithField("accepted", false). WithField("currentCounter", b.current). WithField("incomingCounter", i). WithField("reason", "nonsense"). Debug("Receive window") } return false } ================================================ FILE: bits_test.go ================================================ package nebula import ( "testing" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestBits(t *testing.T) { l := test.NewLogger() b := NewBits(10) // make sure it is the right size assert.Len(t, b.bits, 10) // This is initialized to zero - receive one. This should work. assert.True(t, b.Check(l, 1)) assert.True(t, b.Update(l, 1)) assert.EqualValues(t, 1, b.current) g := []bool{true, true, false, false, false, false, false, false, false, false} assert.Equal(t, g, b.bits) // Receive two assert.True(t, b.Check(l, 2)) assert.True(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) g = []bool{true, true, true, false, false, false, false, false, false, false} assert.Equal(t, g, b.bits) // Receive two again - it will fail assert.False(t, b.Check(l, 2)) assert.False(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) // Jump ahead to 15, which should clear everything and set the 6th element assert.True(t, b.Check(l, 15)) assert.True(t, b.Update(l, 15)) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, false, true, false, false, false, false} assert.Equal(t, g, b.bits) // Mark 14, which is allowed because it is in the window assert.True(t, b.Check(l, 14)) assert.True(t, b.Update(l, 14)) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, true, true, false, false, false, false} assert.Equal(t, g, b.bits) // Mark 5, which is not allowed because it is not in the window assert.False(t, b.Check(l, 5)) assert.False(t, b.Update(l, 5)) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, true, true, false, false, false, false} assert.Equal(t, g, b.bits) // make sure we handle wrapping around once to the current position b = NewBits(10) assert.True(t, b.Update(l, 1)) assert.True(t, b.Update(l, 11)) assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) // Walk through a few windows in order b = NewBits(10) for i := uint64(1); i <= 100; i++ { assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i) } assert.False(t, b.Check(l, 1), "Out of window check") } func TestBitsLargeJumps(t *testing.T) { l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b = NewBits(10) b.lostCounter.Clear() assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54 assert.Equal(t, int64(45), b.lostCounter.Count()) assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99 assert.Equal(t, int64(89), b.lostCounter.Count()) assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199 assert.Equal(t, int64(188), b.lostCounter.Count()) } func TestBitsDupeCounter(t *testing.T) { l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() assert.True(t, b.Update(l, 1)) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.False(t, b.Update(l, 1)) assert.Equal(t, int64(1), b.dupeCounter.Count()) assert.True(t, b.Update(l, 2)) assert.Equal(t, int64(1), b.dupeCounter.Count()) assert.True(t, b.Update(l, 3)) assert.Equal(t, int64(1), b.dupeCounter.Count()) assert.False(t, b.Update(l, 1)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.Equal(t, int64(2), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func TestBitsOutOfWindowCounter(t *testing.T) { l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() assert.True(t, b.Update(l, 20)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 22)) assert.True(t, b.Update(l, 23)) assert.True(t, b.Update(l, 24)) assert.True(t, b.Update(l, 25)) assert.True(t, b.Update(l, 26)) assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 29)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) assert.False(t, b.Update(l, 0)) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) } func TestBitsLostCounter(t *testing.T) { l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 22)) assert.True(t, b.Update(l, 23)) assert.True(t, b.Update(l, 24)) assert.True(t, b.Update(l, 25)) assert.True(t, b.Update(l, 26)) assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 29)) assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) b = NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() assert.True(t, b.Update(l, 9)) assert.Equal(t, int64(0), b.lostCounter.Count()) // 10 will set 0 index, 0 was already set, no lost packets assert.True(t, b.Update(l, 10)) assert.Equal(t, int64(0), b.lostCounter.Count()) // 11 will set 1 index, 1 was missed, we should see 1 packet lost assert.True(t, b.Update(l, 11)) assert.Equal(t, int64(1), b.lostCounter.Count()) // Now let's fill in the window, should end up with 8 lost packets assert.True(t, b.Update(l, 12)) assert.True(t, b.Update(l, 13)) assert.True(t, b.Update(l, 14)) assert.True(t, b.Update(l, 15)) assert.True(t, b.Update(l, 16)) assert.True(t, b.Update(l, 17)) assert.True(t, b.Update(l, 18)) assert.True(t, b.Update(l, 19)) assert.Equal(t, int64(8), b.lostCounter.Count()) // Jump ahead by a window size assert.True(t, b.Update(l, 29)) assert.Equal(t, int64(8), b.lostCounter.Count()) // Now lets walk ahead normally through the window, the missed packets should fill in assert.True(t, b.Update(l, 30)) assert.True(t, b.Update(l, 31)) assert.True(t, b.Update(l, 32)) assert.True(t, b.Update(l, 33)) assert.True(t, b.Update(l, 34)) assert.True(t, b.Update(l, 35)) assert.True(t, b.Update(l, 36)) assert.True(t, b.Update(l, 37)) assert.True(t, b.Update(l, 38)) // 39 packets tracked, 22 seen, 17 lost assert.Equal(t, int64(17), b.lostCounter.Count()) // Jump ahead by 2 windows, should have recording 1 full window missing assert.True(t, b.Update(l, 58)) assert.Equal(t, int64(27), b.lostCounter.Count()) // Now lets walk ahead normally through the window, the missed packets should fill in from this window assert.True(t, b.Update(l, 59)) assert.True(t, b.Update(l, 60)) assert.True(t, b.Update(l, 61)) assert.True(t, b.Update(l, 62)) assert.True(t, b.Update(l, 63)) assert.True(t, b.Update(l, 64)) assert.True(t, b.Update(l, 65)) assert.True(t, b.Update(l, 66)) assert.True(t, b.Update(l, 67)) // 68 packets tracked, 32 seen, 36 missed assert.Equal(t, int64(36), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func TestBitsLostCounterIssue1(t *testing.T) { l := test.NewLogger() b := NewBits(10) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() assert.True(t, b.Update(l, 4)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 1)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 9)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 2)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 3)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 5)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 6)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 7)) assert.Equal(t, int64(0), b.lostCounter.Count()) // assert.True(t, b.Update(l, 8)) assert.True(t, b.Update(l, 10)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 11)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 14)) assert.Equal(t, int64(0), b.lostCounter.Count()) // Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter assert.True(t, b.Update(l, 19)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 12)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 13)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 15)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 16)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 17)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 18)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 20)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 21)) // We missed packet 8 above assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func BenchmarkBits(b *testing.B) { z := NewBits(10) for n := 0; n < b.N; n++ { for i := range z.bits { z.bits[i] = true } for i := range z.bits { z.bits[i] = false } } } ================================================ FILE: boring.go ================================================ //go:build boringcrypto package nebula import "crypto/boring" var boringEnabled = boring.Enabled ================================================ FILE: calculated_remote.go ================================================ package nebula import ( "encoding/binary" "fmt" "math" "net" "net/netip" "strconv" "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" ) // This allows us to "guess" what the remote might be for a host while we wait // for the lighthouse response. See "lighthouse.calculated_remotes" in the // example config file. type calculatedRemote struct { ipNet netip.Prefix mask netip.Prefix port uint32 } func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) { if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() { return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr) } masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) } return &calculatedRemote{ ipNet: maskCidr, mask: masked, port: uint32(port), }, nil } func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort { // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) mask := binary.BigEndian.Uint32(maskb[:]) b := c.mask.Addr().As4() maskAddr := binary.BigEndian.Uint32(b[:]) b = addr.As4() intAddr := binary.BigEndian.Uint32(b[:]) return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port} } func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort { mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) maskAddr := c.mask.Addr().As16() calcAddr := addr.As16() ap := V6AddrPort{Port: c.port} maskb := binary.BigEndian.Uint64(mask[:8]) maskAddrb := binary.BigEndian.Uint64(maskAddr[:8]) calcAddrb := binary.BigEndian.Uint64(calcAddr[:8]) ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb) maskb = binary.BigEndian.Uint64(mask[8:]) maskAddrb = binary.BigEndian.Uint64(maskAddr[8:]) calcAddrb = binary.BigEndian.Uint64(calcAddr[8:]) ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb) return &ap } func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } calculatedRemotes := new(bart.Table[[]*calculatedRemote]) rawMap, ok := value.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) } for rawCIDR, rawValue := range rawMap { cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } calculatedRemotes.Insert(cidr, entry) } return calculatedRemotes, nil } func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) { rawList, ok := raw.([]any) if !ok { return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) } var l []*calculatedRemote for _, e := range rawList { c, err := newCalculatedRemotesEntryFromConfig(cidr, e) if err != nil { return nil, fmt.Errorf("calculated_remotes entry: %w", err) } l = append(l, c) } return l, nil } func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { rawMap, ok := raw.(map[string]any) if !ok { return nil, fmt.Errorf("invalid type: %T", raw) } rawValue := rawMap["mask"] if rawValue == nil { return nil, fmt.Errorf("missing mask: %v", rawMap) } rawMask, ok := rawValue.(string) if !ok { return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) } maskCidr, err := netip.ParsePrefix(rawMask) if err != nil { return nil, fmt.Errorf("invalid mask: %s", rawMask) } var port int rawValue = rawMap["port"] if rawValue == nil { return nil, fmt.Errorf("missing port: %v", rawMap) } switch v := rawValue.(type) { case int: port = v case string: port, err = strconv.Atoi(v) if err != nil { return nil, fmt.Errorf("invalid port: %s: %w", v, err) } default: return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } return newCalculatedRemote(cidr, maskCidr, port) } ================================================ FILE: calculated_remote_test.go ================================================ package nebula import ( "net/netip" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCalculatedRemoteApply(t *testing.T) { // Test v4 addresses ipNet := netip.MustParsePrefix("192.168.1.0/24") c, err := newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") require.NoError(t, err) expected, err := netip.ParseAddr("192.168.1.182") require.NoError(t, err) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) // Test v6 addresses ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64") c, err = newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) // Test v6 addresses part 2 ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80") c, err = newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) // Test v6 addresses part 2 ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48") c, err = newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) } func Test_newCalculatedRemote(t *testing.T) { c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242) require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128") require.Nil(t, c) c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242) require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32") require.Nil(t, c) c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242) require.NoError(t, err) require.NotNil(t, c) c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242) require.NoError(t, err) require.NotNil(t, c) } ================================================ FILE: cert/Makefile ================================================ GO111MODULE = on export GO111MODULE cert_v1.pb.go: cert_v1.proto .FORCE go build google.golang.org/protobuf/cmd/protoc-gen-go PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $< rm protoc-gen-go .FORCE: ================================================ FILE: cert/README.md ================================================ ## `cert` This is a library for interacting with `nebula` style certificates and authorities. There are now 2 versions of `nebula` certificates: ## v1 This version is deprecated. A `protobuf` definition of the certificate format is included at `cert_v1.proto` To compile the definition you will need `protoc` installed. To compile for `go` with the same version of protobuf specified in go.mod: ```bash make proto ``` ## v2 This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate future certificate changes better than v1. `cert_v2.asn1` defines the wire format and can be used to compile marshalers. ================================================ FILE: cert/asn1.go ================================================ package cert import ( "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" ) // readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value // https://github.com/golang/go/issues/64811#issuecomment-1944446920 func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool { var present bool var child cryptobyte.String if !b.ReadOptionalASN1(&child, &present, tag) { return false } if !present { *out = defaultValue return true } // Ensure we have 1 byte if len(child) == 1 { *out = child[0] > 0 return true } return false } // readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value // Similar issue as with readOptionalASN1Boolean func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool { var present bool var child cryptobyte.String if !b.ReadOptionalASN1(&child, &present, tag) { return false } if !present { *out = defaultValue return true } // Ensure we have 1 byte if len(child) == 1 { *out = child[0] return true } return false } ================================================ FILE: cert/ca_pool.go ================================================ package cert import ( "errors" "fmt" "net/netip" "slices" "strings" "time" ) type CAPool struct { CAs map[string]*CachedCertificate certBlocklist map[string]struct{} } // NewCAPool creates an empty CAPool func NewCAPool() *CAPool { ca := CAPool{ CAs: make(map[string]*CachedCertificate), certBlocklist: make(map[string]struct{}), } return &ca } // NewCAPoolFromPEM will create a new CA pool from the provided // input bytes, which must be a PEM-encoded set of nebula certificates. // If the pool contains any expired certificates, an ErrExpired will be // returned along with the pool. The caller must handle any such errors. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { pool := NewCAPool() var err error var expired bool for { caPEMs, err = pool.AddCAFromPEM(caPEMs) if errors.Is(err, ErrExpired) { expired = true err = nil } if err != nil { return nil, err } if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { break } } if expired { return pool, ErrExpired } return pool, nil } // AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool. // Only the first pem encoded object will be consumed, any remaining bytes are returned. // Parsed certificates will be verified and must be a CA func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes) if err != nil { return pemBytes, err } err = ncp.AddCA(c) if err != nil { return pemBytes, err } return pemBytes, nil } // AddCA verifies a Nebula CA certificate and adds it to the pool. func (ncp *CAPool) AddCA(c Certificate) error { if !c.IsCA() { return fmt.Errorf("%s: %w", c.Name(), ErrNotCA) } if !c.CheckSignature(c.PublicKey()) { return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned) } sum, err := c.Fingerprint() if err != nil { return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name()) } cc := &CachedCertificate{ Certificate: c, Fingerprint: sum, InvertedGroups: make(map[string]struct{}), } for _, g := range c.Groups() { cc.InvertedGroups[g] = struct{}{} } ncp.CAs[sum] = cc if c.Expired(time.Now()) { return fmt.Errorf("%s: %w", c.Name(), ErrExpired) } return nil } // BlocklistFingerprint adds a cert fingerprint to the blocklist func (ncp *CAPool) BlocklistFingerprint(f string) { ncp.certBlocklist[f] = struct{}{} } // ResetCertBlocklist removes all previously blocklisted cert fingerprints func (ncp *CAPool) ResetCertBlocklist() { ncp.certBlocklist = make(map[string]struct{}) } // IsBlocklisted tests the provided fingerprint against the pools blocklist. // Returns true if the fingerprint is blocked. func (ncp *CAPool) IsBlocklisted(fingerprint string) bool { if _, ok := ncp.certBlocklist[fingerprint]; ok { return true } return false } // VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool. // If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts // to increase performance. func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) { if c == nil { return nil, fmt.Errorf("no certificate") } fp, err := c.Fingerprint() if err != nil { return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err) } signer, err := ncp.verify(c, now, fp, "") if err != nil { return nil, err } // Pre nebula v1.10.3 could generate signatures in either high or low s form and validation // of signatures allowed for either. Nebula v1.10.3 and beyond clamps signature generation to low-s form // but validation still allows for either. Since a change in the signature bytes affects the fingerprint, we // need to test both forms until such a time comes that we enforce low-s form on signature validation. fp2, err := CalculateAlternateFingerprint(c) if err != nil { return nil, fmt.Errorf("could not calculate alternate fingerprint to verify: %w", err) } if fp2 != "" && ncp.IsBlocklisted(fp2) { return nil, ErrBlockListed } cc := CachedCertificate{ Certificate: c, InvertedGroups: make(map[string]struct{}), Fingerprint: fp, fingerprint2: fp2, signerFingerprint: signer.Fingerprint, } for _, g := range c.Groups() { cc.InvertedGroups[g] = struct{}{} } return &cc, nil } // VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and // is a cheaper operation to perform as a result. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { // Check any available alternate fingerprint forms for this certificate, re P256 high-s/low-s if c.fingerprint2 != "" && ncp.IsBlocklisted(c.fingerprint2) { return ErrBlockListed } _, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint) return err } func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) { if ncp.IsBlocklisted(certFp) { return nil, ErrBlockListed } signer, err := ncp.GetCAForCert(c) if err != nil { return nil, err } if signer.Certificate.Expired(now) { return nil, ErrRootExpired } if c.Expired(now) { return nil, ErrExpired } // If we are checking a cached certificate then we can bail early here // Either the root is no longer trusted or everything is fine if len(signerFp) > 0 { if signerFp != signer.Fingerprint { return nil, ErrFingerprintMismatch } return signer, nil } if !c.CheckSignature(signer.Certificate.PublicKey()) { return nil, ErrSignatureMismatch } err = CheckCAConstraints(signer.Certificate, c) if err != nil { return nil, err } return signer, nil } // GetCAForCert attempts to return the signing certificate for the provided certificate. // No signature validation is performed func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { issuer := c.Issuer() if issuer == "" { return nil, fmt.Errorf("no issuer in certificate") } signer, ok := ncp.CAs[issuer] if ok { return signer, nil } return nil, ErrCaNotFound } // GetFingerprints returns an array of trusted CA fingerprints func (ncp *CAPool) GetFingerprints() []string { fp := make([]string, len(ncp.CAs)) i := 0 for k := range ncp.CAs { fp[i] = k i++ } return fp } // CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate. func CheckCAConstraints(signer Certificate, sub Certificate) error { return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks()) } // checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested. func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error { // Make sure this cert isn't valid after the root if notAfter.After(signer.NotAfter()) { return fmt.Errorf("certificate expires after signing certificate") } // Make sure this cert wasn't valid before the root if notBefore.Before(signer.NotBefore()) { return fmt.Errorf("certificate is valid before the signing certificate") } // If the signer has a limited set of groups make sure the cert only contains a subset signerGroups := signer.Groups() if len(signerGroups) > 0 { for _, g := range groups { if !slices.Contains(signerGroups, g) { return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) } } } // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset signingNetworks := signer.Networks() if len(signingNetworks) > 0 { for _, certNetwork := range networks { found := false for _, signingNetwork := range signingNetworks { if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() { found = true break } } if !found { return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String()) } } } // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset signingUnsafeNetworks := signer.UnsafeNetworks() if len(signingUnsafeNetworks) > 0 { for _, certUnsafeNetwork := range unsafeNetworks { found := false for _, caNetwork := range signingUnsafeNetworks { if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() { found = true break } } if !found { return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String()) } } } return nil } ================================================ FILE: cert/ca_pool_test.go ================================================ package cert import ( "net/netip" "testing" "time" "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewCAPoolFromBytes(t *testing.T) { noNewLines := ` # Current provisional, Remove once everything moves over to the real root. -----BEGIN NEBULA CERTIFICATE----- Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf 2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== -----END NEBULA CERTIFICATE----- # root-ca01 -----BEGIN NEBULA CERTIFICATE----- CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== -----END NEBULA CERTIFICATE----- ` withNewLines := ` # Current provisional, Remove once everything moves over to the real root. -----BEGIN NEBULA CERTIFICATE----- Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf 2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== -----END NEBULA CERTIFICATE----- # root-ca01 -----BEGIN NEBULA CERTIFICATE----- CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== -----END NEBULA CERTIFICATE----- ` expired := ` # expired certificate -----BEGIN NEBULA CERTIFICATE----- CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA 7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8 Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0= -----END NEBULA CERTIFICATE----- ` p256 := ` # p256 certificate -----BEGIN NEBULA CERTIFICATE----- CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe +0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq 75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA== -----END NEBULA CERTIFICATE----- ` rootCA := certificateV1{ details: detailsV1{ name: "nebula root ca", }, } rootCA01 := certificateV1{ details: detailsV1{ name: "nebula root ca 01", }, } rootCAP256 := certificateV1{ details: detailsV1{ name: "nebula P256 test", }, } p, err := NewCAPoolFromPEM([]byte(noNewLines)) require.NoError(t, err) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) require.NoError(t, err) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) assert.Equal(t, ErrExpired, err) assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name()) assert.Len(t, pppp.CAs, 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) require.NoError(t, err) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Len(t, ppppp.CAs, 1) } func TestCertificateV1_Verify(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) }) // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) }) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") // Create a copy of the cert and swap to the alternate form for the signature nc := c.Copy() b, err := p256.Swap(c.Signature()) require.NoError(t, err) require.NoError(t, nc.(*certificateV1).setSignature(b)) _, err = caPool.VerifyCertificate(time.Now(), nc) require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) }) // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool = NewCAPool() b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) }) c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) cc, err := caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Reset the blocklist and block the alternate form fingerprint caPool.ResetCertBlocklist() caPool.BlocklistFingerprint(cc.fingerprint2) err = caPool.VerifyCachedCertificate(time.Now(), cc) require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } func TestCertificateV1_Verify_IPs(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) // ip is outside the network cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } func TestCertificateV1_Verify_Subnets(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) // ip is outside the network cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } func TestCertificateV2_Verify(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) }) // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) }) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) require.EqualError(t, err, "certificate is in the block list") // Create a copy of the cert and swap to the alternate form for the signature nc := c.Copy() b, err := p256.Swap(c.Signature()) require.NoError(t, err) require.NoError(t, nc.(*certificateV2).setSignature(b)) _, err = caPool.VerifyCertificate(time.Now(), nc) require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) }) // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool = NewCAPool() b, err = caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) }) c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) cc, err := caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Reset the blocklist and block the alternate form fingerprint caPool.ResetCertBlocklist() caPool.BlocklistFingerprint(cc.fingerprint2) err = caPool.VerifyCachedCertificate(time.Now(), cc) require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() err = caPool.VerifyCachedCertificate(time.Now(), cc) require.NoError(t, err) } func TestCertificateV2_Verify_IPs(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) // ip is outside the network cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } func TestCertificateV2_Verify_Subnets(t *testing.T) { caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) require.NoError(t, err) assert.Empty(t, b) // ip is outside the network cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip is outside the network reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip is within the network but mask is outside cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip is within the network but mask is outside reversed order of above cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) }) // ip and mask are within the network cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } ================================================ FILE: cert/cert.go ================================================ package cert import ( "fmt" "net/netip" "time" "github.com/slackhq/nebula/cert/p256" ) type Version uint8 const ( VersionPre1 Version = 0 Version1 Version = 1 Version2 Version = 2 ) type Certificate interface { // Version defines the underlying certificate structure and wire protocol version // Version1 certificates are ipv4 only and uses protobuf serialization // Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization Version() Version // Name is the human-readable name that identifies this certificate. Name() string // Networks is a list of ip addresses and network sizes assigned to this certificate. // If IsCA is true then certificates signed by this CA can only have ip addresses and // networks that are contained by an entry in this list. Networks() []netip.Prefix // UnsafeNetworks is a list of networks that this host can act as an unsafe router for. // If IsCA is true then certificates signed by this CA can only have networks that are // contained by an entry in this list. UnsafeNetworks() []netip.Prefix // Groups is a list of identities that can be used to write more general firewall rule // definitions. // If IsCA is true then certificates signed by this CA can only use groups that are // in this list. Groups() []string // IsCA signifies if this is a certificate authority (true) or a host certificate (false). // It is invalid to use a CA certificate as a host certificate. IsCA() bool // NotBefore is the time at which this certificate becomes valid. // If IsCA is true then certificate signed by this CA can not have a time before this. NotBefore() time.Time // NotAfter is the time at which this certificate becomes invalid. // If IsCA is true then certificate signed by this CA can not have a time after this. NotAfter() time.Time // Issuer is the fingerprint of the CA that signed this certificate. // If IsCA is true then this will be empty. Issuer() string // PublicKey is the raw bytes to be used in asymmetric cryptographic operations. PublicKey() []byte // MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM MarshalPublicKeyPEM() []byte // Curve identifies which curve was used for the PublicKey and Signature. Curve() Curve // Signature is the cryptographic seal for all the details of this certificate. // CheckSignature can be used to verify that the details of this certificate are valid. Signature() []byte // CheckSignature will check that the certificate Signature() matches the // computed signature. A true result means this certificate has not been tampered with. CheckSignature(signingPublicKey []byte) bool // Fingerprint returns the hex encoded sha256 sum of the certificate. // This acts as a unique fingerprint and can be used to blocklist certificates. Fingerprint() (string, error) // Expired tests if the certificate is valid for the provided time. Expired(t time.Time) bool // VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key. VerifyPrivateKey(curve Curve, privateKey []byte) error // Marshal will return the byte representation of this certificate // This is primarily the format transmitted on the wire. Marshal() ([]byte, error) // MarshalForHandshakes prepares the bytes needed to use directly in a handshake MarshalForHandshakes() ([]byte, error) // MarshalPEM will return a PEM encoded representation of this certificate // This is primarily the format stored on disk MarshalPEM() ([]byte, error) // MarshalJSON will return the json representation of this certificate MarshalJSON() ([]byte, error) // String will return a human-readable representation of this certificate String() string // Copy creates a copy of the certificate Copy() Certificate } // CachedCertificate represents a verified certificate with some cached fields to improve // performance. type CachedCertificate struct { Certificate Certificate InvertedGroups map[string]struct{} Fingerprint string signerFingerprint string // A place to store a 2nd fingerprint if the certificate could have one, such as with P256 fingerprint2 string } func (cc *CachedCertificate) String() string { return cc.Certificate.String() } // Recombine will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. // Implementations MUST assert the public key is not in the raw certificate bytes if the passed in public key is not empty. func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) { if publicKey == nil { return nil, ErrNoPeerStaticKey } if rawCertBytes == nil { return nil, ErrNoPayload } var c Certificate var err error switch v { // Implementations must ensure the result is a valid cert! case VersionPre1, Version1: c, err = unmarshalCertificateV1(rawCertBytes, publicKey) case Version2: c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) default: return nil, ErrUnknownVersion } if err != nil { return nil, err } if c.Curve() != curve { return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) } return c, nil } // CalculateAlternateFingerprint calculates a 2nd fingerprint representation for P256 certificates // CAPool blocklist testing through `VerifyCertificate` and `VerifyCachedCertificate` automatically performs this step. func CalculateAlternateFingerprint(c Certificate) (string, error) { if c.Curve() != Curve_P256 { return "", nil } nc := c.Copy() b, err := p256.Swap(nc.Signature()) if err != nil { return "", err } switch v := nc.(type) { case *certificateV1: err = v.setSignature(b) case *certificateV2: err = v.setSignature(b) default: return "", ErrUnknownVersion } if err != nil { return "", err } return nc.Fingerprint() } ================================================ FILE: cert/cert_v1.go ================================================ package cert import ( "bytes" "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "encoding/pem" "fmt" "net" "net/netip" "time" "golang.org/x/crypto/curve25519" "google.golang.org/protobuf/proto" ) const publicKeyLen = 32 type certificateV1 struct { details detailsV1 signature []byte } type detailsV1 struct { name string networks []netip.Prefix unsafeNetworks []netip.Prefix groups []string notBefore time.Time notAfter time.Time publicKey []byte isCA bool issuer string curve Curve } type m = map[string]any func (c *certificateV1) Version() Version { return Version1 } func (c *certificateV1) Curve() Curve { return c.details.curve } func (c *certificateV1) Groups() []string { return c.details.groups } func (c *certificateV1) IsCA() bool { return c.details.isCA } func (c *certificateV1) Issuer() string { return c.details.issuer } func (c *certificateV1) Name() string { return c.details.name } func (c *certificateV1) Networks() []netip.Prefix { return c.details.networks } func (c *certificateV1) NotAfter() time.Time { return c.details.notAfter } func (c *certificateV1) NotBefore() time.Time { return c.details.notBefore } func (c *certificateV1) PublicKey() []byte { return c.details.publicKey } func (c *certificateV1) MarshalPublicKeyPEM() []byte { return marshalCertPublicKeyToPEM(c) } func (c *certificateV1) Signature() []byte { return c.signature } func (c *certificateV1) UnsafeNetworks() []netip.Prefix { return c.details.unsafeNetworks } func (c *certificateV1) Fingerprint() (string, error) { b, err := c.Marshal() if err != nil { return "", err } sum := sha256.Sum256(b) return hex.EncodeToString(sum[:]), nil } func (c *certificateV1) CheckSignature(key []byte) bool { b, err := proto.Marshal(c.getRawDetails()) if err != nil { return false } switch c.details.curve { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) if err != nil { return false } hashed := sha256.Sum256(b) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: return false } } func (c *certificateV1) Expired(t time.Time) bool { return c.details.notBefore.After(t) || c.details.notAfter.Before(t) } func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { if curve != c.details.curve { return fmt.Errorf("curve in cert and private key supplied don't match") } if c.details.isCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise if len(key) != ed25519.PrivateKeySize { return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { return fmt.Errorf("public key in cert and private key supplied don't match") } case Curve_P256: privkey, err := ecdh.P256().NewPrivateKey(key) if err != nil { return fmt.Errorf("cannot parse private key as P256: %w", err) } pub := privkey.PublicKey().Bytes() if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } default: return fmt.Errorf("invalid curve: %s", curve) } return nil } var pub []byte switch curve { case Curve_CURVE25519: var err error pub, err = curve25519.X25519(key, curve25519.Basepoint) if err != nil { return err } case Curve_P256: privkey, err := ecdh.P256().NewPrivateKey(key) if err != nil { return err } pub = privkey.PublicKey().Bytes() default: return fmt.Errorf("invalid curve: %s", curve) } if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } return nil } // getRawDetails marshals the raw details into protobuf ready struct func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { rd := &RawNebulaCertificateDetails{ Name: c.details.name, Groups: c.details.groups, NotBefore: c.details.notBefore.Unix(), NotAfter: c.details.notAfter.Unix(), PublicKey: make([]byte, len(c.details.publicKey)), IsCA: c.details.isCA, Curve: c.details.curve, } for _, ipNet := range c.details.networks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) } for _, ipNet := range c.details.unsafeNetworks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) } copy(rd.PublicKey, c.details.publicKey[:]) // I know, this is terrible rd.Issuer, _ = hex.DecodeString(c.details.issuer) return rd } func (c *certificateV1) String() string { b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") if err != nil { return fmt.Sprintf("", err) } return string(b) } func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { pubKey := c.details.publicKey c.details.publicKey = nil rawCertNoKey, err := c.Marshal() if err != nil { return nil, err } c.details.publicKey = pubKey return rawCertNoKey, nil } func (c *certificateV1) Marshal() ([]byte, error) { rc := RawNebulaCertificate{ Details: c.getRawDetails(), Signature: c.signature, } return proto.Marshal(&rc) } func (c *certificateV1) MarshalPEM() ([]byte, error) { b, err := c.Marshal() if err != nil { return nil, err } return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil } func (c *certificateV1) MarshalJSON() ([]byte, error) { return json.Marshal(c.marshalJSON()) } func (c *certificateV1) marshalJSON() m { fp, _ := c.Fingerprint() return m{ "version": Version1, "details": m{ "name": c.details.name, "networks": c.details.networks, "unsafeNetworks": c.details.unsafeNetworks, "groups": c.details.groups, "notBefore": c.details.notBefore, "notAfter": c.details.notAfter, "publicKey": fmt.Sprintf("%x", c.details.publicKey), "isCa": c.details.isCA, "issuer": c.details.issuer, "curve": c.details.curve.String(), }, "fingerprint": fp, "signature": fmt.Sprintf("%x", c.Signature()), } } func (c *certificateV1) Copy() Certificate { nc := &certificateV1{ details: detailsV1{ name: c.details.name, notBefore: c.details.notBefore, notAfter: c.details.notAfter, publicKey: make([]byte, len(c.details.publicKey)), isCA: c.details.isCA, issuer: c.details.issuer, curve: c.details.curve, }, signature: make([]byte, len(c.signature)), } if c.details.groups != nil { nc.details.groups = make([]string, len(c.details.groups)) copy(nc.details.groups, c.details.groups) } if c.details.networks != nil { nc.details.networks = make([]netip.Prefix, len(c.details.networks)) copy(nc.details.networks, c.details.networks) } if c.details.unsafeNetworks != nil { nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) } copy(nc.signature, c.signature) copy(nc.details.publicKey, c.details.publicKey) return nc } func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { c.details = detailsV1{ name: t.Name, networks: t.Networks, unsafeNetworks: t.UnsafeNetworks, groups: t.Groups, notBefore: t.NotBefore, notAfter: t.NotAfter, publicKey: t.PublicKey, isCA: t.IsCA, curve: t.Curve, issuer: t.issuer, } return c.validate() } func (c *certificateV1) validate() error { // Empty names are allowed if len(c.details.publicKey) == 0 { return ErrInvalidPublicKey } // Original v1 rules allowed multiple networks to be present but ignored all but the first one. // Continue to allow this behavior if !c.details.isCA && len(c.details.networks) == 0 { return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network") } for _, network := range c.details.networks { if !network.IsValid() || !network.Addr().IsValid() { return NewErrInvalidCertificateProperties("invalid network: %s", network) } if network.Addr().Is6() { return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network) } if network.Addr().IsUnspecified() { return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) } if network.Addr().Zone() != "" { return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) } } for _, network := range c.details.unsafeNetworks { if !network.IsValid() || !network.Addr().IsValid() { return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) } if network.Addr().Is6() { return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network) } if network.Addr().Zone() != "" { return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) } } // v1 doesn't bother with sort order or uniqueness of networks or unsafe networks. // We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered // unsafe networks would result in a different signature. return nil } func (c *certificateV1) marshalForSigning() ([]byte, error) { b, err := proto.Marshal(c.getRawDetails()) if err != nil { return nil, err } return b, nil } func (c *certificateV1) setSignature(b []byte) error { if len(b) == 0 { return ErrEmptySignature } c.signature = b return nil } // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert // if the publicKey is provided here then it is not required to be present in `b` func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) { if len(b) == 0 { return nil, fmt.Errorf("nil byte array") } var rc RawNebulaCertificate err := proto.Unmarshal(b, &rc) if err != nil { return nil, err } if rc.Details == nil { return nil, fmt.Errorf("encoded Details was nil") } if len(rc.Details.Ips)%2 != 0 { return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") } if len(rc.Details.Subnets)%2 != 0 { return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") } nc := certificateV1{ details: detailsV1{ name: rc.Details.Name, groups: make([]string, len(rc.Details.Groups)), networks: make([]netip.Prefix, len(rc.Details.Ips)/2), unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2), notBefore: time.Unix(rc.Details.NotBefore, 0), notAfter: time.Unix(rc.Details.NotAfter, 0), publicKey: nil, isCA: rc.Details.IsCA, curve: rc.Details.Curve, }, signature: make([]byte, len(rc.Signature)), } copy(nc.signature, rc.Signature) copy(nc.details.groups, rc.Details.Groups) nc.details.issuer = hex.EncodeToString(rc.Details.Issuer) // If a public key is passed in as an argument, the certificate pubkey must be empty // and the passed-in pubkey copied into the cert. if len(publicKey) > 0 { if len(rc.Details.PublicKey) != 0 { return nil, ErrCertPubkeyPresent } nc.details.publicKey = make([]byte, len(publicKey)) copy(nc.details.publicKey, publicKey) } else { nc.details.publicKey = make([]byte, len(rc.Details.PublicKey)) copy(nc.details.publicKey, rc.Details.PublicKey) } var ip netip.Addr for i, rawIp := range rc.Details.Ips { if i%2 == 0 { ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() nc.details.networks[i/2] = netip.PrefixFrom(ip, ones) } } for i, rawIp := range rc.Details.Subnets { if i%2 == 0 { ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones) } } err = nc.validate() if err != nil { return nil, err } return &nc, nil } func ip2int(ip []byte) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) } return binary.BigEndian.Uint32(ip) } func int2ip(nn uint32) net.IP { ip := make(net.IP, net.IPv4len) binary.BigEndian.PutUint32(ip, nn) return ip } func addr2int(addr netip.Addr) uint32 { b := addr.Unmap().As4() return binary.BigEndian.Uint32(b[:]) } func int2addr(nn uint32) netip.Addr { ip := [4]byte{} binary.BigEndian.PutUint32(ip[:], nn) return netip.AddrFrom4(ip).Unmap() } ================================================ FILE: cert/cert_v1.pb.go ================================================ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 // protoc v3.21.5 // source: cert_v1.proto package cert import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type Curve int32 const ( Curve_CURVE25519 Curve = 0 Curve_P256 Curve = 1 ) // Enum value maps for Curve. var ( Curve_name = map[int32]string{ 0: "CURVE25519", 1: "P256", } Curve_value = map[string]int32{ "CURVE25519": 0, "P256": 1, } ) func (x Curve) Enum() *Curve { p := new(Curve) *p = x return p } func (x Curve) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } func (Curve) Descriptor() protoreflect.EnumDescriptor { return file_cert_v1_proto_enumTypes[0].Descriptor() } func (Curve) Type() protoreflect.EnumType { return &file_cert_v1_proto_enumTypes[0] } func (x Curve) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } // Deprecated: Use Curve.Descriptor instead. func (Curve) EnumDescriptor() ([]byte, []int) { return file_cert_v1_proto_rawDescGZIP(), []int{0} } type RawNebulaCertificate struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Details *RawNebulaCertificateDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"` Signature []byte `protobuf:"bytes,2,opt,name=Signature,proto3" json:"Signature,omitempty"` } func (x *RawNebulaCertificate) Reset() { *x = RawNebulaCertificate{} if protoimpl.UnsafeEnabled { mi := &file_cert_v1_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RawNebulaCertificate) String() string { return protoimpl.X.MessageStringOf(x) } func (*RawNebulaCertificate) ProtoMessage() {} func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message { mi := &file_cert_v1_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead. func (*RawNebulaCertificate) Descriptor() ([]byte, []int) { return file_cert_v1_proto_rawDescGZIP(), []int{0} } func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails { if x != nil { return x.Details } return nil } func (x *RawNebulaCertificate) GetSignature() []byte { if x != nil { return x.Signature } return nil } type RawNebulaCertificateDetails struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` // Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask Ips []uint32 `protobuf:"varint,2,rep,packed,name=Ips,proto3" json:"Ips,omitempty"` Subnets []uint32 `protobuf:"varint,3,rep,packed,name=Subnets,proto3" json:"Subnets,omitempty"` Groups []string `protobuf:"bytes,4,rep,name=Groups,proto3" json:"Groups,omitempty"` NotBefore int64 `protobuf:"varint,5,opt,name=NotBefore,proto3" json:"NotBefore,omitempty"` NotAfter int64 `protobuf:"varint,6,opt,name=NotAfter,proto3" json:"NotAfter,omitempty"` PublicKey []byte `protobuf:"bytes,7,opt,name=PublicKey,proto3" json:"PublicKey,omitempty"` IsCA bool `protobuf:"varint,8,opt,name=IsCA,proto3" json:"IsCA,omitempty"` // sha-256 of the issuer certificate, if this field is blank the cert is self-signed Issuer []byte `protobuf:"bytes,9,opt,name=Issuer,proto3" json:"Issuer,omitempty"` Curve Curve `protobuf:"varint,100,opt,name=curve,proto3,enum=cert.Curve" json:"curve,omitempty"` } func (x *RawNebulaCertificateDetails) Reset() { *x = RawNebulaCertificateDetails{} if protoimpl.UnsafeEnabled { mi := &file_cert_v1_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RawNebulaCertificateDetails) String() string { return protoimpl.X.MessageStringOf(x) } func (*RawNebulaCertificateDetails) ProtoMessage() {} func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message { mi := &file_cert_v1_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead. func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) { return file_cert_v1_proto_rawDescGZIP(), []int{1} } func (x *RawNebulaCertificateDetails) GetName() string { if x != nil { return x.Name } return "" } func (x *RawNebulaCertificateDetails) GetIps() []uint32 { if x != nil { return x.Ips } return nil } func (x *RawNebulaCertificateDetails) GetSubnets() []uint32 { if x != nil { return x.Subnets } return nil } func (x *RawNebulaCertificateDetails) GetGroups() []string { if x != nil { return x.Groups } return nil } func (x *RawNebulaCertificateDetails) GetNotBefore() int64 { if x != nil { return x.NotBefore } return 0 } func (x *RawNebulaCertificateDetails) GetNotAfter() int64 { if x != nil { return x.NotAfter } return 0 } func (x *RawNebulaCertificateDetails) GetPublicKey() []byte { if x != nil { return x.PublicKey } return nil } func (x *RawNebulaCertificateDetails) GetIsCA() bool { if x != nil { return x.IsCA } return false } func (x *RawNebulaCertificateDetails) GetIssuer() []byte { if x != nil { return x.Issuer } return nil } func (x *RawNebulaCertificateDetails) GetCurve() Curve { if x != nil { return x.Curve } return Curve_CURVE25519 } type RawNebulaEncryptedData struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields EncryptionMetadata *RawNebulaEncryptionMetadata `protobuf:"bytes,1,opt,name=EncryptionMetadata,proto3" json:"EncryptionMetadata,omitempty"` Ciphertext []byte `protobuf:"bytes,2,opt,name=Ciphertext,proto3" json:"Ciphertext,omitempty"` } func (x *RawNebulaEncryptedData) Reset() { *x = RawNebulaEncryptedData{} if protoimpl.UnsafeEnabled { mi := &file_cert_v1_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RawNebulaEncryptedData) String() string { return protoimpl.X.MessageStringOf(x) } func (*RawNebulaEncryptedData) ProtoMessage() {} func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message { mi := &file_cert_v1_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) { return file_cert_v1_proto_rawDescGZIP(), []int{2} } func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata { if x != nil { return x.EncryptionMetadata } return nil } func (x *RawNebulaEncryptedData) GetCiphertext() []byte { if x != nil { return x.Ciphertext } return nil } type RawNebulaEncryptionMetadata struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields EncryptionAlgorithm string `protobuf:"bytes,1,opt,name=EncryptionAlgorithm,proto3" json:"EncryptionAlgorithm,omitempty"` Argon2Parameters *RawNebulaArgon2Parameters `protobuf:"bytes,2,opt,name=Argon2Parameters,proto3" json:"Argon2Parameters,omitempty"` } func (x *RawNebulaEncryptionMetadata) Reset() { *x = RawNebulaEncryptionMetadata{} if protoimpl.UnsafeEnabled { mi := &file_cert_v1_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RawNebulaEncryptionMetadata) String() string { return protoimpl.X.MessageStringOf(x) } func (*RawNebulaEncryptionMetadata) ProtoMessage() {} func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message { mi := &file_cert_v1_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead. func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) { return file_cert_v1_proto_rawDescGZIP(), []int{3} } func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string { if x != nil { return x.EncryptionAlgorithm } return "" } func (x *RawNebulaEncryptionMetadata) GetArgon2Parameters() *RawNebulaArgon2Parameters { if x != nil { return x.Argon2Parameters } return nil } type RawNebulaArgon2Parameters struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Version int32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` // rune in Go Memory uint32 `protobuf:"varint,2,opt,name=memory,proto3" json:"memory,omitempty"` Parallelism uint32 `protobuf:"varint,4,opt,name=parallelism,proto3" json:"parallelism,omitempty"` // uint8 in Go Iterations uint32 `protobuf:"varint,3,opt,name=iterations,proto3" json:"iterations,omitempty"` Salt []byte `protobuf:"bytes,5,opt,name=salt,proto3" json:"salt,omitempty"` } func (x *RawNebulaArgon2Parameters) Reset() { *x = RawNebulaArgon2Parameters{} if protoimpl.UnsafeEnabled { mi := &file_cert_v1_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RawNebulaArgon2Parameters) String() string { return protoimpl.X.MessageStringOf(x) } func (*RawNebulaArgon2Parameters) ProtoMessage() {} func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message { mi := &file_cert_v1_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead. func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) { return file_cert_v1_proto_rawDescGZIP(), []int{4} } func (x *RawNebulaArgon2Parameters) GetVersion() int32 { if x != nil { return x.Version } return 0 } func (x *RawNebulaArgon2Parameters) GetMemory() uint32 { if x != nil { return x.Memory } return 0 } func (x *RawNebulaArgon2Parameters) GetParallelism() uint32 { if x != nil { return x.Parallelism } return 0 } func (x *RawNebulaArgon2Parameters) GetIterations() uint32 { if x != nil { return x.Iterations } return 0 } func (x *RawNebulaArgon2Parameters) GetSalt() []byte { if x != nil { return x.Salt } return nil } var File_cert_v1_proto protoreflect.FileDescriptor var file_cert_v1_proto_rawDesc = []byte{ 0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( file_cert_v1_proto_rawDescOnce sync.Once file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc ) func file_cert_v1_proto_rawDescGZIP() []byte { file_cert_v1_proto_rawDescOnce.Do(func() { file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData) }) return file_cert_v1_proto_rawDescData } var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_cert_v1_proto_goTypes = []any{ (Curve)(0), // 0: cert.Curve (*RawNebulaCertificate)(nil), // 1: cert.RawNebulaCertificate (*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails (*RawNebulaEncryptedData)(nil), // 3: cert.RawNebulaEncryptedData (*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata (*RawNebulaArgon2Parameters)(nil), // 5: cert.RawNebulaArgon2Parameters } var file_cert_v1_proto_depIdxs = []int32{ 2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails 0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve 4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata 5, // 3: cert.RawNebulaEncryptionMetadata.Argon2Parameters:type_name -> cert.RawNebulaArgon2Parameters 4, // [4:4] is the sub-list for method output_type 4, // [4:4] is the sub-list for method input_type 4, // [4:4] is the sub-list for extension type_name 4, // [4:4] is the sub-list for extension extendee 0, // [0:4] is the sub-list for field type_name } func init() { file_cert_v1_proto_init() } func file_cert_v1_proto_init() { if File_cert_v1_proto != nil { return } if !protoimpl.UnsafeEnabled { file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificate); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaCertificateDetails); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptedData); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaEncryptionMetadata); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any { switch v := v.(*RawNebulaArgon2Parameters); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_cert_v1_proto_rawDesc, NumEnums: 1, NumMessages: 5, NumExtensions: 0, NumServices: 0, }, GoTypes: file_cert_v1_proto_goTypes, DependencyIndexes: file_cert_v1_proto_depIdxs, EnumInfos: file_cert_v1_proto_enumTypes, MessageInfos: file_cert_v1_proto_msgTypes, }.Build() File_cert_v1_proto = out.File file_cert_v1_proto_rawDesc = nil file_cert_v1_proto_goTypes = nil file_cert_v1_proto_depIdxs = nil } ================================================ FILE: cert/cert_v1.proto ================================================ syntax = "proto3"; package cert; option go_package = "github.com/slackhq/nebula/cert"; //import "google/protobuf/timestamp.proto"; enum Curve { CURVE25519 = 0; P256 = 1; } message RawNebulaCertificate { RawNebulaCertificateDetails Details = 1; bytes Signature = 2; } message RawNebulaCertificateDetails { string Name = 1; // Ips and Subnets are in big endian 32 bit pairs, 1st the ip, 2nd the mask repeated uint32 Ips = 2; repeated uint32 Subnets = 3; repeated string Groups = 4; int64 NotBefore = 5; int64 NotAfter = 6; bytes PublicKey = 7; bool IsCA = 8; // sha-256 of the issuer certificate, if this field is blank the cert is self-signed bytes Issuer = 9; Curve curve = 100; } message RawNebulaEncryptedData { RawNebulaEncryptionMetadata EncryptionMetadata = 1; bytes Ciphertext = 2; } message RawNebulaEncryptionMetadata { string EncryptionAlgorithm = 1; RawNebulaArgon2Parameters Argon2Parameters = 2; } message RawNebulaArgon2Parameters { int32 version = 1; // rune in Go uint32 memory = 2; uint32 parallelism = 4; // uint8 in Go uint32 iterations = 3; bytes salt = 5; } ================================================ FILE: cert/cert_v1_test.go ================================================ package cert import ( "crypto/ed25519" "fmt" "net/netip" "testing" "time" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" ) func TestCertificateV1_Marshal(t *testing.T) { t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") nc := certificateV1{ details: detailsV1{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, publicKey: pubKey, isCA: false, issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.Marshal() require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV1(b, nil) require.NoError(t, err) assert.Equal(t, Version1, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) assert.Equal(t, nc.IsCA(), nc2.IsCA()) assert.Equal(t, nc.Networks(), nc2.Networks()) assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) assert.Equal(t, nc.Groups(), nc2.Groups()) } func TestCertificateV1_Unmarshal(t *testing.T) { t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") invalidPubkey := []byte("00000000000000000000000000000000") nc := certificateV1{ details: detailsV1{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, publicKey: pubKey, isCA: false, issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } // This certificate has a pubkey included certWithPubkey, err := nc.Marshal() require.NoError(t, err) // This certificate is missing the pubkey section certWithoutPubkey, err := nc.MarshalForHandshakes() require.NoError(t, err) // Cert has no pubkey and no pubkey passed in must fail to validate isNil, err := unmarshalCertificateV1(certWithoutPubkey, nil) require.Error(t, err) // Cert has different pubkey than one passed in must fail isNil, err = unmarshalCertificateV1(certWithPubkey, invalidPubkey) require.Nil(t, isNil) require.Error(t, err) // Cert has pubkey and no pubkey argument works ok _, err = unmarshalCertificateV1(certWithPubkey, nil) require.NoError(t, err) // Cert has no pubkey and valid, correctly signed pubkey passed in nc2, err := unmarshalCertificateV1(certWithoutPubkey, pubKey) require.NoError(t, err) assert.Equal(t, pubKey, nc2.PublicKey()) } func TestCertificateV1_PublicKeyPem(t *testing.T) { t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab") nc := certificateV1{ details: detailsV1{ name: "testing", networks: []netip.Prefix{}, unsafeNetworks: []netip.Prefix{}, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, publicKey: pubKey, isCA: false, issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } assert.Equal(t, Version1, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n" assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) assert.False(t, nc.IsCA()) nc.details.isCA = true assert.Equal(t, Curve_CURVE25519, nc.Curve()) pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n" assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) assert.True(t, nc.IsCA()) pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA P256 PUBLIC KEY----- `) pubP256KeyPemCA := []byte(`-----BEGIN NEBULA ECDSA P256 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA ECDSA P256 PUBLIC KEY----- `) pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem) require.NoError(t, err) nc.details.curve = Curve_P256 nc.details.publicKey = pubP256Key assert.Equal(t, Curve_P256, nc.Curve()) assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPemCA)) assert.True(t, nc.IsCA()) nc.details.isCA = false assert.Equal(t, Curve_P256, nc.Curve()) assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) assert.False(t, nc.IsCA()) } func TestCertificateV1_Expired(t *testing.T) { nc := certificateV1{ details: detailsV1{ notBefore: time.Now().Add(time.Second * -60).Round(time.Second), notAfter: time.Now().Add(time.Second * 60).Round(time.Second), }, } assert.True(t, nc.Expired(time.Now().Add(time.Hour))) assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) assert.False(t, nc.Expired(time.Now())) } func TestCertificateV1_MarshalJSON(t *testing.T) { time.Local = time.UTC pubKey := []byte("1234567890abcedfghij1234567890ab") nc := certificateV1{ details: detailsV1{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), publicKey: pubKey, isCA: false, issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.MarshalJSON() require.NoError(t, err) assert.JSONEq( t, "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", string(b), ) } func TestCertificateV1_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) require.Error(t, err) } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) require.Error(t, err) } // Ensure that upgrading the protobuf library does not change how certificates // are marshalled, since this would break signature verification func TestMarshalingCertificateV1Consistency(t *testing.T) { before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC) after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC) pubKey := []byte("1234567890abcedfghij1234567890ab") nc := certificateV1{ details: detailsV1{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.2/16"), mustParsePrefixUnmapped("10.1.1.1/24"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.3/16"), mustParsePrefixUnmapped("9.1.1.2/24"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, publicKey: pubKey, isCA: false, issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.Marshal() require.NoError(t, err) assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) b, err = proto.Marshal(nc.getRawDetails()) require.NoError(t, err) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } func TestCertificateV1_Copy(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) cc := c.Copy() test.AssertDeepCopyEqual(t, c, cc) } func TestUnmarshalCertificateV1(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV1(data, nil) require.EqualError(t, err, "encoded Details was nil") } func appendByteSlices(b ...[]byte) []byte { retSlice := []byte{} for _, v := range b { retSlice = append(retSlice, v...) } return retSlice } func mustParsePrefixUnmapped(s string) netip.Prefix { prefix := netip.MustParsePrefix(s) return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) } ================================================ FILE: cert/cert_v2.asn1 ================================================ Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN Name ::= UTF8String (SIZE (1..253)) Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length Curve ::= ENUMERATED { curve25519 (0), p256 (1) } -- The maximum size of a certificate must not exceed 65536 bytes Certificate ::= SEQUENCE { details OCTET STRING, curve Curve DEFAULT curve25519, publicKey OCTET STRING, -- signature(details + curve + publicKey) using the appropriate method for curve signature OCTET STRING } Details ::= SEQUENCE { name Name, -- At least 1 ipv4 or ipv6 address must be present if isCA is false networks SEQUENCE OF Network OPTIONAL, unsafeNetworks SEQUENCE OF Network OPTIONAL, groups SEQUENCE OF Name OPTIONAL, isCA BOOLEAN DEFAULT false, notBefore Time, notAfter Time, -- issuer is only required if isCA is false, if isCA is true then it must not be present issuer OCTET STRING OPTIONAL, ... -- New fields can be added below here } END ================================================ FILE: cert/cert_v2.go ================================================ package cert import ( "bytes" "crypto/ecdh" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/sha256" "encoding/hex" "encoding/json" "encoding/pem" "fmt" "net/netip" "slices" "time" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" "golang.org/x/crypto/curve25519" ) const ( classConstructed = 0x20 classContextSpecific = 0x80 TagCertDetails = 0 | classConstructed | classContextSpecific TagCertCurve = 1 | classContextSpecific TagCertPublicKey = 2 | classContextSpecific TagCertSignature = 3 | classContextSpecific TagDetailsName = 0 | classContextSpecific TagDetailsNetworks = 1 | classConstructed | classContextSpecific TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific TagDetailsGroups = 3 | classConstructed | classContextSpecific TagDetailsIsCA = 4 | classContextSpecific TagDetailsNotBefore = 5 | classContextSpecific TagDetailsNotAfter = 6 | classContextSpecific TagDetailsIssuer = 7 | classContextSpecific ) const ( // MaxCertificateSize is the maximum length a valid certificate can be MaxCertificateSize = 65536 // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems MaxNameLength = 253 // MaxNetworkLength is the maximum length a network value can be. // 16 bytes for an ipv6 address + 1 byte for the prefix length MaxNetworkLength = 17 ) type certificateV2 struct { details detailsV2 // RawDetails contains the entire asn.1 DER encoded Details struct // This is to benefit forwards compatibility in signature checking. // signature(RawDetails + Curve + PublicKey) == Signature rawDetails []byte curve Curve publicKey []byte signature []byte } type detailsV2 struct { name string networks []netip.Prefix // MUST BE SORTED unsafeNetworks []netip.Prefix // MUST BE SORTED groups []string isCA bool notBefore time.Time notAfter time.Time issuer string } func (c *certificateV2) Version() Version { return Version2 } func (c *certificateV2) Curve() Curve { return c.curve } func (c *certificateV2) Groups() []string { return c.details.groups } func (c *certificateV2) IsCA() bool { return c.details.isCA } func (c *certificateV2) Issuer() string { return c.details.issuer } func (c *certificateV2) Name() string { return c.details.name } func (c *certificateV2) Networks() []netip.Prefix { return c.details.networks } func (c *certificateV2) NotAfter() time.Time { return c.details.notAfter } func (c *certificateV2) NotBefore() time.Time { return c.details.notBefore } func (c *certificateV2) PublicKey() []byte { return c.publicKey } func (c *certificateV2) MarshalPublicKeyPEM() []byte { return marshalCertPublicKeyToPEM(c) } func (c *certificateV2) Signature() []byte { return c.signature } func (c *certificateV2) UnsafeNetworks() []netip.Prefix { return c.details.unsafeNetworks } func (c *certificateV2) Fingerprint() (string, error) { if len(c.rawDetails) == 0 { return "", ErrMissingDetails } b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature)) copy(b, c.rawDetails) b[len(c.rawDetails)] = byte(c.curve) copy(b[len(c.rawDetails)+1:], c.publicKey) copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature) sum := sha256.Sum256(b) return hex.EncodeToString(sum[:]), nil } func (c *certificateV2) CheckSignature(key []byte) bool { if len(c.rawDetails) == 0 { return false } b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) copy(b, c.rawDetails) b[len(c.rawDetails)] = byte(c.curve) copy(b[len(c.rawDetails)+1:], c.publicKey) switch c.curve { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) if err != nil { return false } hashed := sha256.Sum256(b) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: return false } } func (c *certificateV2) Expired(t time.Time) bool { return c.details.notBefore.After(t) || c.details.notAfter.Before(t) } func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { if curve != c.curve { return ErrPublicPrivateCurveMismatch } if c.details.isCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise if len(key) != ed25519.PrivateKeySize { return ErrInvalidPrivateKey } if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { return ErrPublicPrivateKeyMismatch } case Curve_P256: privkey, err := ecdh.P256().NewPrivateKey(key) if err != nil { return ErrInvalidPrivateKey } pub := privkey.PublicKey().Bytes() if !bytes.Equal(pub, c.publicKey) { return ErrPublicPrivateKeyMismatch } default: return fmt.Errorf("invalid curve: %s", curve) } return nil } var pub []byte switch curve { case Curve_CURVE25519: var err error pub, err = curve25519.X25519(key, curve25519.Basepoint) if err != nil { return ErrInvalidPrivateKey } case Curve_P256: privkey, err := ecdh.P256().NewPrivateKey(key) if err != nil { return ErrInvalidPrivateKey } pub = privkey.PublicKey().Bytes() default: return fmt.Errorf("invalid curve: %s", curve) } if !bytes.Equal(pub, c.publicKey) { return ErrPublicPrivateKeyMismatch } return nil } func (c *certificateV2) String() string { mb, err := c.marshalJSON() if err != nil { return fmt.Sprintf("", err) } b, err := json.MarshalIndent(mb, "", "\t") if err != nil { return fmt.Sprintf("", err) } return string(b) } func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { if c.rawDetails == nil { return nil, ErrEmptyRawDetails } var b cryptobyte.Builder // Outermost certificate b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { // Add the cert details which is already marshalled b.AddBytes(c.rawDetails) // Skipping the curve and public key since those come across in a different part of the handshake // Add the signature b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { b.AddBytes(c.signature) }) }) return b.Bytes() } func (c *certificateV2) Marshal() ([]byte, error) { if c.rawDetails == nil { return nil, ErrEmptyRawDetails } var b cryptobyte.Builder // Outermost certificate b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { // Add the cert details which is already marshalled b.AddBytes(c.rawDetails) // Add the curve only if its not the default value if c.curve != Curve_CURVE25519 { b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) { b.AddBytes([]byte{byte(c.curve)}) }) } // Add the public key if it is not empty if c.publicKey != nil { b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) { b.AddBytes(c.publicKey) }) } // Add the signature b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { b.AddBytes(c.signature) }) }) return b.Bytes() } func (c *certificateV2) MarshalPEM() ([]byte, error) { b, err := c.Marshal() if err != nil { return nil, err } return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil } func (c *certificateV2) MarshalJSON() ([]byte, error) { b, err := c.marshalJSON() if err != nil { return nil, err } return json.Marshal(b) } func (c *certificateV2) marshalJSON() (m, error) { fp, err := c.Fingerprint() if err != nil { return nil, err } return m{ "details": m{ "name": c.details.name, "networks": c.details.networks, "unsafeNetworks": c.details.unsafeNetworks, "groups": c.details.groups, "notBefore": c.details.notBefore, "notAfter": c.details.notAfter, "isCa": c.details.isCA, "issuer": c.details.issuer, }, "version": Version2, "publicKey": fmt.Sprintf("%x", c.publicKey), "curve": c.curve.String(), "fingerprint": fp, "signature": fmt.Sprintf("%x", c.Signature()), }, nil } func (c *certificateV2) Copy() Certificate { nc := &certificateV2{ details: detailsV2{ name: c.details.name, notBefore: c.details.notBefore, notAfter: c.details.notAfter, isCA: c.details.isCA, issuer: c.details.issuer, }, curve: c.curve, publicKey: make([]byte, len(c.publicKey)), signature: make([]byte, len(c.signature)), rawDetails: make([]byte, len(c.rawDetails)), } if c.details.groups != nil { nc.details.groups = make([]string, len(c.details.groups)) copy(nc.details.groups, c.details.groups) } if c.details.networks != nil { nc.details.networks = make([]netip.Prefix, len(c.details.networks)) copy(nc.details.networks, c.details.networks) } if c.details.unsafeNetworks != nil { nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) } copy(nc.rawDetails, c.rawDetails) copy(nc.signature, c.signature) copy(nc.publicKey, c.publicKey) return nc } func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { c.details = detailsV2{ name: t.Name, networks: t.Networks, unsafeNetworks: t.UnsafeNetworks, groups: t.Groups, isCA: t.IsCA, notBefore: t.NotBefore, notAfter: t.NotAfter, issuer: t.issuer, } c.curve = t.Curve c.publicKey = t.PublicKey return c.validate() } func (c *certificateV2) validate() error { // Empty names are allowed if len(c.publicKey) == 0 { return ErrInvalidPublicKey } if !c.details.isCA && len(c.details.networks) == 0 { return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network") } hasV4Networks := false hasV6Networks := false for _, network := range c.details.networks { if !network.IsValid() || !network.Addr().IsValid() { return NewErrInvalidCertificateProperties("invalid network: %s", network) } if network.Addr().IsUnspecified() { return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network) } if network.Addr().Zone() != "" { return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network) } if network.Addr().Is4In6() { return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network) } hasV4Networks = hasV4Networks || network.Addr().Is4() hasV6Networks = hasV6Networks || network.Addr().Is6() } slices.SortFunc(c.details.networks, comparePrefix) err := findDuplicatePrefix(c.details.networks) if err != nil { return err } for _, network := range c.details.unsafeNetworks { if !network.IsValid() || !network.Addr().IsValid() { return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network) } if network.Addr().Zone() != "" { return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network) } if !c.details.isCA { if network.Addr().Is6() { if !hasV6Networks { return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) } } else if network.Addr().Is4() { if !hasV4Networks { return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) } } } } slices.SortFunc(c.details.unsafeNetworks, comparePrefix) err = findDuplicatePrefix(c.details.unsafeNetworks) if err != nil { return err } return nil } func (c *certificateV2) marshalForSigning() ([]byte, error) { d, err := c.details.Marshal() if err != nil { return nil, fmt.Errorf("marshalling certificate details failed: %w", err) } c.rawDetails = d b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) copy(b, c.rawDetails) b[len(c.rawDetails)] = byte(c.curve) copy(b[len(c.rawDetails)+1:], c.publicKey) return b, nil } func (c *certificateV2) setSignature(b []byte) error { if len(b) == 0 { return ErrEmptySignature } c.signature = b return nil } func (d *detailsV2) Marshal() ([]byte, error) { var b cryptobyte.Builder var err error // Details are a structure b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) { // Add the name b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) { b.AddBytes([]byte(d.name)) }) // Add the networks if any exist if len(d.networks) > 0 { b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) { for _, n := range d.networks { sb, innerErr := n.MarshalBinary() if innerErr != nil { // MarshalBinary never returns an error err = fmt.Errorf("unable to marshal network: %w", innerErr) return } b.AddASN1OctetString(sb) } }) } // Add the unsafe networks if any exist if len(d.unsafeNetworks) > 0 { b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) { for _, n := range d.unsafeNetworks { sb, innerErr := n.MarshalBinary() if innerErr != nil { // MarshalBinary never returns an error err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr) return } b.AddASN1OctetString(sb) } }) } // Add groups if any exist if len(d.groups) > 0 { b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) { for _, group := range d.groups { b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) { b.AddBytes([]byte(group)) }) } }) } // Add IsCA only if true if d.isCA { b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) { b.AddUint8(0xff) }) } // Add not before b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore) // Add not after b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter) // Add the issuer if present if d.issuer != "" { issuerBytes, innerErr := hex.DecodeString(d.issuer) if innerErr != nil { err = fmt.Errorf("failed to decode issuer: %w", innerErr) return } b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) { b.AddBytes(issuerBytes) }) } }) if err != nil { return nil, err } return b.Bytes() } func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { l := len(b) if l == 0 || l > MaxCertificateSize { return nil, ErrBadFormat } input := cryptobyte.String(b) // Open the envelope if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() { return nil, ErrBadFormat } // Grab the cert details, we need to preserve the tag and length var rawDetails cryptobyte.String if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() { return nil, ErrBadFormat } //Maybe grab the curve var rawCurve byte if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { return nil, ErrBadFormat } curve = Curve(rawCurve) // Maybe grab the public key var rawPublicKey cryptobyte.String if len(publicKey) > 0 { // If a public key is passed in, then the handshake certificate must // not have a public key present if input.PeekASN1Tag(TagCertPublicKey) { return nil, ErrCertPubkeyPresent } rawPublicKey = make(cryptobyte.String, len(publicKey)) copy(rawPublicKey, publicKey) } else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) { return nil, ErrBadFormat } if len(rawPublicKey) == 0 { return nil, ErrBadFormat } // Grab the signature var rawSignature cryptobyte.String if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() { return nil, ErrBadFormat } // Finally unmarshal the details details, err := unmarshalDetails(rawDetails) if err != nil { return nil, err } c := &certificateV2{ details: details, rawDetails: rawDetails, curve: curve, publicKey: rawPublicKey, signature: rawSignature, } err = c.validate() if err != nil { return nil, err } return c, nil } func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { // Open the envelope if !b.ReadASN1(&b, TagCertDetails) || b.Empty() { return detailsV2{}, ErrBadFormat } // Read the name var name cryptobyte.String if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength { return detailsV2{}, ErrBadFormat } // Read the network addresses var subString cryptobyte.String var found bool if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) { return detailsV2{}, ErrBadFormat } var networks []netip.Prefix var val cryptobyte.String if found { for !subString.Empty() { if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { return detailsV2{}, ErrBadFormat } var n netip.Prefix if err := n.UnmarshalBinary(val); err != nil { return detailsV2{}, ErrBadFormat } networks = append(networks, n) } } // Read out any unsafe networks if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) { return detailsV2{}, ErrBadFormat } var unsafeNetworks []netip.Prefix if found { for !subString.Empty() { if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { return detailsV2{}, ErrBadFormat } var n netip.Prefix if err := n.UnmarshalBinary(val); err != nil { return detailsV2{}, ErrBadFormat } unsafeNetworks = append(unsafeNetworks, n) } } // Read out any groups if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) { return detailsV2{}, ErrBadFormat } var groups []string if found { for !subString.Empty() { if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() { return detailsV2{}, ErrBadFormat } groups = append(groups, string(val)) } } // Read out IsCA var isCa bool if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) { return detailsV2{}, ErrBadFormat } // Read not before and not after var notBefore int64 if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) { return detailsV2{}, ErrBadFormat } var notAfter int64 if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) { return detailsV2{}, ErrBadFormat } // Read issuer var issuer cryptobyte.String if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) { return detailsV2{}, ErrBadFormat } return detailsV2{ name: string(name), networks: networks, unsafeNetworks: unsafeNetworks, groups: groups, isCA: isCa, notBefore: time.Unix(notBefore, 0), notAfter: time.Unix(notAfter, 0), issuer: hex.EncodeToString(issuer), }, nil } ================================================ FILE: cert/cert_v2_test.go ================================================ package cert import ( "crypto/ed25519" "crypto/rand" "encoding/hex" "net/netip" "slices" "testing" "time" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCertificateV2_Marshal(t *testing.T) { t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") nc := certificateV2{ details: detailsV2{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.2/16"), mustParsePrefixUnmapped("10.1.1.1/24"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.3/16"), mustParsePrefixUnmapped("9.1.1.2/24"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, isCA: false, issuer: "1234567890abcdef1234567890abcdef", }, signature: []byte("1234567890abcdef1234567890abcdef"), publicKey: pubKey, } db, err := nc.details.Marshal() require.NoError(t, err) nc.rawDetails = db b, err := nc.Marshal() require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) require.NoError(t, err) assert.Equal(t, Version2, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, nc.Signature(), nc2.Signature()) assert.Equal(t, nc.Name(), nc2.Name()) assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) assert.Equal(t, nc.IsCA(), nc2.IsCA()) assert.Equal(t, nc.Issuer(), nc2.Issuer()) // unmarshalling will sort networks and unsafeNetworks, we need to do the same // but first make sure it fails assert.NotEqual(t, nc.Networks(), nc2.Networks()) assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) slices.SortFunc(nc.details.networks, comparePrefix) slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) assert.Equal(t, nc.Networks(), nc2.Networks()) assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) assert.Equal(t, nc.Groups(), nc2.Groups()) } func TestCertificateV2_Unmarshal(t *testing.T) { t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") nc := certificateV2{ details: detailsV2{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.2/16"), mustParsePrefixUnmapped("10.1.1.1/24"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.3/16"), mustParsePrefixUnmapped("9.1.1.2/24"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, isCA: false, issuer: "1234567890abcdef1234567890abcdef", }, signature: []byte("1234567890abcdef1234567890abcdef"), publicKey: pubKey, } db, err := nc.details.Marshal() require.NoError(t, err) nc.rawDetails = db certWithPubkey, err := nc.Marshal() require.NoError(t, err) //t.Log("Cert size:", len(b)) certWithoutPubkey, err := nc.MarshalForHandshakes() require.NoError(t, err) // Cert must not have a pubkey if one is passed in as an argument _, err = unmarshalCertificateV2(certWithPubkey, pubKey, Curve_CURVE25519) require.ErrorIs(t, err, ErrCertPubkeyPresent) // Certs must have pubkeys _, err = unmarshalCertificateV2(certWithoutPubkey, nil, Curve_CURVE25519) require.ErrorIs(t, err, ErrBadFormat) // Ensure proper unmarshal if a pubkey is passed in nc2, err := unmarshalCertificateV2(certWithoutPubkey, pubKey, Curve_CURVE25519) require.NoError(t, err) assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) } func TestCertificateV2_PublicKeyPem(t *testing.T) { t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab") nc := certificateV2{ details: detailsV2{ name: "testing", networks: []netip.Prefix{}, unsafeNetworks: []netip.Prefix{}, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, isCA: false, issuer: "1234567890abcedfghij1234567890ab", }, publicKey: pubKey, signature: []byte("1234567890abcedfghij1234567890ab"), } assert.Equal(t, Version2, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n" assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) assert.False(t, nc.IsCA()) nc.details.isCA = true assert.Equal(t, Curve_CURVE25519, nc.Curve()) pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n" assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) assert.True(t, nc.IsCA()) pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA P256 PUBLIC KEY----- `) pubP256KeyPemCA := []byte(`-----BEGIN NEBULA ECDSA P256 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA ECDSA P256 PUBLIC KEY----- `) pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem) require.NoError(t, err) nc.curve = Curve_P256 nc.publicKey = pubP256Key assert.Equal(t, Curve_P256, nc.Curve()) assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPemCA)) assert.True(t, nc.IsCA()) nc.details.isCA = false assert.Equal(t, Curve_P256, nc.Curve()) assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) assert.False(t, nc.IsCA()) } func TestCertificateV2_Expired(t *testing.T) { nc := certificateV2{ details: detailsV2{ notBefore: time.Now().Add(time.Second * -60).Round(time.Second), notAfter: time.Now().Add(time.Second * 60).Round(time.Second), }, } assert.True(t, nc.Expired(time.Now().Add(time.Hour))) assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) assert.False(t, nc.Expired(time.Now())) } func TestCertificateV2_MarshalJSON(t *testing.T) { time.Local = time.UTC pubKey := []byte("1234567890abcedf1234567890abcedf") nc := certificateV2{ details: detailsV2{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), isCA: false, issuer: "1234567890abcedf1234567890abcedf", }, publicKey: pubKey, signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), } b, err := nc.MarshalJSON() require.ErrorIs(t, err, ErrMissingDetails) rd, err := nc.details.Marshal() require.NoError(t, err) nc.rawDetails = rd b, err = nc.MarshalJSON() require.NoError(t, err) assert.JSONEq( t, "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", string(b), ) } func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) require.ErrorIs(t, err, ErrInvalidPrivateKey) _, caKey2, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) require.ErrorIs(t, err, ErrInvalidPrivateKey) ac, ok := c.(*certificateV2) require.True(t, ok) ac.curve = Curve(99) err = c.VerifyPrivateKey(Curve(99), priv2) require.EqualError(t, err, "invalid curve: 99") ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) require.NoError(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) require.ErrorIs(t, err, ErrInvalidPrivateKey) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) err = c.VerifyPrivateKey(Curve_P256, priv[:16]) require.ErrorIs(t, err, ErrInvalidPrivateKey) err = c.VerifyPrivateKey(Curve_P256, priv) require.ErrorIs(t, err, ErrInvalidPrivateKey) aCa, ok := ca2.(*certificateV2) require.True(t, ok) aCa.curve = Curve(99) err = aCa.VerifyPrivateKey(Curve(99), priv2) require.EqualError(t, err, "invalid curve: 99") } func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) require.Error(t, err) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) require.Error(t, err) } func TestCertificateV2_Copy(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) cc := c.Copy() test.AssertDeepCopyEqual(t, c, cc) } func TestUnmarshalCertificateV2(t *testing.T) { data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) require.EqualError(t, err, "bad wire format") } func TestCertificateV2_marshalForSigningStability(t *testing.T) { before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC) after := before.Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") nc := certificateV2{ details: detailsV2{ name: "testing", networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.2/16"), mustParsePrefixUnmapped("10.1.1.1/24"), }, unsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.3/16"), mustParsePrefixUnmapped("9.1.1.2/24"), }, groups: []string{"test-group1", "test-group2", "test-group3"}, notBefore: before, notAfter: after, isCA: false, issuer: "1234567890abcdef1234567890abcdef", }, signature: []byte("1234567890abcdef1234567890abcdef"), publicKey: pubKey, } const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef" expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr) require.NoError(t, err) db, err := nc.details.Marshal() require.NoError(t, err) assert.Equal(t, expectedRawDetails, db) expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") b, err := nc.marshalForSigning() require.NoError(t, err) assert.Equal(t, expectedForSigning, b) } ================================================ FILE: cert/crypto.go ================================================ package cert import ( "crypto/aes" "crypto/cipher" "crypto/ed25519" "crypto/rand" "encoding/pem" "fmt" "io" "math" "golang.org/x/crypto/argon2" "google.golang.org/protobuf/proto" ) type NebulaEncryptedData struct { EncryptionMetadata NebulaEncryptionMetadata Ciphertext []byte } type NebulaEncryptionMetadata struct { EncryptionAlgorithm string Argon2Parameters Argon2Parameters } // Argon2Parameters KDF factors type Argon2Parameters struct { version rune Memory uint32 // KiB Parallelism uint8 Iterations uint32 salt []byte } // NewArgon2Parameters Returns a new Argon2Parameters object with current version set func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters { return &Argon2Parameters{ version: argon2.Version, Memory: memory, // KiB Parallelism: parallelism, Iterations: iterations, } } // Encrypts data using AES-256-GCM and the Argon2id key derivation function func aes256Encrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { key, err := aes256DeriveKey(passphrase, kdfParams) if err != nil { return nil, err } // this should never happen, but since this dictates how our calls into the // aes package behave and could be catastraphic, let's sanity check this if len(key) != 32 { return nil, fmt.Errorf("invalid AES-256 key length (%d) - cowardly refusing to encrypt", len(key)) } block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } ciphertext := gcm.Seal(nil, nonce, data, nil) blob := joinNonceCiphertext(nonce, ciphertext) return blob, nil } // Decrypts data using AES-256-GCM and the Argon2id key derivation function // Expects the data to include an Argon2id parameter string before the encrypted data func aes256Decrypt(passphrase []byte, kdfParams *Argon2Parameters, data []byte) ([]byte, error) { key, err := aes256DeriveKey(passphrase, kdfParams) if err != nil { return nil, err } block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce, ciphertext, err := splitNonceCiphertext(data, gcm.NonceSize()) if err != nil { return nil, err } plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, fmt.Errorf("invalid passphrase or corrupt private key") } return plaintext, nil } func aes256DeriveKey(passphrase []byte, params *Argon2Parameters) ([]byte, error) { if params.salt == nil { params.salt = make([]byte, 32) if _, err := rand.Read(params.salt); err != nil { return nil, err } } // keySize of 32 bytes will result in AES-256 encryption key, err := deriveKey(passphrase, 32, params) if err != nil { return nil, err } return key, nil } // Derives a key from a passphrase using Argon2id func deriveKey(passphrase []byte, keySize uint32, params *Argon2Parameters) ([]byte, error) { if params.version != argon2.Version { return nil, fmt.Errorf("incompatible Argon2 version: %d", params.version) } if params.salt == nil { return nil, fmt.Errorf("salt must be set in argon2Parameters") } else if len(params.salt) < 16 { return nil, fmt.Errorf("salt must be at least 128 bits") } key := argon2.IDKey(passphrase, params.salt, params.Iterations, params.Memory, params.Parallelism, keySize) return key, nil } // Prepends nonce to ciphertext func joinNonceCiphertext(nonce []byte, ciphertext []byte) []byte { return append(nonce, ciphertext...) } // Splits nonce from ciphertext func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) { if len(blob) <= nonceSize { return nil, nil, fmt.Errorf("invalid ciphertext blob - blob shorter than nonce length") } return blob[:nonceSize], blob[nonceSize:], nil } // EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) { ciphertext, err := aes256Encrypt(passphrase, kdfParams, b) if err != nil { return nil, err } b, err = proto.Marshal(&RawNebulaEncryptedData{ EncryptionMetadata: &RawNebulaEncryptionMetadata{ EncryptionAlgorithm: "AES-256-GCM", Argon2Parameters: &RawNebulaArgon2Parameters{ Version: kdfParams.version, Memory: kdfParams.Memory, Parallelism: uint32(kdfParams.Parallelism), Iterations: kdfParams.Iterations, Salt: kdfParams.salt, }, }, Ciphertext: ciphertext, }) if err != nil { return nil, err } switch curve { case Curve_CURVE25519: return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil case Curve_P256: return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil default: return nil, fmt.Errorf("invalid curve: %v", curve) } } // UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its // protobuf-generated struct. func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) { if len(b) == 0 { return nil, fmt.Errorf("nil byte array") } var rned RawNebulaEncryptedData err := proto.Unmarshal(b, &rned) if err != nil { return nil, err } if rned.EncryptionMetadata == nil { return nil, fmt.Errorf("encoded EncryptionMetadata was nil") } if rned.EncryptionMetadata.Argon2Parameters == nil { return nil, fmt.Errorf("encoded Argon2Parameters was nil") } params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters) if err != nil { return nil, err } ned := NebulaEncryptedData{ EncryptionMetadata: NebulaEncryptionMetadata{ EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm, Argon2Parameters: *params, }, Ciphertext: rned.Ciphertext, } return &ned, nil } func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) { if params.Version < math.MinInt32 || params.Version > math.MaxInt32 { return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32) } if params.Memory <= 0 || params.Memory > math.MaxUint32 { return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) } if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 { return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8) } if params.Iterations <= 0 || params.Iterations > math.MaxUint32 { return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) } return &Argon2Parameters{ version: params.Version, Memory: params.Memory, Parallelism: uint8(params.Parallelism), Iterations: params.Iterations, salt: params.Salt, }, nil } // DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with // the given passphrase, returning any other bytes b or an error on failure func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) { var curve Curve k, r := pem.Decode(b) if k == nil { return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block") } switch k.Type { case EncryptedEd25519PrivateKeyBanner: curve = Curve_CURVE25519 case EncryptedECDSAP256PrivateKeyBanner: curve = Curve_P256 default: return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") } ned, err := UnmarshalNebulaEncryptedData(k.Bytes) if err != nil { return curve, nil, r, err } var bytes []byte switch ned.EncryptionMetadata.EncryptionAlgorithm { case "AES-256-GCM": bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext) if err != nil { return curve, nil, r, err } default: return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm) } switch curve { case Curve_CURVE25519: if len(bytes) != ed25519.PrivateKeySize { return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize) } case Curve_P256: if len(bytes) != 32 { return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") } } return curve, bytes, r, nil } ================================================ FILE: cert/crypto_test.go ================================================ package cert import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/argon2" ) func TestNewArgon2Parameters(t *testing.T) { p := NewArgon2Parameters(64*1024, 4, 3) assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 64 * 1024, Parallelism: 4, Iterations: 3, }, p) p = NewArgon2Parameters(2*1024*1024, 2, 1) assert.Equal(t, &Argon2Parameters{ version: argon2.Version, Memory: 2 * 1024 * 1024, Parallelism: 2, Iterations: 1, }, p) } func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) { passphrase := []byte("DO NOT USE") privKey := []byte(`# A good key -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiCPoDfGQiosxNPTbPn5EsMlc2MI c0Bt4oz6gTrFQhX3aBJcimhHKeAuhyTGvllD0Z19fe+DFPcLH3h5VrdjVfIAajg0 KrbV3n9UHif/Au5skWmquNJzoW1E4MTdRbvpti6o+WdQ49DxjBFhx0YH8LBqrbPU 0BGkUHmIO7daP24= -----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- `) shortKey := []byte(`# A key which, once decrypted, is too short -----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- CjsKC0FFUy0yNTYtR0NNEiwIExCAgAQYAyAEKiAVJwdfl3r+eqi/vF6S7OMdpjfo hAzmTCRnr58Su4AqmBJbCv3zleYCEKYJP6UI3S8ekLMGISsgO4hm5leukCCyqT0Z cQ76yrberpzkJKoPLGisX8f+xdy4aXSZl7oEYWQte1+vqbtl/eY9PGZhxUQdcyq7 hqzIyrRqfUgVuA== -----END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- `) invalidBanner := []byte(`# Invalid banner (not encrypted) -----BEGIN NEBULA ED25519 PRIVATE KEY----- bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG XgLvodMXZJuaFPssp+WwtA== -----END NEBULA ED25519 PRIVATE KEY----- `) invalidPem := []byte(`# Not a valid PEM format -BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY----- CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl +Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB qrlJ69wer3ZUHFXA -END NEBULA ED25519 ENCRYPTED PRIVATE KEY----- `) keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem) // Success test case curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) require.NoError(t, err) assert.Equal(t, Curve_CURVE25519, curve) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) // Fail due to invalid banner curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to invalid PEM format, because // it's missing the requisite pre-encapsulation boundary. curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) require.EqualError(t, err, "input did not contain a valid PEM encoded block") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to invalid passphrase curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) require.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) assert.Equal(t, []byte{}, rest) } func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { // Having proved that decryption works correctly above, we can test the // encryption function produces a value which can be decrypted passphrase := []byte("passphrase") bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") kdfParams := NewArgon2Parameters(64*1024, 4, 3) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) require.NoError(t, err) // Verify the "key" can be decrypted successfully curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) assert.Len(t, k, 64) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, []byte{}, rest) require.NoError(t, err) // EncryptAndMarshalEd25519PrivateKey does not create any errors itself } ================================================ FILE: cert/errors.go ================================================ package cert import ( "errors" "fmt" ) var ( ErrBadFormat = errors.New("bad wire format") ErrRootExpired = errors.New("root certificate is expired") ErrExpired = errors.New("certificate is expired") ErrNotCA = errors.New("certificate is not a CA") ErrNotSelfSigned = errors.New("certificate is not self-signed") ErrBlockListed = errors.New("certificate is in the block list") ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") ErrSignatureMismatch = errors.New("certificate signature did not match") ErrInvalidPublicKey = errors.New("invalid public key") ErrInvalidPrivateKey = errors.New("invalid private key") ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrCaNotFound = errors.New("could not find ca for the certificate") ErrUnknownVersion = errors.New("certificate version unrecognized") ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") ErrInvalidPEMX25519PublicKeyBanner = errors.New("bytes did not contain a proper X25519 public key banner") ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") ErrNoPeerStaticKey = errors.New("no peer static key was present") ErrNoPayload = errors.New("provided payload was empty") ErrMissingDetails = errors.New("certificate did not contain details") ErrEmptySignature = errors.New("empty signature") ErrEmptyRawDetails = errors.New("empty rawDetails not allowed") ) type ErrInvalidCertificateProperties struct { str string } func NewErrInvalidCertificateProperties(format string, a ...any) error { return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)} } func (e *ErrInvalidCertificateProperties) Error() string { return e.str } ================================================ FILE: cert/helper_test.go ================================================ package cert import ( "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "io" "net/netip" "time" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" ) // NewTestCaCert will create a new ca certificate func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { var err error var pub, priv []byte switch curve { case Curve_CURVE25519: pub, priv, err = ed25519.GenerateKey(rand.Reader) case Curve_P256: privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { panic(err) } pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) priv = privk.D.FillBytes(make([]byte, 32)) default: // There is no default to allow the underlying lib to respond with an error } if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } if after.IsZero() { after = time.Now().Add(time.Second * 60).Round(time.Second) } t := &TBSCertificate{ Curve: curve, Version: version, Name: "test ca", NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), PublicKey: pub, Networks: networks, UnsafeNetworks: unsafeNetworks, Groups: groups, IsCA: true, } c, err := t.Sign(nil, curve, priv) if err != nil { panic(err) } pem, err := c.MarshalPEM() if err != nil { panic(err) } return c, pub, priv, pem } // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } if after.IsZero() { after = time.Now().Add(time.Second * 60).Round(time.Second) } if len(networks) == 0 { networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} } var pub, priv []byte switch curve { case Curve_CURVE25519: pub, priv = X25519Keypair() case Curve_P256: pub, priv = P256Keypair() default: panic("unknown curve") } nc := &TBSCertificate{ Version: v, Curve: curve, Name: name, Networks: networks, UnsafeNetworks: unsafeNetworks, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), PublicKey: pub, IsCA: false, } c, err := nc.Sign(ca, ca.Curve(), key) if err != nil { panic(err) } pem, err := c.MarshalPEM() if err != nil { panic(err) } return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem } func X25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { panic(err) } pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) if err != nil { panic(err) } return pubkey, privkey } func P256Keypair() ([]byte, []byte) { privkey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { panic(err) } pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } ================================================ FILE: cert/p256/p256.go ================================================ package p256 import ( "crypto/elliptic" "errors" "math/big" "filippo.io/bigmod" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" ) var halfN = new(big.Int).Rsh(elliptic.P256().Params().N, 1) var nMod *bigmod.Modulus func init() { n, err := bigmod.NewModulus(elliptic.P256().Params().N.Bytes()) if err != nil { panic(err) } nMod = n } func IsNormalized(sig []byte) (bool, error) { r, s, err := parseSignature(sig) if err != nil { return false, err } return checkLowS(r, s), nil } func checkLowS(_, s []byte) bool { bigS := new(big.Int).SetBytes(s) // Check if S <= (N/2), because we want to include the midpoint in the set of low-s return bigS.Cmp(halfN) <= 0 } func swap(r, s []byte) ([]byte, []byte, error) { var err error bigS, err := bigmod.NewNat().SetBytes(s, nMod) if err != nil { return nil, nil, err } sNormalized := nMod.Nat().Sub(bigS, nMod) return r, sNormalized.Bytes(nMod), nil } func Normalize(sig []byte) ([]byte, error) { r, s, err := parseSignature(sig) if err != nil { return nil, err } if checkLowS(r, s) { return sig, nil } newR, newS, err := swap(r, s) if err != nil { return nil, err } return encodeSignature(newR, newS) } // Swap will change sig between its current form to the opposite high or low form. func Swap(sig []byte) ([]byte, error) { r, s, err := parseSignature(sig) if err != nil { return nil, err } newR, newS, err := swap(r, s) if err != nil { return nil, err } return encodeSignature(newR, newS) } // parseSignature taken exactly from crypto/ecdsa/ecdsa.go func parseSignature(sig []byte) (r, s []byte, err error) { var inner cryptobyte.String input := cryptobyte.String(sig) if !input.ReadASN1(&inner, asn1.SEQUENCE) || !input.Empty() || !inner.ReadASN1Integer(&r) || !inner.ReadASN1Integer(&s) || !inner.Empty() { return nil, nil, errors.New("invalid ASN.1") } return r, s, nil } func encodeSignature(r, s []byte) ([]byte, error) { var b cryptobyte.Builder b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { addASN1IntBytes(b, r) addASN1IntBytes(b, s) }) return b.Bytes() } // addASN1IntBytes encodes in ASN.1 a positive integer represented as // a big-endian byte slice with zero or more leading zeroes. func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { for len(bytes) > 0 && bytes[0] == 0 { bytes = bytes[1:] } if len(bytes) == 0 { b.SetError(errors.New("invalid integer")) return } b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) { if bytes[0]&0x80 != 0 { c.AddUint8(0) } c.AddBytes(bytes) }) } ================================================ FILE: cert/p256/p256_test.go ================================================ package p256 import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "testing" "github.com/stretchr/testify/require" ) func TestFlipping(t *testing.T) { priv, err1 := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err1) out, err := ecdsa.SignASN1(rand.Reader, priv, []byte("big chungus")) require.NoError(t, err) r, s, err := parseSignature(out) require.NoError(t, err) r, s1, err := swap(r, s) require.NoError(t, err) r, s2, err := swap(r, s1) require.NoError(t, err) require.Equal(t, s, s2) require.NotEqual(t, s, s1) } ================================================ FILE: cert/pem.go ================================================ package cert import ( "encoding/pem" "fmt" "golang.org/x/crypto/ed25519" ) const ( //cert banners CertificateBanner = "NEBULA CERTIFICATE" CertificateV2Banner = "NEBULA CERTIFICATE V2" ) const ( //key-agreement-key banners X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" ) /* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */ const ( //signing key banners EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY" EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" ) // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed // data or an error on failure func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { p, r := pem.Decode(b) if p == nil { return nil, r, ErrInvalidPEMBlock } var c Certificate var err error switch p.Type { // Implementations must validate the resulting certificate contains valid information case CertificateBanner: c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) default: return nil, r, ErrInvalidPEMCertificateBanner } if err != nil { return nil, r, err } return c, r, nil } func marshalCertPublicKeyToPEM(c Certificate) []byte { if c.IsCA() { return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) } else { return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey()) } } // MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH. // if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes! func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { switch curve { case Curve_CURVE25519: return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b}) case Curve_P256: return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) default: return nil } } // MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing. // if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes! func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte { switch curve { case Curve_CURVE25519: return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b}) case Curve_P256: return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PublicKeyBanner, Bytes: b}) default: return nil } } func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { k, r := pem.Decode(b) if k == nil { return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") } var expectedLen int var curve Curve switch k.Type { case X25519PublicKeyBanner, Ed25519PublicKeyBanner: expectedLen = 32 curve = Curve_CURVE25519 case P256PublicKeyBanner, ECDSAP256PublicKeyBanner: // Uncompressed expectedLen = 65 curve = Curve_P256 default: return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner") } if len(k.Bytes) != expectedLen { return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve) } return k.Bytes, r, curve, nil } func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte { switch curve { case Curve_CURVE25519: return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b}) case Curve_P256: return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b}) default: return nil } } func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte { switch curve { case Curve_CURVE25519: return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b}) case Curve_P256: return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b}) default: return nil } } // UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non // consumed data or an error on failure func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { k, r := pem.Decode(b) if k == nil { return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") } var expectedLen int var curve Curve switch k.Type { case X25519PrivateKeyBanner: expectedLen = 32 curve = Curve_CURVE25519 case P256PrivateKeyBanner: expectedLen = 32 curve = Curve_P256 default: return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner") } if len(k.Bytes) != expectedLen { return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve) } return k.Bytes, r, curve, nil } func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { k, r := pem.Decode(b) if k == nil { return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block") } var curve Curve switch k.Type { case EncryptedEd25519PrivateKeyBanner: return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted case EncryptedECDSAP256PrivateKeyBanner: return nil, nil, Curve_P256, ErrPrivateKeyEncrypted case Ed25519PrivateKeyBanner: curve = Curve_CURVE25519 if len(k.Bytes) != ed25519.PrivateKeySize { return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize) } case ECDSAP256PrivateKeyBanner: curve = Curve_P256 if len(k.Bytes) != 32 { return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key") } default: return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner") } return k.Bytes, r, curve, nil } ================================================ FILE: cert/pem_test.go ================================================ package cert import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestUnmarshalCertificateFromPEM(t *testing.T) { goodCert := []byte(` # A good cert -----BEGIN NEBULA CERTIFICATE----- CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB -----END NEBULA CERTIFICATE----- `) badBanner := []byte(`# A bad banner -----BEGIN NOT A NEBULA CERTIFICATE----- CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB -----END NOT A NEBULA CERTIFICATE----- `) invalidPem := []byte(`# Not a valid PEM format -BEGIN NEBULA CERTIFICATE----- CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB -END NEBULA CERTIFICATE----`) certBundle := appendByteSlices(goodCert, badBanner, invalidPem) // Success test case cert, rest, err := UnmarshalCertificateFromPEM(certBundle) assert.NotNil(t, cert) assert.Equal(t, rest, append(badBanner, invalidPem...)) require.NoError(t, err) // Fail due to invalid banner. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "bytes did not contain a proper certificate banner") // Fail due to invalid PEM format, because // it's missing the requisite pre-encapsulation boundary. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { privKey := []byte(`# A good key -----BEGIN NEBULA ED25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -----END NEBULA ED25519 PRIVATE KEY----- `) privP256Key := []byte(`# A good key -----BEGIN NEBULA ECDSA P256 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA ECDSA P256 PRIVATE KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA ED25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA -----END NEBULA ED25519 PRIVATE KEY----- `) invalidBanner := []byte(`# Invalid banner -----BEGIN NOT A NEBULA PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -----END NOT A NEBULA PRIVATE KEY----- `) invalidPem := []byte(`# Not a valid PEM format -BEGIN NEBULA ED25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -END NEBULA ED25519 PRIVATE KEY-----`) keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) // Success test case k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") // Fail due to invalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { privKey := []byte(`# A good key -----BEGIN NEBULA X25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA X25519 PRIVATE KEY----- `) privP256Key := []byte(`# A good key -----BEGIN NEBULA P256 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA P256 PRIVATE KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA X25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -----END NEBULA X25519 PRIVATE KEY----- `) invalidBanner := []byte(`# Invalid banner -----BEGIN NOT A NEBULA PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NOT A NEBULA PRIVATE KEY----- `) invalidPem := []byte(`# Not a valid PEM format -BEGIN NEBULA X25519 PRIVATE KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -END NEBULA X25519 PRIVATE KEY-----`) keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem) // Success test case k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "bytes did not contain a proper private key banner") // Fail due to invalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPublicKeyFromPEM(t *testing.T) { t.Parallel() pubKey := []byte(`# A good key -----BEGIN NEBULA ED25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA ED25519 PUBLIC KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA ED25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -----END NEBULA ED25519 PUBLIC KEY----- `) invalidBanner := []byte(`# Invalid banner -----BEGIN NOT A NEBULA PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NOT A NEBULA PUBLIC KEY----- `) invalidPem := []byte(`# Not a valid PEM format -BEGIN NEBULA ED25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -END NEBULA ED25519 PUBLIC KEY-----`) keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem) // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) assert.Equal(t, Curve_CURVE25519, curve) require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to invalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalX25519PublicKey(t *testing.T) { t.Parallel() pubKey := []byte(`# A good key -----BEGIN NEBULA X25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA X25519 PUBLIC KEY----- `) pubP256Key := []byte(`# A good key -----BEGIN NEBULA P256 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA P256 PUBLIC KEY----- `) oldPubP256Key := []byte(`# A good key -----BEGIN NEBULA ECDSA P256 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA ECDSA P256 PUBLIC KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA X25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA== -----END NEBULA X25519 PUBLIC KEY----- `) invalidBanner := []byte(`# Invalid banner -----BEGIN NOT A NEBULA PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -----END NOT A NEBULA PUBLIC KEY----- `) invalidPem := []byte(`# Not a valid PEM format -BEGIN NEBULA X25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -END NEBULA X25519 PUBLIC KEY-----`) keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem) // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Len(t, k, 65) require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Len(t, k, 65) require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) // Fail due to short key k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to invalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) require.EqualError(t, err, "input did not contain a valid PEM encoded block") } ================================================ FILE: cert/sign.go ================================================ package cert import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/sha256" "fmt" "net/netip" "time" "github.com/slackhq/nebula/cert/p256" ) // TBSCertificate represents a certificate intended to be signed. // It is invalid to use this structure as a Certificate. type TBSCertificate struct { Version Version Name string Networks []netip.Prefix UnsafeNetworks []netip.Prefix Groups []string IsCA bool NotBefore time.Time NotAfter time.Time PublicKey []byte Curve Curve issuer string } type beingSignedCertificate interface { // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation // Implementations must validate the resulting certificate contains valid information fromTBSCertificate(*TBSCertificate) error // marshalForSigning returns the bytes that should be signed marshalForSigning() ([]byte, error) // setSignature sets the signature for the certificate that has just been signed. The signature must not be blank. setSignature([]byte) error } type SignerLambda func(certBytes []byte) ([]byte, error) // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // details do not violate constraints of the signing certificate. // If the TBSCertificate is a CA then signer must be nil. func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { switch t.Curve { case Curve_CURVE25519: pk := ed25519.PrivateKey(key) sp := func(certBytes []byte) ([]byte, error) { sig := ed25519.Sign(pk, certBytes) return sig, nil } return t.SignWith(signer, curve, sp) case Curve_P256: pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key) if err != nil { return nil, err } sp := func(certBytes []byte) ([]byte, error) { // We need to hash first for ECDSA // - https://pkg.go.dev/crypto/ecdsa#SignASN1 hashed := sha256.Sum256(certBytes) return ecdsa.SignASN1(rand.Reader, pk, hashed[:]) } return t.SignWith(signer, curve, sp) default: return nil, fmt.Errorf("invalid curve: %s", t.Curve) } } // SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature. // You should only use SignWith if you do not have direct access to your private key. func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) { if curve != t.Curve { return nil, fmt.Errorf("curve in cert and private key supplied don't match") } if signer != nil { if t.IsCA { return nil, fmt.Errorf("can not sign a CA certificate with another") } err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks) if err != nil { return nil, err } issuer, err := signer.Fingerprint() if err != nil { return nil, fmt.Errorf("error computing issuer: %v", err) } t.issuer = issuer } else { if !t.IsCA { return nil, fmt.Errorf("self signed certificates must have IsCA set to true") } } var c beingSignedCertificate switch t.Version { case Version1: c = &certificateV1{} err := c.fromTBSCertificate(t) if err != nil { return nil, err } case Version2: c = &certificateV2{} err := c.fromTBSCertificate(t) if err != nil { return nil, err } default: return nil, fmt.Errorf("unknown cert version %d", t.Version) } certBytes, err := c.marshalForSigning() if err != nil { return nil, err } sig, err := sp(certBytes) if err != nil { return nil, err } if curve == Curve_P256 { sig, err = p256.Normalize(sig) if err != nil { return nil, err } } err = c.setSignature(sig) if err != nil { return nil, err } sc, ok := c.(Certificate) if !ok { return nil, fmt.Errorf("invalid certificate") } return sc, nil } func comparePrefix(a, b netip.Prefix) int { addr := a.Addr().Compare(b.Addr()) if addr == 0 { return a.Bits() - b.Bits() } return addr } // findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error { if len(sortedPrefixes) < 2 { return nil } for i := 1; i < len(sortedPrefixes); i++ { if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 { return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i]) } } return nil } ================================================ FILE: cert/sign_test.go ================================================ package cert import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "net/netip" "testing" "time" "github.com/slackhq/nebula/cert/p256" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCertificateV1_Sign(t *testing.T) { before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") tbs := TBSCertificate{ Version: Version1, Name: "testing", Networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, UnsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/24"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, NotAfter: after, PublicKey: pubKey, IsCA: false, } pub, priv, err := ed25519.GenerateKey(rand.Reader) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) require.NoError(t, err) assert.NotNil(t, uc) } func TestCertificateV1_SignP256(t *testing.T) { before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") tbs := TBSCertificate{ Version: Version1, Name: "testing", Networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, UnsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, NotAfter: after, PublicKey: pubKey, IsCA: false, Curve: Curve_P256, } priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) require.NoError(t, err) assert.NotNil(t, uc) } func TestCertificate_SignP256_AlwaysNormalized(t *testing.T) { before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") tbs := TBSCertificate{ Version: Version1, Name: "testing", Networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), }, UnsafeNetworks: []netip.Prefix{ mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, NotAfter: after, PublicKey: pubKey, IsCA: true, Curve: Curve_P256, } priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) for i := 0; i < 1000; i++ { if i&1 == 1 { tbs.Version = Version1 } else { tbs.Version = Version2 } c, err := tbs.Sign(nil, Curve_P256, rawPriv) require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) normie, err := p256.IsNormalized(c.Signature()) require.NoError(t, err) assert.True(t, normie) } } ================================================ FILE: cert_test/cert.go ================================================ package cert_test import ( "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "io" "net/netip" "time" "github.com/slackhq/nebula/cert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" ) // NewTestCaCert will create a new ca certificate func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { var err error var pub, priv []byte switch curve { case cert.Curve_CURVE25519: pub, priv, err = ed25519.GenerateKey(rand.Reader) case cert.Curve_P256: privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { panic(err) } pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) priv = privk.D.FillBytes(make([]byte, 32)) default: // There is no default to allow the underlying lib to respond with an error } if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } if after.IsZero() { after = time.Now().Add(time.Second * 60).Round(time.Second) } t := &cert.TBSCertificate{ Curve: curve, Version: version, Name: "test ca", NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), PublicKey: pub, Networks: networks, UnsafeNetworks: unsafeNetworks, Groups: groups, IsCA: true, } c, err := t.Sign(nil, curve, priv) if err != nil { panic(err) } pem, err := c.MarshalPEM() if err != nil { panic(err) } return c, pub, priv, pem } // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } if after.IsZero() { after = time.Now().Add(time.Second * 60).Round(time.Second) } var pub, priv []byte switch curve { case cert.Curve_CURVE25519: pub, priv = X25519Keypair() case cert.Curve_P256: pub, priv = P256Keypair() default: panic("unknown curve") } nc := &cert.TBSCertificate{ Version: v, Curve: curve, Name: name, Networks: networks, UnsafeNetworks: unsafeNetworks, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), PublicKey: pub, IsCA: false, } c, err := nc.Sign(ca, ca.Curve(), key) if err != nil { panic(err) } pem, err := c.MarshalPEM() if err != nil { panic(err) } return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem } func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) { nc := &cert.TBSCertificate{ Version: v, Curve: c.Curve(), Name: c.Name(), Networks: c.Networks(), UnsafeNetworks: c.UnsafeNetworks(), Groups: c.Groups(), NotBefore: time.Unix(c.NotBefore().Unix(), 0), NotAfter: time.Unix(c.NotAfter().Unix(), 0), PublicKey: c.PublicKey(), IsCA: false, } c, err := nc.Sign(ca, ca.Curve(), key) if err != nil { panic(err) } pem, err := c.MarshalPEM() if err != nil { panic(err) } return c, pem } func X25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { panic(err) } pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) if err != nil { panic(err) } return pubkey, privkey } func P256Keypair() ([]byte, []byte) { privkey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { panic(err) } pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } ================================================ FILE: cmd/nebula/main.go ================================================ package main import ( "flag" "fmt" "os" "runtime/debug" "strings" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) // A version string that can be set with // // -ldflags "-X main.Build=SOMEVERSION" // // at compile-time. var Build string func init() { if Build == "" { info, ok := debug.ReadBuildInfo() if !ok { return } Build = strings.TrimPrefix(info.Main.Version, "v") } } func main() { configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") printVersion := flag.Bool("version", false, "Print version") printUsage := flag.Bool("help", false, "Print command line usage") flag.Parse() if *printVersion { fmt.Printf("Version: %s\n", Build) os.Exit(0) } if *printUsage { flag.Usage() os.Exit(0) } if *configPath == "" { fmt.Println("-config flag must be set") flag.Usage() os.Exit(1) } l := logrus.New() l.Out = os.Stdout c := config.NewC(l) err := c.Load(*configPath) if err != nil { fmt.Printf("failed to load config: %s", err) os.Exit(1) } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } if !*configTest { ctrl.Start() notifyReady(l) ctrl.ShutdownBlock() } os.Exit(0) } ================================================ FILE: cmd/nebula/notify_linux.go ================================================ package main import ( "net" "os" "time" "github.com/sirupsen/logrus" ) // SdNotifyReady tells systemd the service is ready and dependent services can now be started // https://www.freedesktop.org/software/systemd/man/sd_notify.html // https://www.freedesktop.org/software/systemd/man/systemd.service.html const SdNotifyReady = "READY=1" func notifyReady(l *logrus.Logger) { sockName := os.Getenv("NOTIFY_SOCKET") if sockName == "" { l.Debugln("NOTIFY_SOCKET systemd env var not set, not sending ready signal") return } conn, err := net.DialTimeout("unixgram", sockName, time.Second) if err != nil { l.WithError(err).Error("failed to connect to systemd notification socket") return } defer conn.Close() err = conn.SetWriteDeadline(time.Now().Add(time.Second)) if err != nil { l.WithError(err).Error("failed to set the write deadline for the systemd notification socket") return } if _, err = conn.Write([]byte(SdNotifyReady)); err != nil { l.WithError(err).Error("failed to signal the systemd notification socket") return } l.Debugln("notified systemd the service is ready") } ================================================ FILE: cmd/nebula/notify_notlinux.go ================================================ //go:build !linux // +build !linux package main import "github.com/sirupsen/logrus" func notifyReady(_ *logrus.Logger) { // No init service to notify } ================================================ FILE: cmd/nebula-cert/ca.go ================================================ package main import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "flag" "fmt" "io" "math" "net/netip" "os" "strings" "time" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/ed25519" ) type caFlags struct { set *flag.FlagSet name *string duration *time.Duration outKeyPath *string outCertPath *string outQRPath *string groups *string networks *string unsafeNetworks *string argonMemory *uint argonIterations *uint argonParallelism *uint encryption *bool version *uint curve *string p11url *string // Deprecated options ips *string subnets *string } func newCaFlags() *caFlags { cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf.set.Usage = func() {} cf.name = cf.set.String("name", "", "Required: name of the certificate authority") cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks") cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") cf.p11url = p11Flag(cf.set) cf.ips = cf.set.String("ips", "", "Deprecated, see -networks") cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &cf } func parseArgonParameters(memory uint, parallelism uint, iterations uint) (*cert.Argon2Parameters, error) { if memory <= 0 || memory > math.MaxUint32 { return nil, newHelpErrorf("-argon-memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32)) } if parallelism <= 0 || parallelism > math.MaxUint8 { return nil, newHelpErrorf("-argon-parallelism must be be greater than 0 and no more than %d", math.MaxUint8) } if iterations <= 0 || iterations > math.MaxUint32 { return nil, newHelpErrorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32)) } return cert.NewArgon2Parameters(uint32(memory), uint8(parallelism), uint32(iterations)), nil } func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { cf := newCaFlags() err := cf.set.Parse(args) if err != nil { return err } isP11 := len(*cf.p11url) > 0 if err := mustFlagString("name", cf.name); err != nil { return err } if !isP11 { if err = mustFlagString("out-key", cf.outKeyPath); err != nil { return err } } if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err } var kdfParams *cert.Argon2Parameters if !isP11 && *cf.encryption { if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil { return err } } if *cf.duration <= 0 { return &helpError{"-duration must be greater than 0"} } var groups []string if *cf.groups != "" { for _, rg := range strings.Split(*cf.groups, ",") { g := strings.TrimSpace(rg) if g != "" { groups = append(groups, g) } } } version := cert.Version(*cf.version) if version != cert.Version1 && version != cert.Version2 { return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } var networks []netip.Prefix if *cf.networks == "" && *cf.ips != "" { // Pull up deprecated -ips flag if needed *cf.networks = *cf.ips } if *cf.networks != "" { for _, rs := range strings.Split(*cf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid -networks definition: %s", rs) } if version == cert.Version1 && !n.Addr().Is4() { return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) } networks = append(networks, n) } } } var unsafeNetworks []netip.Prefix if *cf.unsafeNetworks == "" && *cf.subnets != "" { // Pull up deprecated -subnets flag if needed *cf.unsafeNetworks = *cf.subnets } if *cf.unsafeNetworks != "" { for _, rs := range strings.Split(*cf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } if version == cert.Version1 && !n.Addr().Is4() { return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs) } unsafeNetworks = append(unsafeNetworks, n) } } } var passphrase []byte if !isP11 && *cf.encryption { passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) if len(passphrase) == 0 { for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() if err == ErrNoTerminal { return fmt.Errorf("out-key must be encrypted interactively") } else if err != nil { return fmt.Errorf("error reading passphrase: %s", err) } if len(passphrase) > 0 { break } } if len(passphrase) == 0 { return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") } } } var curve cert.Curve var pub, rawPriv []byte var p11Client *pkclient.PKClient if isP11 { switch *cf.curve { case "P256": curve = cert.Curve_P256 default: return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) } p11Client, err = pkclient.FromUrl(*cf.p11url) if err != nil { return fmt.Errorf("error while creating PKCS#11 client: %w", err) } defer func(client *pkclient.PKClient) { _ = client.Close() }(p11Client) pub, err = p11Client.GetPubKey() if err != nil { return fmt.Errorf("error while getting public key with PKCS#11: %w", err) } } else { switch *cf.curve { case "25519", "X25519", "Curve25519", "CURVE25519": curve = cert.Curve_CURVE25519 pub, rawPriv, err = ed25519.GenerateKey(rand.Reader) if err != nil { return fmt.Errorf("error while generating ed25519 keys: %s", err) } case "P256": var key *ecdsa.PrivateKey curve = cert.Curve_P256 key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return fmt.Errorf("error while generating ecdsa keys: %s", err) } // ecdh.PrivateKey lets us get at the encoded bytes, even though // we aren't using ECDH here. eKey, err := key.ECDH() if err != nil { return fmt.Errorf("error while converting ecdsa key: %s", err) } rawPriv = eKey.Bytes() pub = eKey.PublicKey().Bytes() default: return fmt.Errorf("invalid curve: %s", *cf.curve) } } t := &cert.TBSCertificate{ Version: version, Name: *cf.name, Groups: groups, Networks: networks, UnsafeNetworks: unsafeNetworks, NotBefore: time.Now(), NotAfter: time.Now().Add(*cf.duration), PublicKey: pub, IsCA: true, Curve: curve, } if !isP11 { if _, err := os.Stat(*cf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) } } if _, err := os.Stat(*cf.outCertPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } var c cert.Certificate var b []byte if isP11 { c, err = t.SignWith(nil, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } } else { c, err = t.Sign(nil, curve, rawPriv) if err != nil { return fmt.Errorf("error while signing: %s", err) } if *cf.encryption { b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams) if err != nil { return fmt.Errorf("error while encrypting out-key: %s", err) } } else { b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) } err = os.WriteFile(*cf.outKeyPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } b, err = c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } err = os.WriteFile(*cf.outCertPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } if *cf.outQRPath != "" { b, err = qrcode.Encode(string(b), qrcode.Medium, -5) if err != nil { return fmt.Errorf("error while generating qr code: %s", err) } err = os.WriteFile(*cf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } } return nil } func caSummary() string { return "ca : create a self signed certificate authority" } func caHelp(out io.Writer) { cf := newCaFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n")) cf.set.SetOutput(out) cf.set.PrintDefaults() } ================================================ FILE: cmd/nebula-cert/ca_test.go ================================================ //go:build !windows // +build !windows package main import ( "bytes" "encoding/pem" "errors" "os" "strings" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_caSummary(t *testing.T) { assert.Equal(t, "ca : create a self signed certificate authority", caSummary()) } func Test_caHelp(t *testing.T) { ob := &bytes.Buffer{} caHelp(ob) assert.Equal( t, "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ " -argon-iterations uint\n"+ " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ " -argon-memory uint\n"+ " \tOptional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase (default 2097152)\n"+ " -argon-parallelism uint\n"+ " \tOptional: Argon2 parallelism parameter used for encrypted private key passphrase (default 4)\n"+ " -curve string\n"+ " \tEdDSA/ECDSA Curve (25519, P256) (default \"25519\")\n"+ " -duration duration\n"+ " \tOptional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\" (default 8760h0m0s)\n"+ " -encrypt\n"+ " \tOptional: prompt for passphrase and write out-key in an encrypted format\n"+ " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ " Deprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the certificate authority\n"+ " -networks string\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " -out-key string\n"+ " \tOptional: path to write the private key to (default \"ca.key\")\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ " \tDeprecated, see -unsafe-networks\n"+ " -unsafe-networks string\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+ " -version uint\n"+ " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } func Test_ca(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} nopw := &StubPasswordReader{ password: []byte(""), err: nil, } errpw := &StubPasswordReader{ password: []byte(""), err: errors.New("stub error"), } passphrase := []byte("DO NOT USE THIS KEY") testpw := &StubPasswordReader{ password: passphrase, err: nil, } pwPromptOb := "Enter passphrase: " // required args assertHelpError(t, ca( []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // ipv4 only ips assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // ipv4 only subnets assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") require.NoError(t, err) require.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") require.NoError(t, err) require.NoError(t, os.Remove(crtF.Name())) require.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, c) assert.Empty(t, b) require.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Empty(t, lCrt.Networks()) assert.True(t, lCrt.IsCA()) assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) assert.Empty(t, lCrt.UnsafeNetworks()) assert.Len(t, lCrt.PublicKey(), 32) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) assert.Empty(t, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(lCrt.PublicKey())) // test encrypted key os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Empty(t, eb.String()) // test encrypted key with passphrase environment variable os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) require.NoError(t, ca(args, ob, eb, testpw)) assert.Empty(t, eb.String()) os.Setenv("NEBULA_CA_PASSPHRASE", "") // read encrypted key file and verify default params rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) require.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) assert.Equal(t, uint32(1), ned.EncryptionMetadata.Argon2Parameters.Iterations) // verify the key is valid and decrypt-able var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) require.NoError(t, err) assert.Empty(t, b) assert.Len(t, lKey, 64) // test when reading password results in an error os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Empty(t, eb.String()) // test when user fails to enter a password os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // test that we won't overwrite existing key file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) } ================================================ FILE: cmd/nebula-cert/keygen.go ================================================ package main import ( "flag" "fmt" "io" "os" "github.com/slackhq/nebula/pkclient" "github.com/slackhq/nebula/cert" ) type keygenFlags struct { set *flag.FlagSet outKeyPath *string outPubPath *string curve *string p11url *string } func newKeygenFlags() *keygenFlags { cf := keygenFlags{set: flag.NewFlagSet("keygen", flag.ContinueOnError)} cf.set.Usage = func() {} cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to") cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to") cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)") cf.p11url = p11Flag(cf.set) return &cf } func keygen(args []string, out io.Writer, errOut io.Writer) error { cf := newKeygenFlags() err := cf.set.Parse(args) if err != nil { return err } isP11 := len(*cf.p11url) > 0 if !isP11 { if err = mustFlagString("out-key", cf.outKeyPath); err != nil { return err } } if err = mustFlagString("out-pub", cf.outPubPath); err != nil { return err } var pub, rawPriv []byte var curve cert.Curve if isP11 { switch *cf.curve { case "P256": curve = cert.Curve_P256 default: return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve) } } else { switch *cf.curve { case "25519", "X25519", "Curve25519", "CURVE25519": pub, rawPriv = x25519Keypair() curve = cert.Curve_CURVE25519 case "P256": pub, rawPriv = p256Keypair() curve = cert.Curve_P256 default: return fmt.Errorf("invalid curve: %s", *cf.curve) } } if isP11 { p11Client, err := pkclient.FromUrl(*cf.p11url) if err != nil { return fmt.Errorf("error while creating PKCS#11 client: %w", err) } defer func(client *pkclient.PKClient) { _ = client.Close() }(p11Client) pub, err = p11Client.GetPubKey() if err != nil { return fmt.Errorf("error while getting public key: %w", err) } } else { err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } return nil } func keygenSummary() string { return "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`" } func keygenHelp(out io.Writer) { cf := newKeygenFlags() _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) cf.set.SetOutput(out) cf.set.PrintDefaults() } ================================================ FILE: cmd/nebula-cert/keygen_test.go ================================================ package main import ( "bytes" "os" "testing" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_keygenSummary(t *testing.T) { assert.Equal(t, "keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary()) } func Test_keygenHelp(t *testing.T) { ob := &bytes.Buffer{} keygenHelp(ob) assert.Equal( t, "Usage of "+os.Args[0]+" keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+ " -curve string\n"+ " \tECDH Curve (25519, P256) (default \"25519\")\n"+ " -out-key string\n"+ " \tRequired: path to write the private key to\n"+ " -out-pub string\n"+ " \tRequired: path to write the public key to\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n"), ob.String(), ) } func Test_keygen(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} // required args assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") require.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write ob.Reset() eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") require.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} require.NoError(t, keygen(args, ob, eb)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) require.NoError(t, err) assert.Len(t, lPub, 32) } ================================================ FILE: cmd/nebula-cert/main.go ================================================ package main import ( "flag" "fmt" "io" "os" "runtime/debug" "strings" ) // A version string that can be set with // // -ldflags "-X main.Build=SOMEVERSION" // // at compile-time. var Build string func init() { if Build == "" { info, ok := debug.ReadBuildInfo() if !ok { return } Build = strings.TrimPrefix(info.Main.Version, "v") } } type helpError struct { s string } func (he *helpError) Error() string { return he.s } func newHelpErrorf(s string, v ...any) error { return &helpError{s: fmt.Sprintf(s, v...)} } func main() { flag.Usage = func() { help("", os.Stderr) os.Exit(1) } printVersion := flag.Bool("version", false, "Print version") flagHelp := flag.Bool("help", false, "Print command line usage") flagH := flag.Bool("h", false, "Print command line usage") printUsage := false flag.Parse() if *flagH || *flagHelp { printUsage = true } args := flag.Args() if *printVersion { fmt.Printf("Version: %v\n", Build) os.Exit(0) } if len(args) < 1 { if printUsage { help("", os.Stderr) os.Exit(0) } help("No mode was provided", os.Stderr) os.Exit(1) } else if printUsage { handleError(args[0], &helpError{}, os.Stderr) os.Exit(0) } var err error switch args[0] { case "ca": err = ca(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "keygen": err = keygen(args[1:], os.Stdout, os.Stderr) case "sign": err = signCert(args[1:], os.Stdout, os.Stderr, StdinPasswordReader{}) case "print": err = printCert(args[1:], os.Stdout, os.Stderr) case "verify": err = verify(args[1:], os.Stdout, os.Stderr) default: err = fmt.Errorf("unknown mode: %s", args[0]) } if err != nil { os.Exit(handleError(args[0], err, os.Stderr)) } } func handleError(mode string, e error, out io.Writer) int { code := 1 // Handle -help, -h flags properly if e == flag.ErrHelp { code = 0 e = &helpError{} } else if e != nil && e.Error() != "" { fmt.Fprintln(out, "Error:", e) } switch e.(type) { case *helpError: switch mode { case "ca": caHelp(out) case "keygen": keygenHelp(out) case "sign": signHelp(out) case "print": printHelp(out) case "verify": verifyHelp(out) } } return code } func help(err string, out io.Writer) { if err != "" { fmt.Fprintln(out, "Error:", err) fmt.Fprintln(out, "") } fmt.Fprintf(out, "Usage of %s :\n", os.Args[0]) fmt.Fprintln(out, " Global flags:") fmt.Fprintln(out, " -version: Prints the version") fmt.Fprintln(out, " -h, -help: Prints this help message") fmt.Fprintln(out, "") fmt.Fprintln(out, " Modes:") fmt.Fprintln(out, " "+caSummary()) fmt.Fprintln(out, " "+keygenSummary()) fmt.Fprintln(out, " "+signSummary()) fmt.Fprintln(out, " "+printSummary()) fmt.Fprintln(out, " "+verifySummary()) fmt.Fprintln(out, "") fmt.Fprintf(out, " To see usage for a given mode, use %s -h\n", os.Args[0]) } func mustFlagString(name string, val *string) error { if *val == "" { return newHelpErrorf("-%s is required", name) } return nil } ================================================ FILE: cmd/nebula-cert/main_test.go ================================================ package main import ( "bytes" "errors" "fmt" "io" "os" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_help(t *testing.T) { expected := "Usage of " + os.Args[0] + " :\n" + " Global flags:\n" + " -version: Prints the version\n" + " -h, -help: Prints this help message\n\n" + " Modes:\n" + " " + caSummary() + "\n" + " " + keygenSummary() + "\n" + " " + signSummary() + "\n" + " " + printSummary() + "\n" + " " + verifySummary() + "\n" + "\n" + " To see usage for a given mode, use " + os.Args[0] + " -h\n" ob := &bytes.Buffer{} // No error test help("", ob) assert.Equal( t, expected, ob.String(), ) // Error test ob.Reset() help("test error", ob) assert.Equal( t, "Error: test error\n\n"+expected, ob.String(), ) } func Test_handleError(t *testing.T) { ob := &bytes.Buffer{} // normal error handleError("", errors.New("test error"), ob) assert.Equal(t, "Error: test error\n", ob.String()) // unknown mode help error ob.Reset() handleError("", newHelpErrorf("test %s", "error"), ob) assert.Equal(t, "Error: test error\n", ob.String()) // test all modes with help error modes := map[string]func(io.Writer){"ca": caHelp, "print": printHelp, "sign": signHelp, "verify": verifyHelp} eb := &bytes.Buffer{} for mode, fn := range modes { ob.Reset() eb.Reset() fn(eb) handleError(mode, newHelpErrorf("test %s", "error"), ob) assert.Equal(t, "Error: test error\n"+eb.String(), ob.String()) } } func assertHelpError(t *testing.T, err error, msg string) { switch err.(type) { case *helpError: // good default: t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } require.EqualError(t, err, msg) } func optionalPkcs11String(msg string) string { if p11Supported() { return msg } else { return "" } } ================================================ FILE: cmd/nebula-cert/p11_cgo.go ================================================ //go:build cgo && pkcs11 package main import ( "flag" ) func p11Supported() bool { return true } func p11Flag(set *flag.FlagSet) *string { return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key") } ================================================ FILE: cmd/nebula-cert/p11_stub.go ================================================ //go:build !cgo || !pkcs11 package main import ( "flag" ) func p11Supported() bool { return false } func p11Flag(set *flag.FlagSet) *string { var ret = "" return &ret } ================================================ FILE: cmd/nebula-cert/passwords.go ================================================ package main import ( "errors" "fmt" "os" "golang.org/x/term" ) var ErrNoTerminal = errors.New("cannot read password from nonexistent terminal") type PasswordReader interface { ReadPassword() ([]byte, error) } type StdinPasswordReader struct{} func (pr StdinPasswordReader) ReadPassword() ([]byte, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { return nil, ErrNoTerminal } password, err := term.ReadPassword(int(os.Stdin.Fd())) fmt.Println() return password, err } ================================================ FILE: cmd/nebula-cert/passwords_test.go ================================================ package main type StubPasswordReader struct { password []byte err error } func (pr *StubPasswordReader) ReadPassword() ([]byte, error) { return pr.password, pr.err } ================================================ FILE: cmd/nebula-cert/print.go ================================================ package main import ( "encoding/json" "flag" "fmt" "io" "os" "strings" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" ) type printFlags struct { set *flag.FlagSet json *bool outQRPath *string path *string } func newPrintFlags() *printFlags { pf := printFlags{set: flag.NewFlagSet("print", flag.ContinueOnError)} pf.set.Usage = func() {} pf.json = pf.set.Bool("json", false, "Optional: outputs certificates in json format") pf.outQRPath = pf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") pf.path = pf.set.String("path", "", "Required: path to the certificate") return &pf } func printCert(args []string, out io.Writer, errOut io.Writer) error { pf := newPrintFlags() err := pf.set.Parse(args) if err != nil { return err } if err := mustFlagString("path", pf.path); err != nil { return err } rawCert, err := os.ReadFile(*pf.path) if err != nil { return fmt.Errorf("unable to read cert; %s", err) } var c cert.Certificate var qrBytes []byte part := 0 var jsonCerts []cert.Certificate for { c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return fmt.Errorf("error while unmarshaling cert: %s", err) } if *pf.json { jsonCerts = append(jsonCerts, c) } else { _, _ = out.Write([]byte(c.String())) _, _ = out.Write([]byte("\n")) } if *pf.outQRPath != "" { b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling cert to PEM: %s", err) } qrBytes = append(qrBytes, b...) } if rawCert == nil || len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { break } part++ } if *pf.json { b, _ := json.Marshal(jsonCerts) _, _ = out.Write(b) _, _ = out.Write([]byte("\n")) } if *pf.outQRPath != "" { b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) if err != nil { return fmt.Errorf("error while generating qr code: %s", err) } err = os.WriteFile(*pf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } } return nil } func printSummary() string { return "print : prints details about a certificate" } func printHelp(out io.Writer) { pf := newPrintFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n")) pf.set.SetOutput(out) pf.set.PrintDefaults() } ================================================ FILE: cmd/nebula-cert/print_test.go ================================================ package main import ( "bytes" "crypto/ed25519" "crypto/rand" "encoding/hex" "net/netip" "os" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_printSummary(t *testing.T) { assert.Equal(t, "print : prints details about a certificate", printSummary()) } func Test_printHelp(t *testing.T) { ob := &bytes.Buffer{} printHelp(ob) assert.Equal( t, "Usage of "+os.Args[0]+" print : prints details about a certificate\n"+ " -json\n"+ " \tOptional: outputs certificates in json format\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ " -path string\n"+ " \tRequired: path to the certificate\n", ob.String(), ) } func Test_printCert(t *testing.T) { // Orient our local time and avoid headaches time.Local = time.UTC ob := &bytes.Buffer{} eb := &bytes.Buffer{} // no path err := printCert([]string{}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, err, "-path is required") // no cert at path ob.Reset() eb.Reset() err = printCert([]string{"-path", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") require.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs ob.Reset() eb.Reset() tf.Truncate(0) tf.Seek(0, 0) ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil) c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"}) p, _ := c.MarshalPEM() tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-path", tf.Name()}, ob, eb) fp, _ := c.Fingerprint() pk := hex.EncodeToString(c.PublicKey()) sig := hex.EncodeToString(c.Signature()) require.NoError(t, err) assert.Equal( t, //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", `{ "details": { "curve": "CURVE25519", "groups": [ "hi" ], "isCa": false, "issuer": "`+c.Issuer()+`", "name": "test", "networks": [ "10.0.0.123/8" ], "notAfter": "0001-01-01T00:00:00Z", "notBefore": "0001-01-01T00:00:00Z", "publicKey": "`+pk+`", "unsafeNetworks": [] }, "fingerprint": "`+fp+`", "signature": "`+sig+`", "version": 1 } { "details": { "curve": "CURVE25519", "groups": [ "hi" ], "isCa": false, "issuer": "`+c.Issuer()+`", "name": "test", "networks": [ "10.0.0.123/8" ], "notAfter": "0001-01-01T00:00:00Z", "notBefore": "0001-01-01T00:00:00Z", "publicKey": "`+pk+`", "unsafeNetworks": [] }, "fingerprint": "`+fp+`", "signature": "`+sig+`", "version": 1 } { "details": { "curve": "CURVE25519", "groups": [ "hi" ], "isCa": false, "issuer": "`+c.Issuer()+`", "name": "test", "networks": [ "10.0.0.123/8" ], "notAfter": "0001-01-01T00:00:00Z", "notBefore": "0001-01-01T00:00:00Z", "publicKey": "`+pk+`", "unsafeNetworks": [] }, "fingerprint": "`+fp+`", "signature": "`+sig+`", "version": 1 } `, ob.String(), ) assert.Empty(t, eb.String()) // test json ob.Reset() eb.Reset() tf.Truncate(0) tf.Seek(0, 0) tf.Write(p) tf.Write(p) tf.Write(p) err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb) fp, _ = c.Fingerprint() pk = hex.EncodeToString(c.PublicKey()) sig = hex.EncodeToString(c.Signature()) require.NoError(t, err) assert.Equal( t, `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] `, ob.String(), ) assert.Empty(t, eb.String()) } // NewTestCaCert will generate a CA cert func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { var err error if pubKey == nil || privKey == nil { pubKey, privKey, err = ed25519.GenerateKey(rand.Reader) if err != nil { panic(err) } } t := &cert.TBSCertificate{ Version: cert.Version1, Name: name, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), PublicKey: pubKey, Networks: networks, UnsafeNetworks: unsafeNetworks, Groups: groups, IsCA: true, } c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey) if err != nil { panic(err) } return c, privKey } func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) { if before.IsZero() { before = ca.NotBefore() } if after.IsZero() { after = ca.NotAfter() } if len(networks) == 0 { networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")} } pub, rawPriv := x25519Keypair() nc := &cert.TBSCertificate{ Version: cert.Version1, Name: name, Networks: networks, UnsafeNetworks: unsafeNetworks, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), PublicKey: pub, IsCA: false, } c, err := nc.Sign(ca, ca.Curve(), signerKey) if err != nil { panic(err) } return c, rawPriv } ================================================ FILE: cmd/nebula-cert/sign.go ================================================ package main import ( "crypto/ecdh" "crypto/rand" "errors" "flag" "fmt" "io" "net/netip" "os" "strings" "time" "github.com/skip2/go-qrcode" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" ) type signFlags struct { set *flag.FlagSet version *uint caKeyPath *string caCertPath *string name *string networks *string unsafeNetworks *string duration *time.Duration inPubPath *string outKeyPath *string outCertPath *string outQRPath *string groups *string p11url *string // Deprecated options ip *string subnets *string } func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert") sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") sf.p11url = p11Flag(sf.set) sf.ip = sf.set.String("ip", "", "Deprecated, see -networks") sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &sf } func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error { sf := newSignFlags() err := sf.set.Parse(args) if err != nil { return err } isP11 := len(*sf.p11url) > 0 if !isP11 { if err := mustFlagString("ca-key", sf.caKeyPath); err != nil { return err } } if err := mustFlagString("ca-crt", sf.caCertPath); err != nil { return err } if err := mustFlagString("name", sf.name); err != nil { return err } if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } var v4Networks []netip.Prefix var v6Networks []netip.Prefix if *sf.networks == "" && *sf.ip != "" { // Pull up deprecated -ip flag if needed *sf.networks = *sf.ip } if len(*sf.networks) == 0 { return newHelpErrorf("-networks is required") } version := cert.Version(*sf.version) if version != 0 && version != cert.Version1 && version != cert.Version2 { return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } var curve cert.Curve var caKey []byte if !isP11 { var rawCAKey []byte rawCAKey, err := os.ReadFile(*sf.caKeyPath) if err != nil { return fmt.Errorf("error while reading ca-key: %s", err) } // naively attempt to decode the private key as though it is not encrypted caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) if errors.Is(err, cert.ErrPrivateKeyEncrypted) { var passphrase []byte passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) if len(passphrase) == 0 { // ask for a passphrase until we get one for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() if errors.Is(err, ErrNoTerminal) { return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") } else if err != nil { return fmt.Errorf("error reading password: %s", err) } if len(passphrase) > 0 { break } } if len(passphrase) == 0 { return fmt.Errorf("cannot open encrypted ca-key without passphrase") } } curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) if err != nil { return fmt.Errorf("error while parsing encrypted ca-key: %s", err) } } else if err != nil { return fmt.Errorf("error while parsing ca-key: %s", err) } } rawCACert, err := os.ReadFile(*sf.caCertPath) if err != nil { return fmt.Errorf("error while reading ca-crt: %s", err) } caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert) if err != nil { return fmt.Errorf("error while parsing ca-crt: %s", err) } if !isP11 { if err := caCert.VerifyPrivateKey(curve, caKey); err != nil { return fmt.Errorf("refusing to sign, root certificate does not match private key") } } if caCert.Expired(time.Now()) { return fmt.Errorf("ca certificate is expired") } if version == 0 { version = caCert.Version() } // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } if *sf.networks != "" { for _, rs := range strings.Split(*sf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid -networks definition: %s", rs) } if n.Addr().Is4() { v4Networks = append(v4Networks, n) } else { v6Networks = append(v6Networks, n) } } } } var v4UnsafeNetworks []netip.Prefix var v6UnsafeNetworks []netip.Prefix if *sf.unsafeNetworks == "" && *sf.subnets != "" { // Pull up deprecated -subnets flag if needed *sf.unsafeNetworks = *sf.subnets } if *sf.unsafeNetworks != "" { for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } if n.Addr().Is4() { v4UnsafeNetworks = append(v4UnsafeNetworks, n) } else { v6UnsafeNetworks = append(v6UnsafeNetworks, n) } } } } var groups []string if *sf.groups != "" { for _, rg := range strings.Split(*sf.groups, ",") { g := strings.TrimSpace(rg) if g != "" { groups = append(groups, g) } } } var pub, rawPriv []byte var p11Client *pkclient.PKClient if isP11 { curve = cert.Curve_P256 p11Client, err = pkclient.FromUrl(*sf.p11url) if err != nil { return fmt.Errorf("error while creating PKCS#11 client: %w", err) } defer func(client *pkclient.PKClient) { _ = client.Close() }(p11Client) } if *sf.inPubPath != "" { var pubCurve cert.Curve rawPub, err := os.ReadFile(*sf.inPubPath) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub) if err != nil { return fmt.Errorf("error while parsing in-pub: %s", err) } if pubCurve != curve { return fmt.Errorf("curve of in-pub does not match ca") } } else if isP11 { pub, err = p11Client.GetPubKey() if err != nil { return fmt.Errorf("error while getting public key with PKCS#11: %w", err) } } else { pub, rawPriv = newKeypair(curve) } if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" } if *sf.outCertPath == "" { *sf.outCertPath = *sf.name + ".crt" } if _, err := os.Stat(*sf.outCertPath); err == nil { return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } var crts []cert.Certificate notBefore := time.Now() notAfter := notBefore.Add(*sf.duration) switch version { case cert.Version1: // Make sure we have only one ipv4 address if len(v4Networks) != 1 { return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } if len(v6Networks) > 0 { return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") } if len(v6UnsafeNetworks) > 0 { return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") } t := &cert.TBSCertificate{ Version: cert.Version1, Name: *sf.name, Networks: []netip.Prefix{v4Networks[0]}, Groups: groups, UnsafeNetworks: v4UnsafeNetworks, NotBefore: notBefore, NotAfter: notAfter, PublicKey: pub, IsCA: false, Curve: curve, } var nc cert.Certificate if p11Client == nil { nc, err = t.Sign(caCert, curve, caKey) if err != nil { return fmt.Errorf("error while signing: %w", err) } } else { nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } } crts = append(crts, nc) case cert.Version2: t := &cert.TBSCertificate{ Version: cert.Version2, Name: *sf.name, Networks: append(v4Networks, v6Networks...), Groups: groups, UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...), NotBefore: notBefore, NotAfter: notAfter, PublicKey: pub, IsCA: false, Curve: curve, } var nc cert.Certificate if p11Client == nil { nc, err = t.Sign(caCert, curve, caKey) if err != nil { return fmt.Errorf("error while signing: %w", err) } } else { nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } } crts = append(crts, nc) default: // this should be unreachable return fmt.Errorf("invalid version: %d", version) } if !isP11 && *sf.inPubPath == "" { if _, err := os.Stat(*sf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) } err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } var b []byte for _, c := range crts { sb, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } b = append(b, sb...) } err = os.WriteFile(*sf.outCertPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } if *sf.outQRPath != "" { b, err = qrcode.Encode(string(b), qrcode.Medium, -5) if err != nil { return fmt.Errorf("error while generating qr code: %s", err) } err = os.WriteFile(*sf.outQRPath, b, 0600) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } } return nil } func newKeypair(curve cert.Curve) ([]byte, []byte) { switch curve { case cert.Curve_CURVE25519: return x25519Keypair() case cert.Curve_P256: return p256Keypair() default: return nil, nil } } func x25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { panic(err) } pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) if err != nil { panic(err) } return pubkey, privkey } func p256Keypair() ([]byte, []byte) { privkey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { panic(err) } pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } func signSummary() string { return "sign : create and sign a certificate" } func signHelp(out io.Writer) { sf := newSignFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n")) sf.set.SetOutput(out) sf.set.PrintDefaults() } ================================================ FILE: cmd/nebula-cert/sign_test.go ================================================ //go:build !windows // +build !windows package main import ( "bytes" "crypto/rand" "errors" "os" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) func Test_signSummary(t *testing.T) { assert.Equal(t, "sign : create and sign a certificate", signSummary()) } func Test_signHelp(t *testing.T) { ob := &bytes.Buffer{} signHelp(ob) assert.Equal( t, "Usage of "+os.Args[0]+" sign : create and sign a certificate\n"+ " -ca-crt string\n"+ " \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+ " -ca-key string\n"+ " \tOptional: path to the signing CA key (default \"ca.key\")\n"+ " -duration duration\n"+ " \tOptional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"\n"+ " -groups string\n"+ " \tOptional: comma separated list of groups\n"+ " -in-pub string\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+ " -ip string\n"+ " \tDeprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the cert, usually a hostname\n"+ " -networks string\n"+ " \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to\n"+ " -out-key string\n"+ " \tOptional (if in-pub not set): path to write the private key to\n"+ " -out-qr string\n"+ " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ " \tDeprecated, see -unsafe-networks\n"+ " -unsafe-networks string\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " -version uint\n"+ " \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", ob.String(), ) } func Test_signCert(t *testing.T) { ob := &bytes.Buffer{} eb := &bytes.Buffer{} nopw := &StubPasswordReader{ password: []byte(""), err: nil, } errpw := &StubPasswordReader{ password: []byte(""), err: errors.New("stub error"), } passphrase := []byte("DO NOT USE THIS KEY") testpw := &StubPasswordReader{ password: passphrase, err: nil, } // required args assertHelpError(t, signCert( []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, signCert( []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-networks is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key assertHelpError(t, signCert( []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed to read key ob.Reset() eb.Reset() args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") require.NoError(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // Write a proper ca key for later ob.Reset() eb.Reset() caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed to unmarshal cert ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") require.NoError(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // write a proper ca cert for later ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) b, _ := ca.MarshalPEM() caCrtF.Write(b) // failed to read pub args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed to unmarshal pub ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") require.NoError(t, err) defer os.Remove(inPubF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // write a proper pub for later ob.Reset() eb.Reset() inPub, _ := x25519Keypair() inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub)) // bad ip cidr ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // bad subnet cidr ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") require.NoError(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // failed key write ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") require.NoError(t, err) os.Remove(keyF.Name()) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") require.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert and key files rb, _ := os.ReadFile(keyF.Name()) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) assert.Len(t, lCrt.Networks(), 1) assert.False(t, lCrt.IsCA()) assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups()) assert.Len(t, lCrt.UnsafeNetworks(), 3) assert.Len(t, lCrt.PublicKey(), 32) assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore())) sns := []string{} for _, sn := range lCrt.UnsafeNetworks() { sns = append(sns, sn.String()) } assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns) issuer, _ := ca.Fingerprint() assert.Equal(t, issuer, lCrt.Issuer()) assert.True(t, lCrt.CheckSignature(caPub)) // test proper cert with in-pub os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // read cert file and check pub key matches in-pub rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) require.NoError(t, err) assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root ob.Reset() eb.Reset() os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create valid cert/key using encrypted CA key os.Remove(caKeyF.Name()) os.Remove(caCrtF.Name()) os.Remove(keyF.Name()) os.Remove(crtF.Name()) ob.Reset() eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") require.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") require.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key caPub, caPriv, _ = ed25519.GenerateKey(rand.Reader) kdfParams := cert.NewArgon2Parameters(64*1024, 4, 3) b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams) caKeyF.Write(b) ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) b, _ = ca.MarshalPEM() caCrtF.Write(b) // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) // test with the proper password in the environment os.Remove(crtF.Name()) os.Remove(keyF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) require.NoError(t, signCert(args, ob, eb, testpw)) assert.Empty(t, eb.String()) os.Setenv("NEBULA_CA_PASSPHRASE", "") // test with the wrong password ob.Reset() eb.Reset() testpw.password = []byte("invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) // test with the wrong password in environment ob.Reset() eb.Reset() os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Setenv("NEBULA_CA_PASSPHRASE", "") // test with the user not entering a password ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) // test an error condition ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) } ================================================ FILE: cmd/nebula-cert/test_darwin.go ================================================ package main const NoSuchFileError = "no such file or directory" const NoSuchDirError = "no such file or directory" ================================================ FILE: cmd/nebula-cert/test_linux.go ================================================ package main const NoSuchFileError = "no such file or directory" const NoSuchDirError = "no such file or directory" ================================================ FILE: cmd/nebula-cert/test_windows.go ================================================ package main const NoSuchFileError = "The system cannot find the file specified." const NoSuchDirError = "The system cannot find the path specified." ================================================ FILE: cmd/nebula-cert/verify.go ================================================ package main import ( "errors" "flag" "fmt" "io" "os" "strings" "time" "github.com/slackhq/nebula/cert" ) type verifyFlags struct { set *flag.FlagSet caPath *string certPath *string } func newVerifyFlags() *verifyFlags { vf := verifyFlags{set: flag.NewFlagSet("verify", flag.ContinueOnError)} vf.set.Usage = func() {} vf.caPath = vf.set.String("ca", "", "Required: path to a file containing one or more ca certificates") vf.certPath = vf.set.String("crt", "", "Required: path to a file containing a single certificate") return &vf } func verify(args []string, out io.Writer, errOut io.Writer) error { vf := newVerifyFlags() err := vf.set.Parse(args) if err != nil { return err } if err := mustFlagString("ca", vf.caPath); err != nil { return err } if err := mustFlagString("crt", vf.certPath); err != nil { return err } rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } caPool := cert.NewCAPool() for { rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { return fmt.Errorf("error while adding ca cert to pool: %w", err) } if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { break } } rawCert, err := os.ReadFile(*vf.certPath) if err != nil { return fmt.Errorf("unable to read crt: %w", err) } var errs []error for { if len(rawCert) == 0 { break } c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { return fmt.Errorf("error while parsing crt: %w", err) } rawCert = extra _, err = caPool.VerifyCertificate(time.Now(), c) if err != nil { switch { case errors.Is(err, cert.ErrCaNotFound): errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err)) default: errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err)) } } } return errors.Join(errs...) } func verifySummary() string { return "verify : verifies a certificate isn't expired and was signed by a trusted authority." } func verifyHelp(out io.Writer) { vf := newVerifyFlags() _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) vf.set.SetOutput(out) vf.set.PrintDefaults() } ================================================ FILE: cmd/nebula-cert/verify_test.go ================================================ package main import ( "bytes" "crypto/rand" "os" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) func Test_verifySummary(t *testing.T) { assert.Equal(t, "verify : verifies a certificate isn't expired and was signed by a trusted authority.", verifySummary()) } func Test_verifyHelp(t *testing.T) { ob := &bytes.Buffer{} verifyHelp(ob) assert.Equal( t, "Usage of "+os.Args[0]+" verify : verifies a certificate isn't expired and was signed by a trusted authority.\n"+ " -ca string\n"+ " \tRequired: path to a file containing one or more ca certificates\n"+ " -crt string\n"+ " \tRequired: path to a file containing a single certificate\n", ob.String(), ) } func Test_verify(t *testing.T) { time.Local = time.UTC ob := &bytes.Buffer{} eb := &bytes.Buffer{} // required args assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // no ca at path ob.Reset() eb.Reset() err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") require.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil) b, _ := ca.MarshalPEM() caFile.Truncate(0) caFile.Seek(0, 0) caFile.Write(b) // no crt at path err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") require.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) // Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature pub := crt.PublicKey() for i, _ := range pub { pub[i] = 0 } b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) b, _ = crt.MarshalPEM() certFile.Truncate(0) certFile.Seek(0, 0) certFile.Write(b) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) require.NoError(t, err) } ================================================ FILE: cmd/nebula-service/logs_generic.go ================================================ //go:build !windows // +build !windows package main import "github.com/sirupsen/logrus" func HookLogger(l *logrus.Logger) { // Do nothing, let the logs flow to stdout/stderr } ================================================ FILE: cmd/nebula-service/logs_windows.go ================================================ package main import ( "fmt" "io/ioutil" "os" "github.com/kardianos/service" "github.com/sirupsen/logrus" ) // HookLogger routes the logrus logs through the service logger so that they end up in the Windows Event Viewer // logrus output will be discarded func HookLogger(l *logrus.Logger) { l.AddHook(newLogHook(logger)) l.SetOutput(ioutil.Discard) } type logHook struct { sl service.Logger } func newLogHook(sl service.Logger) *logHook { return &logHook{sl: sl} } func (h *logHook) Fire(entry *logrus.Entry) error { line, err := entry.String() if err != nil { fmt.Fprintf(os.Stderr, "Unable to read entry, %v", err) return err } switch entry.Level { case logrus.PanicLevel: return h.sl.Error(line) case logrus.FatalLevel: return h.sl.Error(line) case logrus.ErrorLevel: return h.sl.Error(line) case logrus.WarnLevel: return h.sl.Warning(line) case logrus.InfoLevel: return h.sl.Info(line) case logrus.DebugLevel: return h.sl.Info(line) default: return nil } } func (h *logHook) Levels() []logrus.Level { return logrus.AllLevels } ================================================ FILE: cmd/nebula-service/main.go ================================================ package main import ( "flag" "fmt" "os" "runtime/debug" "strings" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) // A version string that can be set with // // -ldflags "-X main.Build=SOMEVERSION" // // at compile-time. var Build string func init() { if Build == "" { info, ok := debug.ReadBuildInfo() if !ok { return } Build = strings.TrimPrefix(info.Main.Version, "v") } } func main() { serviceFlag := flag.String("service", "", "Control the system service.") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") printVersion := flag.Bool("version", false, "Print version") printUsage := flag.Bool("help", false, "Print command line usage") flag.Parse() if *printVersion { fmt.Printf("Version: %s\n", Build) os.Exit(0) } if *printUsage { flag.Usage() os.Exit(0) } if *serviceFlag != "" { doService(configPath, configTest, Build, serviceFlag) os.Exit(1) } if *configPath == "" { fmt.Println("-config flag must be set") flag.Usage() os.Exit(1) } l := logrus.New() l.Out = os.Stdout c := config.NewC(l) err := c.Load(*configPath) if err != nil { fmt.Printf("failed to load config: %s", err) os.Exit(1) } ctrl, err := nebula.Main(c, *configTest, Build, l, nil) if err != nil { util.LogWithContextIfNeeded("Failed to start", err, l) os.Exit(1) } if !*configTest { ctrl.Start() ctrl.ShutdownBlock() } os.Exit(0) } ================================================ FILE: cmd/nebula-service/service.go ================================================ package main import ( "fmt" "log" "os" "path/filepath" "github.com/kardianos/service" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" ) var logger service.Logger type program struct { configPath *string configTest *bool build string control *nebula.Control } func (p *program) Start(s service.Service) error { // Start should not block. logger.Info("Nebula service starting.") l := logrus.New() HookLogger(l) c := config.NewC(l) err := c.Load(*p.configPath) if err != nil { return fmt.Errorf("failed to load config: %s", err) } p.control, err = nebula.Main(c, *p.configTest, Build, l, nil) if err != nil { return err } p.control.Start() return nil } func (p *program) Stop(s service.Service) error { logger.Info("Nebula service stopping.") p.control.Stop() return nil } func fileExists(filename string) bool { _, err := os.Stat(filename) if os.IsNotExist(err) { return false } return true } func doService(configPath *string, configTest *bool, build string, serviceFlag *string) { if *configPath == "" { ex, err := os.Executable() if err != nil { panic(err) } *configPath = filepath.Dir(ex) + "/config.yaml" if !fileExists(*configPath) { *configPath = filepath.Dir(ex) + "/config.yml" } } svcConfig := &service.Config{ Name: "Nebula", DisplayName: "Nebula Network Service", Description: "Nebula network connectivity daemon for encrypted communications", Arguments: []string{"-service", "run", "-config", *configPath}, } prg := &program{ configPath: configPath, configTest: configTest, build: build, } // Here are what the different loggers are doing: // - `log` is the standard go log utility, meant to be used while the process is still attached to stdout/stderr // - `logger` is the service log utility that may be attached to a special place depending on OS (Windows will have it attached to the event log) // - above, in `Run` we create a `logrus.Logger` which is what nebula expects to use s, err := service.New(prg, svcConfig) if err != nil { log.Fatal(err) } errs := make(chan error, 5) logger, err = s.Logger(errs) if err != nil { log.Fatal(err) } go func() { for { err := <-errs if err != nil { // Route any errors from the system logger to stdout as a best effort to notice issues there log.Print(err) } } }() switch *serviceFlag { case "run": err = s.Run() if err != nil { // Route any errors to the system logger logger.Error(err) } default: err := service.Control(s, *serviceFlag) if err != nil { log.Printf("Valid actions: %q\n", service.ControlAction) log.Fatal(err) } return } } ================================================ FILE: config/config.go ================================================ package config import ( "context" "errors" "fmt" "math" "os" "os/signal" "path/filepath" "sort" "strconv" "strings" "sync" "syscall" "time" "dario.cat/mergo" "github.com/sirupsen/logrus" "go.yaml.in/yaml/v3" ) type C struct { path string files []string Settings map[string]any oldSettings map[string]any callbacks []func(*C) l *logrus.Logger reloadLock sync.Mutex } func NewC(l *logrus.Logger) *C { return &C{ Settings: make(map[string]any), l: l, } } // Load will find all yaml files within path and load them in lexical order func (c *C) Load(path string) error { c.path = path c.files = make([]string, 0) err := c.resolve(path, true) if err != nil { return err } if len(c.files) == 0 { return fmt.Errorf("no config files found at %s", path) } sort.Strings(c.files) err = c.parse() if err != nil { return err } return nil } func (c *C) LoadString(raw string) error { if raw == "" { return errors.New("Empty configuration") } return c.parseRaw([]byte(raw)) } // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered // here should decide if they need to make a change to the current process before making the change. HasChanged can be // used to help decide if a change is necessary. // These functions should return quickly or spawn their own go routine if they will take a while func (c *C) RegisterReloadCallback(f func(*C)) { c.callbacks = append(c.callbacks, f) } // InitialLoad returns true if this is the first load of the config, and ReloadConfig has not been called yet. func (c *C) InitialLoad() bool { return c.oldSettings == nil } // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of // k in both the old and new settings will be serialized, the result of the string comparison is returned. // If k is an empty string the entire config is tested. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating // there is change when there actually wasn't any. func (c *C) HasChanged(k string) bool { if c.oldSettings == nil { return false } var ( nv any ov any ) if k == "" { nv = c.Settings ov = c.oldSettings k = "all settings" } else { nv = c.get(k, c.Settings) ov = c.get(k, c.oldSettings) } newVals, err := yaml.Marshal(nv) if err != nil { c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config") } oldVals, err := yaml.Marshal(ov) if err != nil { c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config") } return string(newVals) != string(oldVals) } // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the // original path provided to Load. The old settings are shallow copied for change detection after the reload. func (c *C) CatchHUP(ctx context.Context) { if c.path == "" { return } ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGHUP) go func() { for { select { case <-ctx.Done(): signal.Stop(ch) close(ch) return case <-ch: c.l.Info("Caught HUP, reloading config") c.ReloadConfig() } } }() } func (c *C) ReloadConfig() { c.reloadLock.Lock() defer c.reloadLock.Unlock() c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } err := c.Load(c.path) if err != nil { c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config") return } for _, v := range c.callbacks { v(c) } } func (c *C) ReloadConfigString(raw string) error { c.reloadLock.Lock() defer c.reloadLock.Unlock() c.oldSettings = make(map[string]any) for k, v := range c.Settings { c.oldSettings[k] = v } err := c.LoadString(raw) if err != nil { return err } for _, v := range c.callbacks { v(c) } return nil } // GetString will get the string for k or return the default d if not found or invalid func (c *C) GetString(k, d string) string { r := c.Get(k) if r == nil { return d } return fmt.Sprintf("%v", r) } // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid func (c *C) GetStringSlice(k string, d []string) []string { r := c.Get(k) if r == nil { return d } rv, ok := r.([]any) if !ok { return d } v := make([]string, len(rv)) for i := 0; i < len(v); i++ { v[i] = fmt.Sprintf("%v", rv[i]) } return v } // GetMap will get the map for k or return the default d if not found or invalid func (c *C) GetMap(k string, d map[string]any) map[string]any { r := c.Get(k) if r == nil { return d } v, ok := r.(map[string]any) if !ok { return d } return v } // GetInt will get the int for k or return the default d if not found or invalid func (c *C) GetInt(k string, d int) int { r := c.GetString(k, strconv.Itoa(d)) v, err := strconv.Atoi(r) if err != nil { return d } return v } // GetUint32 will get the uint32 for k or return the default d if not found or invalid func (c *C) GetUint32(k string, d uint32) uint32 { r := c.GetInt(k, int(d)) if r < 0 || uint64(r) > uint64(math.MaxUint32) { return d } return uint32(r) } // GetBool will get the bool for k or return the default d if not found or invalid func (c *C) GetBool(k string, d bool) bool { r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d))) v, err := strconv.ParseBool(r) if err != nil { switch r { case "y", "yes": return true case "n", "no": return false } return d } return v } func AsBool(v any) (value bool, ok bool) { switch x := v.(type) { case bool: return x, true case string: switch x { case "y", "yes": return true, true case "n", "no": return false, true } } return false, false } // GetDuration will get the duration for k or return the default d if not found or invalid func (c *C) GetDuration(k string, d time.Duration) time.Duration { r := c.GetString(k, "") v, err := time.ParseDuration(r) if err != nil { return d } return v } func (c *C) Get(k string) any { return c.get(k, c.Settings) } func (c *C) IsSet(k string) bool { return c.get(k, c.Settings) != nil } func (c *C) get(k string, v any) any { parts := strings.Split(k, ".") for _, p := range parts { m, ok := v.(map[string]any) if !ok { return nil } v, ok = m[p] if !ok { return nil } } return v } // direct signifies if this is the config path directly specified by the user, // versus a file/dir found by recursing into that path func (c *C) resolve(path string, direct bool) error { i, err := os.Stat(path) if err != nil { return nil } if !i.IsDir() { c.addFile(path, direct) return nil } paths, err := readDirNames(path) if err != nil { return fmt.Errorf("problem while reading directory %s: %s", path, err) } for _, p := range paths { err := c.resolve(filepath.Join(path, p), false) if err != nil { return err } } return nil } func (c *C) addFile(path string, direct bool) error { ext := filepath.Ext(path) if !direct && ext != ".yaml" && ext != ".yml" { return nil } ap, err := filepath.Abs(path) if err != nil { return err } c.files = append(c.files, ap) return nil } func (c *C) parseRaw(b []byte) error { var m map[string]any err := yaml.Unmarshal(b, &m) if err != nil { return err } c.Settings = m return nil } func (c *C) parse() error { var m map[string]any for _, path := range c.files { b, err := os.ReadFile(path) if err != nil { return err } var nm map[string]any err = yaml.Unmarshal(b, &nm) if err != nil { return err } // We need to use WithAppendSlice so that firewall rules in separate // files are appended together err = mergo.Merge(&nm, m, mergo.WithAppendSlice) m = nm if err != nil { return err } } c.Settings = m return nil } func readDirNames(path string) ([]string, error) { f, err := os.Open(path) if err != nil { return nil, err } paths, err := f.Readdirnames(-1) f.Close() if err != nil { return nil, err } sort.Strings(paths) return paths, nil } ================================================ FILE: config/config_test.go ================================================ package config import ( "os" "path/filepath" "testing" "time" "dario.cat/mergo" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" ) func TestConfig_Load(t *testing.T) { l := test.NewLogger() dir, err := os.MkdirTemp("", "config-test") // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}") // simple multi config merge c = NewC(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) require.NoError(t, c.Load(dir)) expected := map[string]any{ "outer": map[string]any{ "inner": "override", }, "new": "hi", } assert.Equal(t, expected, c.Settings) } func TestConfig_Get(t *testing.T) { l := test.NewLogger() // test simple type c := NewC(l) c.Settings["firewall"] = map[string]any{"outbound": "hi"} assert.Equal(t, "hi", c.Get("firewall.outbound")) // test complex type inner := []map[string]any{{"port": "1", "code": "2"}} c.Settings["firewall"] = map[string]any{"outbound": inner} assert.EqualValues(t, inner, c.Get("firewall.outbound")) // test missing assert.Nil(t, c.Get("firewall.nope")) } func TestConfig_GetStringSlice(t *testing.T) { l := test.NewLogger() c := NewC(l) c.Settings["slice"] = []any{"one", "two"} assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{})) } func TestConfig_GetBool(t *testing.T) { l := test.NewLogger() c := NewC(l) c.Settings["bool"] = true assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "true" assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = false assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "false" assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "Y" assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "yEs" assert.True(t, c.GetBool("bool", false)) c.Settings["bool"] = "N" assert.False(t, c.GetBool("bool", true)) c.Settings["bool"] = "nO" assert.False(t, c.GetBool("bool", true)) } func TestConfig_HasChanged(t *testing.T) { l := test.NewLogger() // No reload has occurred, return false c := NewC(l) c.Settings["test"] = "hi" assert.False(t, c.HasChanged("")) // Test key change c = NewC(l) c.Settings["test"] = "hi" c.oldSettings = map[string]any{"test": "no"} assert.True(t, c.HasChanged("test")) assert.True(t, c.HasChanged("")) // No key change c = NewC(l) c.Settings["test"] = "hi" c.oldSettings = map[string]any{"test": "hi"} assert.False(t, c.HasChanged("test")) assert.False(t, c.HasChanged("")) } func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) require.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) assert.False(t, c.HasChanged("")) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644) c.RegisterReloadCallback(func(c *C) { done <- true }) c.ReloadConfig() assert.True(t, c.HasChanged("outer.inner")) assert.True(t, c.HasChanged("outer")) assert.True(t, c.HasChanged("")) // Make sure we call the callbacks select { case <-done: case <-time.After(1 * time.Second): panic("timeout") } } // Ensure mergo merges are done the way we expect. // This is needed to test for potential regressions, like: // - https://github.com/imdario/mergo/issues/187 func TestConfig_MergoMerge(t *testing.T) { configs := [][]byte{ []byte(` listen: port: 1234 `), []byte(` firewall: inbound: - port: 443 proto: tcp groups: - server - port: 443 proto: tcp groups: - webapp `), []byte(` listen: host: 0.0.0.0 port: 4242 firewall: outbound: - port: any proto: any host: any inbound: - port: any proto: icmp host: any `), } var m map[string]any // merge the same way config.parse() merges for _, b := range configs { var nm map[string]any err := yaml.Unmarshal(b, &nm) require.NoError(t, err) // We need to use WithAppendSlice so that firewall rules in separate // files are appended together err = mergo.Merge(&nm, m, mergo.WithAppendSlice) m = nm require.NoError(t, err) } t.Logf("Merged Config: %#v", m) mYaml, err := yaml.Marshal(m) require.NoError(t, err) t.Logf("Merged Config as YAML:\n%s", mYaml) // If a bug is present, some items might be replaced instead of merged like we expect expected := map[string]any{ "firewall": map[string]any{ "inbound": []any{ map[string]any{"host": "any", "port": "any", "proto": "icmp"}, map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"}, map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}}, "outbound": []any{ map[string]any{"host": "any", "port": "any", "proto": "any"}}}, "listen": map[string]any{ "host": "0.0.0.0", "port": 4242, }, } assert.Equal(t, expected, m) } ================================================ FILE: connection_manager.go ================================================ package nebula import ( "bytes" "context" "encoding/binary" "fmt" "net/netip" "sync" "sync/atomic" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type trafficDecision int const ( doNothing trafficDecision = 0 deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote swapPrimary trafficDecision = 3 migrateRelays trafficDecision = 4 tryRehandshake trafficDecision = 5 sendTestPacket trafficDecision = 6 ) type connectionManager struct { // relayUsed holds which relay localIndexs are in use relayUsed map[uint32]struct{} relayUsedLock *sync.RWMutex hostMap *HostMap trafficTimer *LockingTimerWheel[uint32] intf *Interface punchy *Punchy // Configuration settings checkInterval time.Duration pendingDeletionInterval time.Duration inactivityTimeout atomic.Int64 dropInactive atomic.Bool metricsTxPunchy metrics.Counter l *logrus.Logger } func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ hostMap: hm, l: l, punchy: p, relayUsed: make(map[uint32]struct{}), relayUsedLock: &sync.RWMutex{}, metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), } cm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { cm.reload(c, false) }) return cm } func (cm *connectionManager) reload(c *config.C, initial bool) { if initial { cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second // We want at least a minimum resolution of 500ms per tick so that we can hit these intervals // pretty close to their configured duration. // The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it. minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval) maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval) cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration) } if initial || c.HasChanged("tunnels.inactivity_timeout") { old := cm.getInactivityTimeout() cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) if !initial { cm.l.WithField("oldDuration", old). WithField("newDuration", cm.getInactivityTimeout()). Info("Inactivity timeout has changed") } } if initial || c.HasChanged("tunnels.drop_inactive") { old := cm.dropInactive.Load() cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) if !initial { cm.l.WithField("oldBool", old). WithField("newBool", cm.dropInactive.Load()). Info("Drop inactive setting has changed") } } } func (cm *connectionManager) getInactivityTimeout() time.Duration { return (time.Duration)(cm.inactivityTimeout.Load()) } func (cm *connectionManager) In(h *HostInfo) { h.in.Store(true) } func (cm *connectionManager) Out(h *HostInfo) { h.out.Store(true) } func (cm *connectionManager) RelayUsed(localIndex uint32) { cm.relayUsedLock.RLock() // If this already exists, return if _, ok := cm.relayUsed[localIndex]; ok { cm.relayUsedLock.RUnlock() return } cm.relayUsedLock.RUnlock() cm.relayUsedLock.Lock() cm.relayUsed[localIndex] = struct{}{} cm.relayUsedLock.Unlock() } // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and // resets the state for this local index func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) { in := h.in.Swap(false) out := h.out.Swap(false) if in || out { h.lastUsed = now } return in, out } // AddTrafficWatch must be called for every new HostInfo. // We will continue to monitor the HostInfo until the tunnel is dropped. func (cm *connectionManager) AddTrafficWatch(h *HostInfo) { if h.out.Swap(true) == false { cm.trafficTimer.Add(h.localIndexId, cm.checkInterval) } } func (cm *connectionManager) Start(ctx context.Context) { clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration) defer clockSource.Stop() p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) for { select { case <-ctx.Done(): return case now := <-clockSource.C: cm.trafficTimer.Advance(now) for { localIndex, has := cm.trafficTimer.Purge() if !has { break } cm.doTrafficCheck(localIndex, p, nb, out, now) } } } } func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now) switch decision { case deleteTunnel: if cm.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: cm.intf.sendCloseTunnel(hostinfo) cm.intf.closeTunnel(hostinfo) case swapPrimary: cm.swapPrimary(hostinfo, primary) case migrateRelays: cm.migrateRelayUsed(hostinfo, primary) case tryRehandshake: cm.tryRehandshake(hostinfo) case sendTestPacket: cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) } cm.resetRelayTrafficCheck(hostinfo) } func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { if hostinfo != nil { cm.relayUsedLock.Lock() defer cm.relayUsedLock.Unlock() // No need to migrate any relays, delete usage info now. for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { delete(cm.relayUsed, idx) } } } func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr) var index uint32 var relayFrom netip.Addr var relayTo netip.Addr switch { case ok: switch existing.State { case Established, PeerRequested, Disestablished: // This relay already exists in newhostinfo, then do nothing. continue case Requested: // The relay exists in a Requested state; re-send the request index = existing.LocalIndex switch r.Type { case TerminalType: relayFrom = cm.intf.myVpnAddrs[0] relayTo = existing.PeerAddr case ForwardingType: relayFrom = existing.PeerAddr relayTo = newhostinfo.vpnAddrs[0] default: // should never happen panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type)) } } case !ok: cm.relayUsedLock.RLock() if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed { // The relay hasn't been used; don't migrate it. cm.relayUsedLock.RUnlock() continue } cm.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: relayFrom = cm.intf.myVpnAddrs[0] relayTo = r.PeerAddr case ForwardingType: relayFrom = r.PeerAddr relayTo = newhostinfo.vpnAddrs[0] default: // should never happen panic(fmt.Sprintf("Migrating unknown relay type: %v", r.Type)) } } // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, } switch newhostinfo.GetCert().Certificate.Version() { case cert.Version1: if !relayFrom.Is4() { cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !relayTo.Is4() { cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } b := relayFrom.As4() req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) b = relayTo.As4() req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) case cert.Version2: req.RelayFromAddr = netAddrToProtoAddr(relayFrom) req.RelayToAddr = netAddrToProtoAddr(relayTo) default: newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay") continue } msg, err := req.Marshal() if err != nil { cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) cm.l.WithFields(logrus.Fields{ "relayFrom": req.RelayFromAddr, "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnAddrs": newhostinfo.vpnAddrs}). Info("send CreateRelayRequest") } } } func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { // Read lock the main hostmap to order decisions based on tunnels being the primary tunnel cm.hostMap.RLock() defer cm.hostMap.RUnlock() hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") return doNothing, nil, nil } if cm.isInvalidCertificate(now, hostinfo) { return closeTunnel, hostinfo, nil } primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false } // Check for traffic on this hostinfo inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now) // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing if cm.l.Level >= logrus.DebugLevel { hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). Debug("Tunnel status") } hostinfo.pendingDeletion.Store(false) if mainHostInfo { decision = tryRehandshake } else { if cm.shouldSwapPrimary(hostinfo) { decision = swapPrimary } else { // migrate the relays to the primary, if in use. decision = migrateRelays } } cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) if !outTraffic { // Send a punch packet to keep the NAT state alive cm.sendPunch(hostinfo) } return decision, hostinfo, primary } if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "dead", "method": "active"}). Info("Tunnel status") return deleteTunnel, hostinfo, nil } decision := doNothing if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { inactiveFor, isInactive := cm.isInactive(hostinfo, now) if isInactive { // Tunnel is inactive, tear it down hostinfo.logger(cm.l). WithField("inactiveDuration", inactiveFor). WithField("primary", mainHostInfo). Info("Dropping tunnel due to inactivity") return closeTunnel, hostinfo, primary } // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. cm.sendPunch(hostinfo) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) return doNothing, nil, nil } if cm.punchy.GetTargetEverything() { // This is similar to the old punchy behavior with a slight optimization. // We aren't receiving traffic but we are sending it, punch on all known // ips in case we need to re-prime NAT state cm.sendPunch(hostinfo) } if cm.l.Level >= logrus.DebugLevel { hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "testing", "method": "active"}). Debug("Tunnel status") } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues decision = sendTestPacket } else { if cm.l.Level >= logrus.DebugLevel { hostinfo.logger(cm.l).Debugf("Hostinfo sadness") } } hostinfo.pendingDeletion.Store(true) cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval) return decision, hostinfo, nil } func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) { if cm.dropInactive.Load() == false { // We aren't configured to drop inactive tunnels return 0, false } inactiveDuration := now.Sub(hostinfo.lastUsed) if inactiveDuration < cm.getInactivityTimeout() { // It's not considered inactive return inactiveDuration, false } // The tunnel is inactive return inactiveDuration, true } func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool { // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. // Only one side should swap because if both swap then we may never resolve to a single tunnel. // vpn addr is static across all tunnels for this host pair so lets // use that to determine if we should consider swapping. if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 { // Their primary vpn addr is less than mine. Do not swap. return false } crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) if crt == nil { //my cert was reloaded away. We should definitely swap from this tunnel return true } // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } func (cm *connectionManager) swapPrimary(current, primary *HostInfo) { cm.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary { cm.hostMap.unlockedMakePrimary(current) } cm.hostMap.Unlock() } // isInvalidCertificate decides if we should destroy a tunnel. // returns true if pki.disconnect_invalid is true and the certificate is no longer valid. // Blocklisted certificates will skip the pki.disconnect_invalid check and return true. func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { remoteCert := hostinfo.GetCert() if remoteCert == nil { return false //don't tear down tunnels for handshakes in progress } caPool := cm.intf.pki.GetCAPool() err := caPool.VerifyCachedCertificate(now, remoteCert) if err == nil { return false //cert is still valid! yay! } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected hostinfo.logger(cm.l).WithError(err). WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is blocked, tearing down the tunnel") return true } else if cm.intf.disconnectInvalid.Load() { hostinfo.logger(cm.l).WithError(err). WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") return true } else { //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open return false } } func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { if !cm.punchy.GetPunch() { // Punching is disabled return } if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { // Do not punch to lighthouses, we assume our lighthouse update interval is good enough. // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse // would lose the ability to notify us and punchy.respond would become unreliable. return } if cm.punchy.GetTargetEverything() { hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { cm.metricsTxPunchy.Inc(1) cm.intf.outside.WriteTo([]byte{1}, addr) }) } else if hostinfo.remote.IsValid() { cm.metricsTxPunchy.Inc(1) cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } } func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert curCrtVersion := curCrt.Version() myCrt := cs.getCertificate(curCrtVersion) if myCrt == nil { cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("version", curCrtVersion). WithField("reason", "local certificate removed"). Info("Re-handshaking with remote") cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } peerCrt := hostinfo.ConnectionState.peerCert if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { // if our certificate version is less than theirs, and we have a matching version available, rehandshake? if cs.getCertificate(peerCrt.Certificate.Version()) != nil { cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("version", curCrtVersion). WithField("peerVersion", peerCrt.Certificate.Version()). WithField("reason", "local certificate version lower than peer, attempting to correct"). Info("Re-handshaking with remote") cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { hh.initiatingVersionOverride = peerCrt.Certificate.Version() }) return } } if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } if curCrtVersion < cs.initiatingVersion { cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "current cert version < pki.initiatingVersion"). Info("Re-handshaking with remote") cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } } ================================================ FILE: connection_manager_test.go ================================================ package nebula import ( "crypto/ed25519" "crypto/rand" "net/netip" "testing" "time" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func newTestLighthouse() *LightHouse { lh := &LightHouse{ l: test.NewLogger(), addrMap: map[netip.Addr]*RemoteList{}, queryChan: make(chan netip.Addr, 10), } lighthouses := []netip.Addr{} staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) lh.staticList.Store(&staticList) return lh } func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, pki: &PKI{}, handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } ifce.pki.cs.Store(cs) // Create manager conf := config.NewC(l) punchy := NewPunchyFromConfig(l, conf) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp nc.Out(hostinfo) nc.In(hostinfo) assert.False(t, hostinfo.pendingDeletion.Load()) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.True(t, hostinfo.out.Load()) assert.True(t, hostinfo.in.Load()) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now nc.Out(hostinfo) assert.True(t, hostinfo.out.Load()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.True(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, pki: &PKI{}, handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } ifce.pki.cs.Store(cs) // Create manager conf := config.NewC(l) punchy := NewPunchyFromConfig(l, conf) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp nc.Out(hostinfo) nc.In(hostinfo) assert.True(t, hostinfo.in.Load()) assert.True(t, hostinfo.out.Load()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now nc.Out(hostinfo) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.True(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion nc.In(hostinfo) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { l := test.NewLogger() localrange := netip.MustParsePrefix("10.1.1.1/24") vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")} preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, pki: &PKI{}, handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, } ifce.pki.cs.Store(cs) // Create manager conf := config.NewC(l) conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } punchy := NewPunchyFromConfig(l, conf) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ vpnAddrs: vpnAddrs, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // Do a traffic check tick, in and out should be cleared but should not be pending deletion nc.Out(hostinfo) nc.In(hostinfo) assert.True(t, hostinfo.out.Load()) assert.True(t, hostinfo.in.Load()) now := time.Now() decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now) assert.Equal(t, tryRehandshake, decision) assert.Equal(t, now, hostinfo.lastUsed) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5)) assert.Equal(t, doNothing, decision) assert.Equal(t, now, hostinfo.lastUsed) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, should still not be pending deletion decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10)) assert.Equal(t, doNothing, decision) assert.Equal(t, now, hostinfo.lastUsed) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Finally advance beyond the inactivity timeout decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10)) assert.Equal(t, closeTunnel, decision) assert.Equal(t, now, hostinfo.lastUsed) assert.False(t, hostinfo.pendingDeletion.Load()) assert.False(t, hostinfo.out.Load()) assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } // Check if we can disconnect the peer. // Validate if the peer's certificate is invalid (expired, etc.) // Disconnect only if disconnectInvalid: true is set. func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { now := time.Now() l := test.NewLogger() vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) tbs := &cert.TBSCertificate{ Version: 1, Name: "ca", IsCA: true, NotBefore: now, NotAfter: now.Add(1 * time.Hour), PublicKey: pubCA, } caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) require.NoError(t, err) ncp := cert.NewCAPool() require.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) tbs = &cert.TBSCertificate{ Version: 1, Name: "host", Networks: []netip.Prefix{vpncidr}, NotBefore: now, NotAfter: now.Add(60 * time.Second), PublicKey: pubCrt, } peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) require.NoError(t, err) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ privateKey: []byte{}, v1Cert: &dummyCert{}, v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() ifce := &Interface{ hostMap: hostMap, inside: &test.NoopTun{}, outside: &udp.NoopConn{}, firewall: &Firewall{}, lightHouse: lh, handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), l: l, pki: &PKI{}, } ifce.pki.cs.Store(cs) ifce.pki.caPool.Store(ncp) ifce.disconnectInvalid.Store(true) // Create manager conf := config.NewC(l) punchy := NewPunchyFromConfig(l, conf) nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnIp}, ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, H: &noise.HandshakeState{}, }, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // Move ahead 45s. // Check if to disconnect with invalid certificate. // Should be alive. nextTick := now.Add(45 * time.Second) invalid := nc.isInvalidCertificate(nextTick, hostinfo) assert.False(t, invalid) // Move ahead 61s. // Check if to disconnect with invalid certificate. // Should be disconnected. nextTick = now.Add(61 * time.Second) invalid = nc.isInvalidCertificate(nextTick, hostinfo) assert.True(t, invalid) } type dummyCert struct { version cert.Version curve cert.Curve groups []string isCa bool issuer string name string networks []netip.Prefix notAfter time.Time notBefore time.Time publicKey []byte signature []byte unsafeNetworks []netip.Prefix } func (d *dummyCert) Version() cert.Version { return d.version } func (d *dummyCert) Curve() cert.Curve { return d.curve } func (d *dummyCert) Groups() []string { return d.groups } func (d *dummyCert) IsCA() bool { return d.isCa } func (d *dummyCert) Issuer() string { return d.issuer } func (d *dummyCert) Name() string { return d.name } func (d *dummyCert) Networks() []netip.Prefix { return d.networks } func (d *dummyCert) NotAfter() time.Time { return d.notAfter } func (d *dummyCert) NotBefore() time.Time { return d.notBefore } func (d *dummyCert) PublicKey() []byte { return d.publicKey } func (d *dummyCert) MarshalPublicKeyPEM() []byte { return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey) } func (d *dummyCert) Signature() []byte { return d.signature } func (d *dummyCert) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks } func (d *dummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil } func (d *dummyCert) Sign(curve cert.Curve, key []byte) error { return nil } func (d *dummyCert) CheckSignature(key []byte) bool { return true } func (d *dummyCert) Expired(t time.Time) bool { return false } func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error { return nil } func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error { return nil } func (d *dummyCert) String() string { return "" } func (d *dummyCert) Marshal() ([]byte, error) { return nil, nil } func (d *dummyCert) MarshalPEM() ([]byte, error) { return nil, nil } func (d *dummyCert) Fingerprint() (string, error) { return "", nil } func (d *dummyCert) MarshalJSON() ([]byte, error) { return nil, nil } func (d *dummyCert) Copy() cert.Certificate { return d } ================================================ FILE: connection_state.go ================================================ package nebula import ( "crypto/rand" "encoding/json" "fmt" "sync" "sync/atomic" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool messageCounter atomic.Uint64 window *Bits writeLock sync.Mutex } func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc switch crt.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: if cs.pkcs11Backed { dhFunc = noiseutil.DHP256PKCS11 } else { dhFunc = noiseutil.DHP256 } default: return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) } var ncs noise.CipherSuite if cs.cipher == "chachapoly" { ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } else { ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} hs, err := noise.NewHandshakeState(noise.Config{ CipherSuite: ncs, Random: rand.Reader, Pattern: pattern, Initiator: initiator, StaticKeypair: static, //NOTE: These should come from CertState (pki.go) when we finally implement it PresharedKey: []byte{}, PresharedKeyPlacement: 0, }) if err != nil { return nil, fmt.Errorf("NewConnectionState: %s", err) } // The queue and ready params prevent a counter race that would happen when // sending stored packets and simultaneously accepting new traffic. ci := &ConnectionState{ H: hs, initiator: initiator, window: NewBits(ReplayWindow), myCert: crt, } // always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(2) return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { return json.Marshal(m{ "certificate": cs.peerCert, "initiator": cs.initiator, "message_counter": cs.messageCounter.Load(), }) } func (cs *ConnectionState) Curve() cert.Curve { return cs.myCert.Curve() } ================================================ FILE: control.go ================================================ package nebula import ( "context" "net/netip" "os" "os/signal" "syscall" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc type controlEach func(h *HostInfo) type controlHostLister interface { QueryVpnAddr(vpnAddr netip.Addr) *HostInfo ForEachIndex(each controlEach) ForEachVpnAddr(each controlEach) GetPreferredRanges() []netip.Prefix } type Control struct { f *Interface l *logrus.Logger ctx context.Context cancel context.CancelFunc sshStart func() statsStart func() dnsStart func() lighthouseStart func() connectionManagerStart func(context.Context) } type ControlHostInfo struct { VpnAddrs []netip.Addr `json:"vpnAddrs"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` Cert cert.Certificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` CurrentRemote netip.AddrPort `json:"currentRemote"` CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() func (c *Control) Start() { // Activate the interface c.f.activate() // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { go c.sshStart() } if c.statsStart != nil { go c.statsStart() } if c.dnsStart != nil { go c.dnsStart() } if c.connectionManagerStart != nil { go c.connectionManagerStart(c.ctx) } if c.lighthouseStart != nil { c.lighthouseStart() } // Start reading packets. c.f.run() } func (c *Control) Context() context.Context { return c.ctx } // Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete func (c *Control) Stop() { // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() c.CloseAllTunnels(false) if err := c.f.Close(); err != nil { c.l.WithError(err).Error("Close interface failed") } c.l.Info("Goodbye") } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled func (c *Control) ShutdownBlock() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT) rawSig := <-sigChan sig := rawSig.String() c.l.WithField("signal", sig).Info("Caught signal, shutting down") c.Stop() } // RebindUDPServer asks the UDP listener to rebind it's listener. Mainly used on mobile clients when interfaces change func (c *Control) RebindUDPServer() { _ = c.f.outside.Rebind() // Trigger a lighthouse update, useful for mobile clients that should have an update interval of 0 c.f.lightHouse.SendUpdate() // Let the main interface know that we rebound so that underlying tunnels know to trigger punches from their remotes c.f.rebindCount++ } // ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { if pendingMap { return listHostMapHosts(c.f.handshakeManager) } else { return listHostMapHosts(c.f.hostMap) } } // ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { if pendingMap { return listHostMapIndexes(c.f.handshakeManager) } else { return listHostMapIndexes(c.f.hostMap) } } // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { if c.f.myVpnAddrsTable.Contains(vpnIp) { // Only returning the default certificate since its impossible // for any other host but ourselves to have more than 1 return c.f.pki.getCertState().GetDefaultCertificate().Copy() } hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } return hi.GetCert().Certificate.Copy() } // CreateTunnel creates a new tunnel to the given vpn ip. func (c *Control) CreateTunnel(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) } // PrintTunnel creates a new tunnel to the given vpn ip. func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges()) return &chi } // QueryLighthouse queries the lighthouse. func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { hi := c.f.lightHouse.Query(vpnIp) if hi == nil { return nil } return hi.CopyCache() } // GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager } else { hl = c.f.hostMap } h := hl.QueryVpnAddr(vpnAddr) if h == nil { return nil } ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges()) return &ch } // SetRemoteForTunnel forces a tunnel to use a specific remote // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return nil } hostInfo.SetRemote(addr) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return false } if !localOnly { c.f.send( header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, []byte{}, make([]byte, 12, 12), make([]byte, mtu), ) } c.f.closeTunnel(hostInfo) return true } // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { shutdown := func(h *HostInfo) { if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { return } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } // Learn which hosts are being used as relays, so we can shut them down last. relayingHosts := map[netip.Addr]*HostInfo{} // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost } c.f.hostMap.Unlock() hostInfos := []*HostInfo{} // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() for _, relayHost := range c.f.hostMap.Indexes { if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok { hostInfos = append(hostInfos, relayHost) } } c.f.hostMap.Unlock() for _, h := range hostInfos { shutdown(h) } for _, h := range relayingHosts { shutdown(h) } return } func (c *Control) Device() overlay.Device { return c.f.inside } func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { chi := ControlHostInfo{ VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), CurrentRemote: h.remote, } for i, a := range h.vpnAddrs { chi.VpnAddrs[i] = a } if h.ConnectionState != nil { chi.MessageCounter = h.ConnectionState.messageCounter.Load() } if c := h.GetCert(); c != nil { chi.Cert = c.Certificate.Copy() } return chi } func listHostMapHosts(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() hl.ForEachVpnAddr(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts } func listHostMapIndexes(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() hl.ForEachIndex(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts } ================================================ FILE: control_test.go ================================================ package nebula import ( "net" "net/netip" "reflect" "testing" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) func TestControl_GetHostInfoByVpnIp(t *testing.T) { //TODO: CERT-V2 with multiple certificate versions we have a problem with this test // Some certs versions have different characteristics and each version implements their own Copy() func // which means this is not a good place to test for exposing memory l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller hm := newHostMap(l) hm.preferredRanges.Store(&[]netip.Prefix{}) remote1 := netip.MustParseAddrPort("0.0.0.100:4444") remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") ipNet := net.IPNet{ IP: remote1.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } ipNet2 := net.IPNet{ IP: remote2.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) vpnIp, ok := netip.AddrFromSlice(ipNet.IP) assert.True(t, ok) crt := &dummyCert{} hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ peerCert: &cert.CachedCertificate{Certificate: crt}, }, remoteIndexId: 200, localIndexId: 201, vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP) assert.True(t, ok) hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ peerCert: nil, }, remoteIndexId: 200, localIndexId: 201, vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) c := Control{ f: &Interface{ hostMap: hm, }, l: logrus.New(), } thi := c.GetHostInfoByVpnAddr(vpnIp, false) expectedInfo := ControlHostInfo{ VpnAddrs: []netip.Addr{vpnIp}, LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []netip.AddrPort{remote2, remote1}, Cert: crt.Copy(), MessageCounter: 0, CurrentRemote: remote1, CurrentRelaysToMe: []netip.Addr{}, CurrentRelaysThroughMe: []netip.Addr{}, } // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assert.Equal(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { thi = c.GetHostInfoByVpnAddr(vpnIp2, false) }) } func assertFields(t *testing.T, expected []string, actualStruct any) { val := reflect.ValueOf(actualStruct).Elem() fields := make([]string, val.NumField()) for i := 0; i < val.NumField(); i++ { fields[i] = val.Type().Field(i).Name } assert.Equal(t, expected, fields) } ================================================ FILE: control_tester.go ================================================ //go:build e2e_testing package nebula import ( "net/netip" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) // WaitForType will pipe all messages from this control device into the pipeTo control device // returning after a message matching the criteria has been piped func (c *Control) WaitForType(msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} for { p := c.f.outside.(*udp.TesterConn).Get(true) if err := h.Parse(p.Data); err != nil { panic(err) } pipeTo.InjectUDPPacket(p) if h.Type == msgType && h.Subtype == subType { return } } } // WaitForTypeByIndex is similar to WaitForType except it adds an index check // Useful if you have many nodes communicating and want to wait to find a specific nodes packet func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, subType header.MessageSubType, pipeTo *Control) { h := &header.H{} for { p := c.f.outside.(*udp.TesterConn).Get(true) if err := h.Parse(p.Data); err != nil { panic(err) } pipeTo.InjectUDPPacket(p) if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType { return } } } // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() if toAddr.Addr().Is4() { remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port())) } else { remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port())) } } // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // This is necessary to inform an initiator of possible relays for communicating with a responder func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() remoteList.unlockedSetRelay(vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula func (c *Control) GetFromTun(block bool) []byte { return c.f.inside.(*overlay.TestTun).Get(block) } // GetFromUDP will pull a udp packet off the udp side of nebula func (c *Control) GetFromUDP(block bool) *udp.Packet { return c.f.outside.(*udp.TesterConn).Get(block) } func (c *Control) GetUDPTxChan() <-chan *udp.Packet { return c.f.outside.(*udp.TesterConn).TxPackets } func (c *Control) GetTunTxChan() <-chan []byte { return c.f.inside.(*overlay.TestTun).TxPackets } // InjectUDPPacket will inject a packet into the udp side of nebula func (c *Control) InjectUDPPacket(p *udp.Packet) { c.f.outside.(*udp.TesterConn).Send(p) } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { serialize := make([]gopacket.SerializableLayer, 0) var netLayer gopacket.NetworkLayer if toAddr.Is6() { if !fromAddr.Is6() { panic("Cant send ipv6 to ipv4") } ip := &layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolUDP, SrcIP: fromAddr.Unmap().AsSlice(), DstIP: toAddr.Unmap().AsSlice(), } serialize = append(serialize, ip) netLayer = ip } else { if !fromAddr.Is4() { panic("Cant send ipv4 to ipv6") } ip := &layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, SrcIP: fromAddr.Unmap().AsSlice(), DstIP: toAddr.Unmap().AsSlice(), } serialize = append(serialize, ip) netLayer = ip } udp := layers.UDP{ SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } err := udp.SetNetworkLayerForChecksum(netLayer) if err != nil { panic(err) } buffer := gopacket.NewSerializeBuffer() opt := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } serialize = append(serialize, &udp, gopacket.Payload(data)) err = gopacket.SerializeLayers(buffer, opt, serialize...) if err != nil { panic(err) } c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } func (c *Control) GetVpnAddrs() []netip.Addr { return c.f.myVpnAddrs } func (c *Control) GetUDPAddr() netip.AddrPort { return c.f.outside.(*udp.TesterConn).Addr } func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp) if hostinfo == nil { return false } c.f.handshakeManager.DeleteHostInfo(hostinfo) return true } func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } func (c *Control) GetF() *Interface { return c.f } func (c *Control) GetCertState() *CertState { return c.f.pki.getCertState() } func (c *Control) ReHandshake(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) } ================================================ FILE: dist/windows/wintun/LICENSE.txt ================================================ Prebuilt Binaries License ------------------------- 1. DEFINITIONS. "Software" means the precise contents of the "wintun.dll" files that are included in the .zip file that contains this document as downloaded from wintun.net/builds. 2. LICENSE GRANT. WireGuard LLC grants to you a non-exclusive and non-transferable right to use Software for lawful purposes under certain obligations and limited rights as set forth in this agreement. 3. RESTRICTIONS. Software is owned and copyrighted by WireGuard LLC. It is licensed, not sold. Title to Software and all associated intellectual property rights are retained by WireGuard. You must not: a. reverse engineer, decompile, disassemble, extract from, or otherwise modify the Software; b. modify or create derivative work based upon Software in whole or in parts, except insofar as only the API interfaces of the "wintun.h" file distributed alongside the Software (the "Permitted API") are used; c. remove any proprietary notices, labels, or copyrights from the Software; d. resell, redistribute, lease, rent, transfer, sublicense, or otherwise transfer rights of the Software without the prior written consent of WireGuard LLC, except insofar as the Software is distributed alongside other software that uses the Software only via the Permitted API; e. use the name of WireGuard LLC, the WireGuard project, the Wintun project, or the names of its contributors to endorse or promote products derived from the Software without specific prior written consent. 4. LIMITED WARRANTY. THE SOFTWARE IS PROVIDED "AS IS" AND WITHOUT WARRANTY OF ANY KIND. WIREGUARD LLC HEREBY EXCLUDES AND DISCLAIMS ALL IMPLIED OR STATUTORY WARRANTIES, INCLUDING ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, QUALITY, NON-INFRINGEMENT, TITLE, RESULTS, EFFORTS, OR QUIET ENJOYMENT. THERE IS NO WARRANTY THAT THE PRODUCT WILL BE ERROR-FREE OR WILL FUNCTION WITHOUT INTERRUPTION. YOU ASSUME THE ENTIRE RISK FOR THE RESULTS OBTAINED USING THE PRODUCT. TO THE EXTENT THAT WIREGUARD LLC MAY NOT DISCLAIM ANY WARRANTY AS A MATTER OF APPLICABLE LAW, THE SCOPE AND DURATION OF SUCH WARRANTY WILL BE THE MINIMUM PERMITTED UNDER SUCH LAW. ALL EXPRESS OR IMPLIED CONDITIONS, REPRESENTATIONS AND WARRANTIES, INCLUDING ANY IMPLIED WARRANTY OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT ARE DISCLAIMED, EXCEPT TO THE EXTENT THAT THESE DISCLAIMERS ARE HELD TO BE LEGALLY INVALID. 5. LIMITATION OF LIABILITY. To the extent not prohibited by law, in no event WireGuard LLC or any third-party-developer will be liable for any lost revenue, profit or data or for special, indirect, consequential, incidental or punitive damages, however caused regardless of the theory of liability, arising out of or related to the use of or inability to use Software, even if WireGuard LLC has been advised of the possibility of such damages. Solely you are responsible for determining the appropriateness of using Software and accept full responsibility for all risks associated with its exercise of rights under this agreement, including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations. The foregoing limitations will apply even if the above stated warranty fails of its essential purpose. You acknowledge, that it is in the nature of software that software is complex and not completely free of errors. In no event shall WireGuard LLC or any third-party-developer be liable to you under any theory for any damages suffered by you or any user of Software or for any special, incidental, indirect, consequential or similar damages (including without limitation damages for loss of business profits, business interruption, loss of business information or any other pecuniary loss) arising out of the use or inability to use Software, even if WireGuard LLC has been advised of the possibility of such damages and regardless of the legal or quitable theory (contract, tort, or otherwise) upon which the claim is based. 6. TERMINATION. This agreement is affected until terminated. You may terminate this agreement at any time. This agreement will terminate immediately without notice from WireGuard LLC if you fail to comply with the terms and conditions of this agreement. Upon termination, you must delete Software and all copies of Software and cease all forms of distribution of Software. 7. SEVERABILITY. If any provision of this agreement is held to be unenforceable, this agreement will remain in effect with the provision omitted, unless omission would frustrate the intent of the parties, in which case this agreement will immediately terminate. 8. RESERVATION OF RIGHTS. All rights not expressly granted in this agreement are reserved by WireGuard LLC. For example, WireGuard LLC reserves the right at any time to cease development of Software, to alter distribution details, features, specifications, capabilities, functions, licensing terms, release dates, APIs, ABIs, general availability, or other characteristics of the Software. ================================================ FILE: dist/windows/wintun/README.md ================================================ # [Wintun Network Adapter](https://www.wintun.net/) ### TUN Device Driver for Windows This is a layer 3 TUN driver for Windows 7, 8, 8.1, and 10. Originally created for [WireGuard](https://www.wireguard.com/), it is intended to be useful to a wide variety of projects that require layer 3 tunneling devices with implementations primarily in userspace. ## Installation Wintun is deployed as a platform-specific `wintun.dll` file. Install the `wintun.dll` file side-by-side with your application. Download the dll from [wintun.net](https://www.wintun.net/), alongside the header file for your application described below. ## Usage Include the [`wintun.h` file](https://git.zx2c4.com/wintun/tree/api/wintun.h) in your project simply by copying it there and dynamically load the `wintun.dll` using [`LoadLibraryEx()`](https://docs.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa) and [`GetProcAddress()`](https://docs.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-getprocaddress) to resolve each function, using the typedefs provided in the header file. The [`InitializeWintun` function in the example.c code](https://git.zx2c4.com/wintun/tree/example/example.c) provides this in a function that you can simply copy and paste. With the library setup, Wintun can then be used by first creating an adapter, configuring it, and then setting its status to "up". Adapters have names (e.g. "OfficeNet") and types (e.g. "Wintun"). ```C WINTUN_ADAPTER_HANDLE Adapter1 = WintunCreateAdapter(L"OfficeNet", L"Wintun", &SomeFixedGUID1); WINTUN_ADAPTER_HANDLE Adapter2 = WintunCreateAdapter(L"HomeNet", L"Wintun", &SomeFixedGUID2); WINTUN_ADAPTER_HANDLE Adapter3 = WintunCreateAdapter(L"Data Center", L"Wintun", &SomeFixedGUID3); ``` After creating an adapter, we can use it by starting a session: ```C WINTUN_SESSION_HANDLE Session = WintunStartSession(Adapter2, 0x400000); ``` Then, the `WintunAllocateSendPacket` and `WintunSendPacket` functions can be used for sending packets ([used by `SendPackets` in the example.c code](https://git.zx2c4.com/wintun/tree/example/example.c)): ```C BYTE *OutgoingPacket = WintunAllocateSendPacket(Session, PacketDataSize); if (OutgoingPacket) { memcpy(OutgoingPacket, PacketData, PacketDataSize); WintunSendPacket(Session, OutgoingPacket); } else if (GetLastError() != ERROR_BUFFER_OVERFLOW) // Silently drop packets if the ring is full Log(L"Packet write failed"); ``` And the `WintunReceivePacket` and `WintunReleaseReceivePacket` functions can be used for receiving packets ([used by `ReceivePackets` in the example.c code](https://git.zx2c4.com/wintun/tree/example/example.c)): ```C for (;;) { DWORD IncomingPacketSize; BYTE *IncomingPacket = WintunReceivePacket(Session, &IncomingPacketSize); if (IncomingPacket) { DoSomethingWithPacket(IncomingPacket, IncomingPacketSize); WintunReleaseReceivePacket(Session, IncomingPacket); } else if (GetLastError() == ERROR_NO_MORE_ITEMS) WaitForSingleObject(WintunGetReadWaitEvent(Session), INFINITE); else { Log(L"Packet read failed"); break; } } ``` Some high performance use cases may want to spin on `WintunReceivePackets` for a number of cycles before falling back to waiting on the read-wait event. You are **highly encouraged** to read the [**example.c short example**](https://git.zx2c4.com/wintun/tree/example/example.c) to see how to put together a simple userspace network tunnel. The various functions and definitions are [documented in the reference below](#Reference). ## Reference ### Macro Definitions #### WINTUN\_MAX\_POOL `#define WINTUN_MAX_POOL 256` Maximum pool name length including zero terminator #### WINTUN\_MIN\_RING\_CAPACITY `#define WINTUN_MIN_RING_CAPACITY 0x20000 /* 128kiB */` Minimum ring capacity. #### WINTUN\_MAX\_RING\_CAPACITY `#define WINTUN_MAX_RING_CAPACITY 0x4000000 /* 64MiB */` Maximum ring capacity. #### WINTUN\_MAX\_IP\_PACKET\_SIZE `#define WINTUN_MAX_IP_PACKET_SIZE 0xFFFF` Maximum IP packet size ### Typedefs #### WINTUN\_ADAPTER\_HANDLE `typedef void* WINTUN_ADAPTER_HANDLE` A handle representing Wintun adapter #### WINTUN\_ENUM\_CALLBACK `typedef BOOL(* WINTUN_ENUM_CALLBACK) (WINTUN_ADAPTER_HANDLE Adapter, LPARAM Param)` Called by WintunEnumAdapters for each adapter in the pool. **Parameters** - *Adapter*: Adapter handle, which will be freed when this function returns. - *Param*: An application-defined value passed to the WintunEnumAdapters. **Returns** Non-zero to continue iterating adapters; zero to stop. #### WINTUN\_LOGGER\_CALLBACK `typedef void(* WINTUN_LOGGER_CALLBACK) (WINTUN_LOGGER_LEVEL Level, DWORD64 Timestamp, const WCHAR *Message)` Called by internal logger to report diagnostic messages **Parameters** - *Level*: Message level. - *Timestamp*: Message timestamp in in 100ns intervals since 1601-01-01 UTC. - *Message*: Message text. #### WINTUN\_SESSION\_HANDLE `typedef void* WINTUN_SESSION_HANDLE` A handle representing Wintun session ### Enumeration Types #### WINTUN\_LOGGER\_LEVEL `enum WINTUN_LOGGER_LEVEL` Determines the level of logging, passed to WINTUN\_LOGGER\_CALLBACK. - *WINTUN\_LOG\_INFO*: Informational - *WINTUN\_LOG\_WARN*: Warning - *WINTUN\_LOG\_ERR*: Error Enumerator ### Functions #### WintunCreateAdapter() `WINTUN_ADAPTER_HANDLE WintunCreateAdapter (const WCHAR * Name, const WCHAR * TunnelType, const GUID * RequestedGUID)` Creates a new Wintun adapter. **Parameters** - *Name*: The requested name of the adapter. Zero-terminated string of up to MAX\_ADAPTER\_NAME-1 characters. - *Name*: Name of the adapter tunnel type. Zero-terminated string of up to MAX\_ADAPTER\_NAME-1 characters. - *RequestedGUID*: The GUID of the created network adapter, which then influences NLA generation deterministically. If it is set to NULL, the GUID is chosen by the system at random, and hence a new NLA entry is created for each new adapter. It is called "requested" GUID because the API it uses is completely undocumented, and so there could be minor interesting complications with its usage. **Returns** If the function succeeds, the return value is the adapter handle. Must be released with WintunCloseAdapter. If the function fails, the return value is NULL. To get extended error information, call GetLastError. #### WintunOpenAdapter() `WINTUN_ADAPTER_HANDLE WintunOpenAdapter (const WCHAR * Name)` Opens an existing Wintun adapter. **Parameters** - *Name*: The requested name of the adapter. Zero-terminated string of up to MAX\_ADAPTER\_NAME-1 characters. **Returns** If the function succeeds, the return value is adapter handle. Must be released with WintunCloseAdapter. If the function fails, the return value is NULL. To get extended error information, call GetLastError. #### WintunCloseAdapter() `void WintunCloseAdapter (WINTUN_ADAPTER_HANDLE Adapter)` Releases Wintun adapter resources and, if adapter was created with WintunCreateAdapter, removes adapter. **Parameters** - *Adapter*: Adapter handle obtained with WintunCreateAdapter or WintunOpenAdapter. #### WintunDeleteDriver() `BOOL WintunDeleteDriver ()` Deletes the Wintun driver if there are no more adapters in use. **Returns** If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To get extended error information, call GetLastError. #### WintunGetAdapterLuid() `void WintunGetAdapterLuid (WINTUN_ADAPTER_HANDLE Adapter, NET_LUID * Luid)` Returns the LUID of the adapter. **Parameters** - *Adapter*: Adapter handle obtained with WintunOpenAdapter or WintunCreateAdapter - *Luid*: Pointer to LUID to receive adapter LUID. #### WintunGetRunningDriverVersion() `DWORD WintunGetRunningDriverVersion (void )` Determines the version of the Wintun driver currently loaded. **Returns** If the function succeeds, the return value is the version number. If the function fails, the return value is zero. To get extended error information, call GetLastError. Possible errors include the following: ERROR\_FILE\_NOT\_FOUND Wintun not loaded #### WintunSetLogger() `void WintunSetLogger (WINTUN_LOGGER_CALLBACK NewLogger)` Sets logger callback function. **Parameters** - *NewLogger*: Pointer to callback function to use as a new global logger. NewLogger may be called from various threads concurrently. Should the logging require serialization, you must handle serialization in NewLogger. Set to NULL to disable. #### WintunStartSession() `WINTUN_SESSION_HANDLE WintunStartSession (WINTUN_ADAPTER_HANDLE Adapter, DWORD Capacity)` Starts Wintun session. **Parameters** - *Adapter*: Adapter handle obtained with WintunOpenAdapter or WintunCreateAdapter - *Capacity*: Rings capacity. Must be between WINTUN\_MIN\_RING\_CAPACITY and WINTUN\_MAX\_RING\_CAPACITY (incl.) Must be a power of two. **Returns** Wintun session handle. Must be released with WintunEndSession. If the function fails, the return value is NULL. To get extended error information, call GetLastError. #### WintunEndSession() `void WintunEndSession (WINTUN_SESSION_HANDLE Session)` Ends Wintun session. **Parameters** - *Session*: Wintun session handle obtained with WintunStartSession #### WintunGetReadWaitEvent() `HANDLE WintunGetReadWaitEvent (WINTUN_SESSION_HANDLE Session)` Gets Wintun session's read-wait event handle. **Parameters** - *Session*: Wintun session handle obtained with WintunStartSession **Returns** Pointer to receive event handle to wait for available data when reading. Should WintunReceivePackets return ERROR\_NO\_MORE\_ITEMS (after spinning on it for a while under heavy load), wait for this event to become signaled before retrying WintunReceivePackets. Do not call CloseHandle on this event - it is managed by the session. #### WintunReceivePacket() `BYTE* WintunReceivePacket (WINTUN_SESSION_HANDLE Session, DWORD * PacketSize)` Retrieves one or packet. After the packet content is consumed, call WintunReleaseReceivePacket with Packet returned from this function to release internal buffer. This function is thread-safe. **Parameters** - *Session*: Wintun session handle obtained with WintunStartSession - *PacketSize*: Pointer to receive packet size. **Returns** Pointer to layer 3 IPv4 or IPv6 packet. Client may modify its content at will. If the function fails, the return value is NULL. To get extended error information, call GetLastError. Possible errors include the following: ERROR\_HANDLE\_EOF Wintun adapter is terminating; ERROR\_NO\_MORE\_ITEMS Wintun buffer is exhausted; ERROR\_INVALID\_DATA Wintun buffer is corrupt #### WintunReleaseReceivePacket() `void WintunReleaseReceivePacket (WINTUN_SESSION_HANDLE Session, const BYTE * Packet)` Releases internal buffer after the received packet has been processed by the client. This function is thread-safe. **Parameters** - *Session*: Wintun session handle obtained with WintunStartSession - *Packet*: Packet obtained with WintunReceivePacket #### WintunAllocateSendPacket() `BYTE* WintunAllocateSendPacket (WINTUN_SESSION_HANDLE Session, DWORD PacketSize)` Allocates memory for a packet to send. After the memory is filled with packet data, call WintunSendPacket to send and release internal buffer. WintunAllocateSendPacket is thread-safe and the WintunAllocateSendPacket order of calls define the packet sending order. **Parameters** - *Session*: Wintun session handle obtained with WintunStartSession - *PacketSize*: Exact packet size. Must be less or equal to WINTUN\_MAX\_IP\_PACKET\_SIZE. **Returns** Returns pointer to memory where to prepare layer 3 IPv4 or IPv6 packet for sending. If the function fails, the return value is NULL. To get extended error information, call GetLastError. Possible errors include the following: ERROR\_HANDLE\_EOF Wintun adapter is terminating; ERROR\_BUFFER\_OVERFLOW Wintun buffer is full; #### WintunSendPacket() `void WintunSendPacket (WINTUN_SESSION_HANDLE Session, const BYTE * Packet)` Sends the packet and releases internal buffer. WintunSendPacket is thread-safe, but the WintunAllocateSendPacket order of calls define the packet sending order. This means the packet is not guaranteed to be sent in the WintunSendPacket yet. **Parameters** - *Session*: Wintun session handle obtained with WintunStartSession - *Packet*: Packet obtained with WintunAllocateSendPacket ## Building **Do not distribute drivers or files named "Wintun", as they will most certainly clash with official deployments. Instead distribute [`wintun.dll` as downloaded from wintun.net](https://www.wintun.net).** General requirements: - [Visual Studio 2019](https://visualstudio.microsoft.com/downloads/) with Windows SDK - [Windows Driver Kit](https://docs.microsoft.com/en-us/windows-hardware/drivers/download-the-wdk) `wintun.sln` may be opened in Visual Studio for development and building. Be sure to run `bcdedit /set testsigning on` and then reboot before to enable unsigned driver loading. The default run sequence (F5) in Visual Studio will build the example project and its dependencies. ## License The entire contents of [the repository](https://git.zx2c4.com/wintun/), including all documentation and example code, is "Copyright © 2018-2021 WireGuard LLC. All Rights Reserved." Source code is licensed under the [GPLv2](COPYING). Prebuilt binaries from [wintun.net](https://www.wintun.net/) are released under a more permissive license suitable for more forms of software contained inside of the .zip files distributed there. ================================================ FILE: dist/windows/wintun/include/wintun.h ================================================ /* SPDX-License-Identifier: GPL-2.0 OR MIT * * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. */ #pragma once #include #include #include #include #include #ifdef __cplusplus extern "C" { #endif #ifndef ALIGNED # if defined(_MSC_VER) # define ALIGNED(n) __declspec(align(n)) # elif defined(__GNUC__) # define ALIGNED(n) __attribute__((aligned(n))) # else # error "Unable to define ALIGNED" # endif #endif /* MinGW is missing this one, unfortunately. */ #ifndef _Post_maybenull_ # define _Post_maybenull_ #endif #pragma warning(push) #pragma warning(disable : 4324) /* structure was padded due to alignment specifier */ /** * A handle representing Wintun adapter */ typedef struct _WINTUN_ADAPTER *WINTUN_ADAPTER_HANDLE; /** * Creates a new Wintun adapter. * * @param Name The requested name of the adapter. Zero-terminated string of up to MAX_ADAPTER_NAME-1 * characters. * * @param TunnelType Name of the adapter tunnel type. Zero-terminated string of up to MAX_ADAPTER_NAME-1 * characters. * * @param RequestedGUID The GUID of the created network adapter, which then influences NLA generation deterministically. * If it is set to NULL, the GUID is chosen by the system at random, and hence a new NLA entry is * created for each new adapter. It is called "requested" GUID because the API it uses is * completely undocumented, and so there could be minor interesting complications with its usage. * * @return If the function succeeds, the return value is the adapter handle. Must be released with * WintunCloseAdapter. If the function fails, the return value is NULL. To get extended error information, call * GetLastError. */ typedef _Must_inspect_result_ _Return_type_success_(return != NULL) _Post_maybenull_ WINTUN_ADAPTER_HANDLE(WINAPI WINTUN_CREATE_ADAPTER_FUNC) (_In_z_ LPCWSTR Name, _In_z_ LPCWSTR TunnelType, _In_opt_ const GUID *RequestedGUID); /** * Opens an existing Wintun adapter. * * @param Name The requested name of the adapter. Zero-terminated string of up to MAX_ADAPTER_NAME-1 * characters. * * @return If the function succeeds, the return value is the adapter handle. Must be released with * WintunCloseAdapter. If the function fails, the return value is NULL. To get extended error information, call * GetLastError. */ typedef _Must_inspect_result_ _Return_type_success_(return != NULL) _Post_maybenull_ WINTUN_ADAPTER_HANDLE(WINAPI WINTUN_OPEN_ADAPTER_FUNC)(_In_z_ LPCWSTR Name); /** * Releases Wintun adapter resources and, if adapter was created with WintunCreateAdapter, removes adapter. * * @param Adapter Adapter handle obtained with WintunCreateAdapter or WintunOpenAdapter. */ typedef VOID(WINAPI WINTUN_CLOSE_ADAPTER_FUNC)(_In_opt_ WINTUN_ADAPTER_HANDLE Adapter); /** * Deletes the Wintun driver if there are no more adapters in use. * * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To * get extended error information, call GetLastError. */ typedef _Return_type_success_(return != FALSE) BOOL(WINAPI WINTUN_DELETE_DRIVER_FUNC)(VOID); /** * Returns the LUID of the adapter. * * @param Adapter Adapter handle obtained with WintunCreateAdapter or WintunOpenAdapter * * @param Luid Pointer to LUID to receive adapter LUID. */ typedef VOID(WINAPI WINTUN_GET_ADAPTER_LUID_FUNC)(_In_ WINTUN_ADAPTER_HANDLE Adapter, _Out_ NET_LUID *Luid); /** * Determines the version of the Wintun driver currently loaded. * * @return If the function succeeds, the return value is the version number. If the function fails, the return value is * zero. To get extended error information, call GetLastError. Possible errors include the following: * ERROR_FILE_NOT_FOUND Wintun not loaded */ typedef _Return_type_success_(return != 0) DWORD(WINAPI WINTUN_GET_RUNNING_DRIVER_VERSION_FUNC)(VOID); /** * Determines the level of logging, passed to WINTUN_LOGGER_CALLBACK. */ typedef enum { WINTUN_LOG_INFO, /**< Informational */ WINTUN_LOG_WARN, /**< Warning */ WINTUN_LOG_ERR /**< Error */ } WINTUN_LOGGER_LEVEL; /** * Called by internal logger to report diagnostic messages * * @param Level Message level. * * @param Timestamp Message timestamp in in 100ns intervals since 1601-01-01 UTC. * * @param Message Message text. */ typedef VOID(CALLBACK *WINTUN_LOGGER_CALLBACK)( _In_ WINTUN_LOGGER_LEVEL Level, _In_ DWORD64 Timestamp, _In_z_ LPCWSTR Message); /** * Sets logger callback function. * * @param NewLogger Pointer to callback function to use as a new global logger. NewLogger may be called from various * threads concurrently. Should the logging require serialization, you must handle serialization in * NewLogger. Set to NULL to disable. */ typedef VOID(WINAPI WINTUN_SET_LOGGER_FUNC)(_In_ WINTUN_LOGGER_CALLBACK NewLogger); /** * Minimum ring capacity. */ #define WINTUN_MIN_RING_CAPACITY 0x20000 /* 128kiB */ /** * Maximum ring capacity. */ #define WINTUN_MAX_RING_CAPACITY 0x4000000 /* 64MiB */ /** * A handle representing Wintun session */ typedef struct _TUN_SESSION *WINTUN_SESSION_HANDLE; /** * Starts Wintun session. * * @param Adapter Adapter handle obtained with WintunOpenAdapter or WintunCreateAdapter * * @param Capacity Rings capacity. Must be between WINTUN_MIN_RING_CAPACITY and WINTUN_MAX_RING_CAPACITY (incl.) * Must be a power of two. * * @return Wintun session handle. Must be released with WintunEndSession. If the function fails, the return value is * NULL. To get extended error information, call GetLastError. */ typedef _Must_inspect_result_ _Return_type_success_(return != NULL) _Post_maybenull_ WINTUN_SESSION_HANDLE(WINAPI WINTUN_START_SESSION_FUNC)(_In_ WINTUN_ADAPTER_HANDLE Adapter, _In_ DWORD Capacity); /** * Ends Wintun session. * * @param Session Wintun session handle obtained with WintunStartSession */ typedef VOID(WINAPI WINTUN_END_SESSION_FUNC)(_In_ WINTUN_SESSION_HANDLE Session); /** * Gets Wintun session's read-wait event handle. * * @param Session Wintun session handle obtained with WintunStartSession * * @return Pointer to receive event handle to wait for available data when reading. Should * WintunReceivePackets return ERROR_NO_MORE_ITEMS (after spinning on it for a while under heavy * load), wait for this event to become signaled before retrying WintunReceivePackets. Do not call * CloseHandle on this event - it is managed by the session. */ typedef HANDLE(WINAPI WINTUN_GET_READ_WAIT_EVENT_FUNC)(_In_ WINTUN_SESSION_HANDLE Session); /** * Maximum IP packet size */ #define WINTUN_MAX_IP_PACKET_SIZE 0xFFFF /** * Retrieves one or packet. After the packet content is consumed, call WintunReleaseReceivePacket with Packet returned * from this function to release internal buffer. This function is thread-safe. * * @param Session Wintun session handle obtained with WintunStartSession * * @param PacketSize Pointer to receive packet size. * * @return Pointer to layer 3 IPv4 or IPv6 packet. Client may modify its content at will. If the function fails, the * return value is NULL. To get extended error information, call GetLastError. Possible errors include the * following: * ERROR_HANDLE_EOF Wintun adapter is terminating; * ERROR_NO_MORE_ITEMS Wintun buffer is exhausted; * ERROR_INVALID_DATA Wintun buffer is corrupt */ typedef _Must_inspect_result_ _Return_type_success_(return != NULL) _Post_maybenull_ _Post_writable_byte_size_(*PacketSize) BYTE *(WINAPI WINTUN_RECEIVE_PACKET_FUNC)(_In_ WINTUN_SESSION_HANDLE Session, _Out_ DWORD *PacketSize); /** * Releases internal buffer after the received packet has been processed by the client. This function is thread-safe. * * @param Session Wintun session handle obtained with WintunStartSession * * @param Packet Packet obtained with WintunReceivePacket */ typedef VOID( WINAPI WINTUN_RELEASE_RECEIVE_PACKET_FUNC)(_In_ WINTUN_SESSION_HANDLE Session, _In_ const BYTE *Packet); /** * Allocates memory for a packet to send. After the memory is filled with packet data, call WintunSendPacket to send * and release internal buffer. WintunAllocateSendPacket is thread-safe and the WintunAllocateSendPacket order of * calls define the packet sending order. * * @param Session Wintun session handle obtained with WintunStartSession * * @param PacketSize Exact packet size. Must be less or equal to WINTUN_MAX_IP_PACKET_SIZE. * * @return Returns pointer to memory where to prepare layer 3 IPv4 or IPv6 packet for sending. If the function fails, * the return value is NULL. To get extended error information, call GetLastError. Possible errors include the * following: * ERROR_HANDLE_EOF Wintun adapter is terminating; * ERROR_BUFFER_OVERFLOW Wintun buffer is full; */ typedef _Must_inspect_result_ _Return_type_success_(return != NULL) _Post_maybenull_ _Post_writable_byte_size_(PacketSize) BYTE *(WINAPI WINTUN_ALLOCATE_SEND_PACKET_FUNC)(_In_ WINTUN_SESSION_HANDLE Session, _In_ DWORD PacketSize); /** * Sends the packet and releases internal buffer. WintunSendPacket is thread-safe, but the WintunAllocateSendPacket * order of calls define the packet sending order. This means the packet is not guaranteed to be sent in the * WintunSendPacket yet. * * @param Session Wintun session handle obtained with WintunStartSession * * @param Packet Packet obtained with WintunAllocateSendPacket */ typedef VOID(WINAPI WINTUN_SEND_PACKET_FUNC)(_In_ WINTUN_SESSION_HANDLE Session, _In_ const BYTE *Packet); #pragma warning(pop) #ifdef __cplusplus } #endif ================================================ FILE: dist/wireshark/nebula.lua ================================================ local nebula = Proto("nebula", "nebula") local default_settings = { port = 4242, all_ports = false, } nebula.prefs.port = Pref.uint("Port number", default_settings.port, "The UDP port number for Nebula") nebula.prefs.all_ports = Pref.bool("All ports", default_settings.all_ports, "Assume nebula packets on any port, useful when dealing with hole punching") local pf_version = ProtoField.new("version", "nebula.version", ftypes.UINT8, nil, base.DEC, 0xF0) local pf_type = ProtoField.new("type", "nebula.type", ftypes.UINT8, { [0] = "handshake", [1] = "message", [2] = "recvError", [3] = "lightHouse", [4] = "test", [5] = "closeTunnel", }, base.DEC, 0x0F) local pf_subtype = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, nil, base.DEC) local pf_subtype_test = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, { [0] = "request", [1] = "reply", }, base.DEC) local pf_subtype_handshake = ProtoField.new("subtype", "nebula.subtype", ftypes.UINT8, { [0] = "ix_psk0", }, base.DEC) local pf_reserved = ProtoField.new("reserved", "nebula.reserved", ftypes.UINT16, nil, base.HEX) local pf_remote_index = ProtoField.new("remote index", "nebula.remote_index", ftypes.UINT32, nil, base.DEC) local pf_message_counter = ProtoField.new("counter", "nebula.counter", ftypes.UINT64, nil, base.DEC) local pf_payload = ProtoField.new("payload", "nebula.payload", ftypes.BYTES, nil, base.NONE) nebula.fields = { pf_version, pf_type, pf_subtype, pf_subtype_handshake, pf_subtype_test, pf_reserved, pf_remote_index, pf_message_counter, pf_payload } local ef_holepunch = ProtoExpert.new("nebula.holepunch.expert", "Nebula hole punch packet", expert.group.PROTOCOL, expert.severity.NOTE) local ef_punchy = ProtoExpert.new("nebula.punchy.expert", "Nebula punchy keepalive packet", expert.group.PROTOCOL, expert.severity.NOTE) nebula.experts = { ef_holepunch, ef_punchy } local type_field = Field.new("nebula.type") local subtype_field = Field.new("nebula.subtype") function nebula.dissector(tvbuf, pktinfo, root) -- set the protocol column to show our protocol name pktinfo.cols.protocol:set("NEBULA") local pktlen = tvbuf:reported_length_remaining() local tree = root:add(nebula, tvbuf:range(0,pktlen)) if pktlen == 0 then tree:add_proto_expert_info(ef_holepunch) pktinfo.cols.info:append(" (holepunch)") return elseif pktlen == 1 then tree:add_proto_expert_info(ef_punchy) pktinfo.cols.info:append(" (punchy)") return end tree:add(pf_version, tvbuf:range(0,1)) local type = tree:add(pf_type, tvbuf:range(0,1)) local nebula_type = bit32.band(tvbuf:range(0,1):uint(), 0x0F) if nebula_type == 0 then local stage = tvbuf(8,8):uint64() tree:add(pf_subtype_handshake, tvbuf:range(1,1)) type:append_text(" stage " .. stage) pktinfo.cols.info:append(" (" .. type_field().display .. ", stage " .. stage .. ", " .. subtype_field().display .. ")") elseif nebula_type == 4 then tree:add(pf_subtype_test, tvbuf:range(1,1)) pktinfo.cols.info:append(" (" .. type_field().display .. ", " .. subtype_field().display .. ")") else tree:add(pf_subtype, tvbuf:range(1,1)) pktinfo.cols.info:append(" (" .. type_field().display .. ")") end tree:add(pf_reserved, tvbuf:range(2,2)) tree:add(pf_remote_index, tvbuf:range(4,4)) tree:add(pf_message_counter, tvbuf:range(8,8)) tree:add(pf_payload, tvbuf:range(16,tvbuf:len() - 16)) end function nebula.prefs_changed() if default_settings.all_ports == nebula.prefs.all_ports and default_settings.port == nebula.prefs.port then -- Nothing changed, bail return end -- Remove our old dissector DissectorTable.get("udp.port"):remove_all(nebula) if nebula.prefs.all_ports and default_settings.all_ports ~= nebula.prefs.all_ports then default_settings.all_port = nebula.prefs.all_ports for i=0, 65535 do DissectorTable.get("udp.port"):add(i, nebula) end -- no need to establish again on specific ports return end if default_settings.all_ports ~= nebula.prefs.all_ports then -- Add our new port dissector default_settings.port = nebula.prefs.port DissectorTable.get("udp.port"):add(default_settings.port, nebula) end end DissectorTable.get("udp.port"):add(default_settings.port, nebula) ================================================ FILE: dns_server.go ================================================ package nebula import ( "fmt" "net" "net/netip" "strconv" "strings" "sync" "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) // This whole thing should be rewritten to use context var dnsR *dnsRecords var dnsServer *dns.Server var dnsAddr string type dnsRecords struct { sync.RWMutex l *logrus.Logger dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap myVpnAddrsTable *bart.Lite } func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ l: l, dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), hostMap: hostMap, myVpnAddrsTable: cs.myVpnAddrsTable, } } func (d *dnsRecords) Query(q uint16, data string) netip.Addr { data = strings.ToLower(data) d.RLock() defer d.RUnlock() switch q { case dns.TypeA: if r, ok := d.dnsMap4[data]; ok { return r } case dns.TypeAAAA: if r, ok := d.dnsMap6[data]; ok { return r } } return netip.Addr{} } func (d *dnsRecords) QueryCert(data string) string { ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" } hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" } q := hostinfo.GetCert() if q == nil { return "" } b, err := q.Certificate.MarshalJSON() if err != nil { return "" } return string(b) } // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` func (d *dnsRecords) Add(host string, addresses []netip.Addr) { host = strings.ToLower(host) d.Lock() defer d.Unlock() haveV4 := false haveV6 := false for _, addr := range addresses { if addr.Is4() && !haveV4 { d.dnsMap4[host] = addr haveV4 = true } else if addr.Is6() && !haveV6 { d.dnsMap6[host] = addr haveV6 = true } if haveV4 && haveV6 { break } } } func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { a, _, _ := net.SplitHostPort(addr) b, err := netip.ParseAddr(a) if err != nil { return false } if b.IsLoopback() { return true } //if we found it in this table, it's good return d.myVpnAddrsTable.Contains(b) } func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] d.l.Debugf("Query for %s %s", qType, q.Name) ip := d.Query(q.Qtype, q.Name) if ip.IsValid() { rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } case dns.TypeTXT: // We only answer these queries from nebula nodes or localhost if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } d.l.Debugf("Query for TXT %s", q.Name) ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } } } if len(m.Answer) == 0 { m.Rcode = dns.RcodeNameError } } func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: d.parseQuery(m, w) } w.WriteMsg(m) } func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) }) return func() { startDns(l, c) } } func getDnsServerAddr(c *config.C) string { dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. if dnsHost == "[::]" { dnsHost = "::" } return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } func startDns(l *logrus.Logger, c *config.C) { dnsAddr = getDnsServerAddr(c) dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") err := dnsServer.ListenAndServe() defer dnsServer.Shutdown() if err != nil { l.Errorf("Failed to start server: %s\n ", err.Error()) } } func reloadDns(l *logrus.Logger, c *config.C) { if dnsAddr == getDnsServerAddr(c) { l.Debug("No DNS server config change detected") return } l.Debug("Restarting DNS server") dnsServer.Shutdown() go startDns(l, c) } ================================================ FILE: dns_server_test.go ================================================ package nebula import ( "net/netip" "testing" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { l := logrus.New() hostMap := &HostMap{} ds := newDnsRecords(l, &CertState{}, hostMap) addrs := []netip.Addr{ netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5"), netip.MustParseAddr("fd01::24"), netip.MustParseAddr("fd01::25"), } ds.Add("test.com.com", addrs) m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) m = &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeAAAA) ds.parseQuery(m, nil) assert.NotNil(t, m.Answer) assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } func Test_getDnsServerAddr(t *testing.T) { c := config.NewC(nil) c.Settings["lighthouse"] = map[string]any{ "dns": map[string]any{ "host": "0.0.0.0", "port": "1", }, } assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c)) c.Settings["lighthouse"] = map[string]any{ "dns": map[string]any{ "host": "::", "port": "1", }, } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) c.Settings["lighthouse"] = map[string]any{ "dns": map[string]any{ "host": "[::]", "port": "1", }, } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) // Make sure whitespace doesn't mess us up c.Settings["lighthouse"] = map[string]any{ "dns": map[string]any{ "host": "[::] ", "port": "1", }, } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) } ================================================ FILE: docker/Dockerfile ================================================ FROM gcr.io/distroless/static:latest ARG TARGETOS TARGETARCH COPY build/$TARGETOS-$TARGETARCH/nebula /nebula COPY build/$TARGETOS-$TARGETARCH/nebula-cert /nebula-cert VOLUME ["/config"] ENTRYPOINT ["/nebula"] # Allow users to override the args passed to nebula CMD ["-config", "/config/config.yml"] ================================================ FILE: docker/README.md ================================================ # NebulaOSS/nebula Docker Image ## Building From the root of the repository, run `make docker`. ## Running To run the built image, use the following command: ``` docker run \ --name nebula \ --network host \ --cap-add NET_ADMIN \ --volume ./config:/config \ --rm \ nebulaoss/nebula ``` A few notes: - The `NET_ADMIN` capability is necessary to create the tun adapter on the host (this is unnecessary if the tun device is disabled.) - `--volume ./config:/config` should point to a directory that contains your `config.yml` and any other necessary files. ================================================ FILE: e2e/doc.go ================================================ package e2e // This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before ================================================ FILE: e2e/handshakes_test.go ================================================ //go:build e2e_testing // +build e2e_testing package e2e import ( "net/netip" "slices" "testing" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" ) func BenchmarkHotPath(b *testing.B) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() theirControl.Start() r := router.NewR(b, myControl, theirControl) r.CancelFlowLogs() assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) b.ResetTimer() for n := 0; n < b.N; n++ { myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } myControl.Stop() theirControl.Stop() } func BenchmarkHotPathRelay(b *testing.B) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(b, myControl, relayControl, theirControl) r.CancelFlowLogs() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) b.ResetTimer() for n := 0; n < b.N; n++ { myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } myControl.Stop() theirControl.Stop() relayControl.Stop() } func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) t.Log("Get their stage 1 packet so that we can play with it") stage1Packet := theirControl.GetFromUDP(true) t.Log("I consume a garbage packet with a proper nebula header for our tunnel") // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel badPacket := stage1Packet.Copy() badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len] myControl.InjectUDPPacket(badPacket) t.Log("Have me consume their real stage 1 packet. I have a tunnel now") myControl.InjectUDPPacket(stage1Packet) t.Log("Wait until we see my cached packet come through") myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestGoodHandshakeNoOverlap(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack! // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() empty := []byte{} t.Log("do something to cause a handshake") myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) t.Log("Get their stage 1 packet") stage1Packet := theirControl.GetFromUDP(true) t.Log("Have me consume their stage 1 packet. I have a tunnel now") myControl.InjectUDPPacket(stage1Packet) t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete") myControl.WaitForType(header.Test, 0, theirControl) t.Log("Make sure our host infos are correct") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestWrongResponderHandshake(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { err := h.Parse(p.Data) if err != nil { panic(err) } if h.Type == header.CloseTunnel && p.To == evilUdpAddr { return router.RouteAndExit } return router.KeepRouting }) t.Log("Evil tunnel is closed, inject the correct udp addr for them") myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) t.Log("Route until we see the cached packet") r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { err := h.Parse(p.Data) if err != nil { panic(err) } if p.To == theirUdpAddr && h.Type == 1 { return router.RouteAndExit } return router.KeepRouting }) t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() theirControl.Stop() } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) o := m{ "static_host_map": m{ theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()}, }, } myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o) // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { err := h.Parse(p.Data) if err != nil { panic(err) } if h.Type == header.CloseTunnel && p.To == evilUdpAddr { return router.RouteAndExit } return router.KeepRouting }) t.Log("Evil tunnel is closed, inject the correct udp addr for them") myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) t.Log("Route until we see the cached packet") r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { err := h.Parse(p.Data) if err != nil { panic(err) } if p.To == theirUdpAddr && h.Type == 1 { return router.RouteAndExit } return router.KeepRouting }) t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) t.Log("Success!") myControl.Stop() theirControl.Stop() } func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() t.Log("Trigger a handshake to start on both me and them") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) theirHsForMe := theirControl.GetFromUDP(true) r.Log("Now inject both stage 1 handshake packets") r.InjectUDPPacket(theirControl, myControl, theirHsForMe) r.InjectUDPPacket(myControl, theirControl, myHsForThem) r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) theirHostmapHosts := theirControl.ListHostmapHosts(false) theirHostmapIndexes := theirControl.ListHostmapIndexes(false) // We should have two tunnels on both sides assert.Len(t, myHostmapHosts, 1) assert.Len(t, theirHostmapHosts, 1) assert.Len(t, myHostmapIndexes, 2) assert.Len(t, theirHostmapIndexes, 2) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) assert.Len(t, theirFinalHostmapHosts, 1) assert.Len(t, myFinalHostmapIndexes, 1) assert.Len(t, theirFinalHostmapIndexes, 1) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestUncleanShutdownRaceLoser(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() r.Log("Trigger a handshake from me to them") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Assert the tunnel works") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } time.Sleep(time.Second) } r.RenderHostmaps("Final hostmaps", myControl, theirControl) } func TestUncleanShutdownRaceWinner(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() r.Log("Trigger a handshake from me to them") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") theirHostmap := theirControl.GetHostmap() theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } time.Sleep(time.Second) } r.RenderHostmaps("Final hostmaps", myControl, theirControl) } func TestRelays(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) } func TestRelaysDontCareAboutIps(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) } func TestReestablishRelays(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Ensure packet traversal from them to me via the relay") theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from them"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) // If we break the relay's connection to 'them', 'me' needs to detect and recover the connection r.Log("Close the tunnel") relayControl.CloseTunnel(theirVpnIpNet[0].Addr(), true) start := len(myControl.GetHostmap().Indexes) curIndexes := len(myControl.GetHostmap().Indexes) for curIndexes >= start { curIndexes = len(myControl.GetHostmap().Indexes) r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { return router.RouteAndExit }) time.Sleep(2 * time.Second) } r.Log("Dead index went away. Woot!") r.RenderHostmaps("Me removed hostinfo", myControl, relayControl, theirControl) // Next packet should re-establish a relayed connection and work just great. t.Logf("Assert the tunnel...") for { t.Log("RouteForAllUntilTxTun") myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p = r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if slices.Compare(v4.SrcIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") continue } if slices.Compare(v4.DstIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { t.Logf("DstIP is unexpected...this is not the packet I'm looking for. Keep looking") continue } udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) if udp == nil { t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") continue } data := packet.ApplicationLayer() if data == nil { t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") continue } if string(data.Payload()) != "Hi from me" { t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) continue } t.Log("I found my lost packet. I am so happy.") break } t.Log("Assert the tunnel works the other way, too") for { t.Log("RouteForAllUntilTxTun") theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if slices.Compare(v4.DstIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { t.Logf("Dst is unexpected...this is not the packet I'm looking for. Keep looking") continue } if slices.Compare(v4.SrcIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") continue } udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) if udp == nil { t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") continue } data := packet.ApplicationLayer() if data == nil { t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") continue } if string(data.Payload()) != "Hi from them" { t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) continue } t.Log("I found my lost packet. I am so happy.") break } r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() r.Log("Get a tunnel between me and relay") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) _ = p r.FlushAll() myControl.Stop() theirControl.Stop() relayControl.Stop() } func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) r.Log("Wait for a packet from them to me") l.Info("Wait for a packet from them to me; myControl") r.RouteForAllUntilTxTun(myControl) l.Info("Wait for a packet from them to me; theirControl") r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") l.WithFields( logrus.Fields{ "myControl": len(myControl.GetHostmap().Indexes), "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") hostInfos := len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) retries := 60 for hostInfos > 6 && retries > 0 { hostInfos = len(myControl.GetHostmap().Indexes) + len(theirControl.GetHostmap().Indexes) + len(relayControl.GetHostmap().Indexes) l.WithFields( logrus.Fields{ "myControl": len(myControl.GetHostmap().Indexes), "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- } r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() relayControl.Stop() } func TestRehandshakingRelays(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { panic(err) } relayConfig.Settings["pki"] = m{ "ca": string(caB), "cert": string(myNextPEM), "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break } time.Sleep(time.Second) } for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break } time.Sleep(time.Second) } r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("myControl hostinfos got cleaned up!") for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("theirControl hostinfos got cleaned up!") for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("relayControl hostinfos got cleaned up!") } func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { panic(err) } relayConfig.Settings["pki"] = m{ "ca": string(caB), "cert": string(myNextPEM), "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break } time.Sleep(time.Second) } for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break } time.Sleep(time.Second) } r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("myControl hostinfos got cleaned up!") for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("theirControl hostinfos got cleaned up!") for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } t.Logf("relayControl hostinfos got cleaned up!") } func TestRehandshaking(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() t.Log("Stand up a tunnel between me and them") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { panic(err) } myConfig.Settings["pki"] = m{ "ca": string(caB), "cert": string(myNextPEM), "key": string(myNextPrivKey), } rc, err := yaml.Marshal(myConfig.Settings) require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now break } time.Sleep(time.Second) } r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) require.NoError(t, err) var theirNewConfig m require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) theirFirewall := theirNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", "group": "new group", }} rc, err = yaml.Marshal(theirNewConfig) require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) assert.Len(t, theirFinalHostmapHosts, 1) assert.Len(t, myFinalHostmapIndexes, 1) assert.Len(t, theirFinalHostmapIndexes, 1) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() theirControl.Start() t.Log("Stand up a tunnel between me and them") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalPEM() if err != nil { panic(err) } theirConfig.Settings["pki"] = m{ "ca": string(caB), "cert": string(theirNextPEM), "key": string(theirNextPrivKey), } rc, err := yaml.Marshal(theirConfig.Settings) require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break } time.Sleep(time.Second) } // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(myConfig.Settings) require.NoError(t, err) var myNewConfig m require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) theirFirewall := myNewConfig["firewall"].(map[string]any) theirFirewall["inbound"] = []m{{ "proto": "any", "port": "any", "group": "their new group", }} rc, err = yaml.Marshal(myNewConfig) require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) assert.Len(t, theirFinalHostmapHosts, 1) assert.Len(t, myFinalHostmapIndexes, 1) assert.Len(t, theirFinalHostmapIndexes, 1) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() theirControl.Start() //them rx stage:1 initiatorIndex=642843150 responderIndex=0 //me rx stage:1 initiatorIndex=120607833 responderIndex=0 //them rx stage:1 initiatorIndex=642843150 responderIndex=0 //me rx stage:2 initiatorIndex=642843150 responderIndex=3701775874 //me rx stage:1 initiatorIndex=120607833 responderIndex=0 //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) theirStage1ForMe := theirControl.GetFromUDP(true) t.Log("Inject them in a special way") theirControl.InjectUDPPacket(myStage1ForThem) myControl.InjectUDPPacket(theirStage1ForMe) theirControl.InjectUDPPacket(myStage1ForThem) t.Log("Get both stage 2") myStage2ForThem := myControl.GetFromUDP(true) theirStage2ForMe := theirControl.GetFromUDP(true) t.Log("Inject them in a special way again") myControl.InjectUDPPacket(theirStage2ForMe) myControl.InjectUDPPacket(theirStage1ForMe) theirControl.InjectUDPPacket(myStage2ForThem) r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() t.Log("Flush the packets") r.RouteForAllUntilTxTun(myControl) r.RouteForAllUntilTxTun(theirControl) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() } func TestV2NonPrimaryWithLighthouse(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) o := m{ "static_host_map": m{ lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()}, }, "lighthouse": m{ "hosts": []string{lhVpnIpNet[1].Addr().String()}, "local_allow_list": m{ // Try and block our lighthouse updates from using the actual addresses assigned to this computer // If we start discovering addresses the test router doesn't know about then test traffic cant flow "10.0.0.0/24": true, "::/0": false, }, }, } myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, lhControl, myControl, theirControl) defer r.RenderFlow() // Start the servers lhControl.Start() myControl.Start() theirControl.Start() t.Log("Stand up an ipv6 tunnel between me and them") assert.True(t, myVpnIpNet[1].Addr().Is6()) assert.True(t, theirVpnIpNet[1].Addr().Is6()) assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) lhControl.Stop() myControl.Stop() theirControl.Stop() } func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}}) o := m{ "static_host_map": m{ lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()}, }, "lighthouse": m{ "hosts": []string{lhVpnIpNet[0].Addr().String()}, "local_allow_list": m{ // Try and block our lighthouse updates from using the actual addresses assigned to this computer // If we start discovering addresses the test router doesn't know about then test traffic cant flow "10.0.0.0/24": true, "::/0": false, }, }, } myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, lhControl, myControl, theirControl) defer r.RenderFlow() // Start the servers lhControl.Start() myControl.Start() theirControl.Start() t.Log("Stand up an ipv6 tunnel between me and them") assert.True(t, myVpnIpNet[1].Addr().Is6()) assert.True(t, theirVpnIpNet[1].Addr().Is6()) assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) lhControl.Stop() myControl.Stop() theirControl.Stop() } func TestGoodHandshakeUnsafeDest(t *testing.T) { unsafePrefix := "192.168.6.0/24" ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil) route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()} myCfg := m{ "tun": m{ "unsafe_routes": []m{route}, }, } myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg) t.Logf("my config %v", myConfig) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) spookyDest := netip.MustParseAddr("192.168.6.4") // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) t.Log("Get their stage 1 packet so that we can play with it") stage1Packet := theirControl.GetFromUDP(true) t.Log("I consume a garbage packet with a proper nebula header for our tunnel") // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel badPacket := stage1Packet.Copy() badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len] myControl.InjectUDPPacket(badPacket) t.Log("Have me consume their real stage 1 packet. I have a tunnel now") myControl.InjectUDPPacket(stage1Packet) t.Log("Wait until we see my cached packet come through") myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) //reply theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) //wait for reply theirControl.WaitForType(1, 0, myControl) theirCachedPacket := myControl.GetFromTun(true) assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } ================================================ FILE: e2e/helpers_test.go ================================================ //go:build e2e_testing // +build e2e_testing package e2e import ( "fmt" "io" "net/netip" "os" "strings" "testing" "time" "dario.cat/mergo" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" ) type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) if err != nil { panic(err) } vpnNetworks = append(vpnNetworks, vpnIpNet) } if len(vpnNetworks) == 0 { panic("no vpn networks") } var udpAddr netip.AddrPort if vpnNetworks[0].Addr().Is4() { budpIp := vpnNetworks[0].Addr().As4() budpIp[1] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) } else { budpIp := vpnNetworks[0].Addr().As16() // beef for funsies budpIp[2] = 190 budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) } func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides) } func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) if err != nil { panic(err) } vpnNetworks = append(vpnNetworks, vpnIpNet) } if len(vpnNetworks) == 0 { panic("no vpn networks") } firewallInbound := []m{{ "proto": "any", "port": "any", "host": "any", }} var unsafeNetworks []netip.Prefix if sUnsafeNetworks != "" { firewallInbound = []m{{ "proto": "any", "port": "any", "host": "any", "local_cidr": "0.0.0.0/0", }} for _, sn := range strings.Split(sUnsafeNetworks, ",") { x, err := netip.ParsePrefix(strings.TrimSpace(sn)) if err != nil { panic(err) } unsafeNetworks = append(unsafeNetworks, x) } } _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } mc := m{ "pki": m{ "ca": string(caB), "cert": string(myPEM), "key": string(myPrivKey), }, //"tun": m{"disabled": true}, "firewall": m{ "outbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, "inbound": firewallInbound, }, //"handshakes": m{ // "try_interval": "1s", //}, "listen": m{ "host": udpAddr.Addr().String(), "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), "level": l.Level.String(), }, "timers": m{ "pending_deletion_interval": 2, "connection_alive_interval": 2, }, } if overrides != nil { final := m{} err = mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } err = mergo.Merge(&final, mc, mergo.WithAppendSlice) if err != nil { panic(err) } mc = final } cb, err := yaml.Marshal(mc) if err != nil { panic(err) } c := config.NewC(l) c.LoadString(string(cb)) control, err := nebula.Main(c, false, "e2e-test", l, nil) if err != nil { panic(err) } return control, vpnNetworks, udpAddr, c } // newServer creates a nebula instance with fewer assumptions func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() vpnNetworks := certs[len(certs)-1].Networks() var udpAddr netip.AddrPort if vpnNetworks[0].Addr().Is4() { budpIp := vpnNetworks[0].Addr().As4() budpIp[1] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) } else { budpIp := vpnNetworks[0].Addr().As16() // beef for funsies budpIp[2] = 190 budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } caStr := "" for _, ca := range caCrt { x, err := ca.MarshalPEM() if err != nil { panic(err) } caStr += string(x) } certStr := "" for _, c := range certs { x, err := c.MarshalPEM() if err != nil { panic(err) } certStr += string(x) } mc := m{ "pki": m{ "ca": caStr, "cert": certStr, "key": string(key), }, //"tun": m{"disabled": true}, "firewall": m{ "outbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, "inbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, }, //"handshakes": m{ // "try_interval": "1s", //}, "listen": m{ "host": udpAddr.Addr().String(), "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), "level": l.Level.String(), }, "timers": m{ "pending_deletion_interval": 2, "connection_alive_interval": 2, }, } if overrides != nil { final := m{} err := mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } err = mergo.Merge(&final, mc, mergo.WithAppendSlice) if err != nil { panic(err) } mc = final } cb, err := yaml.Marshal(mc) if err != nil { panic(err) } c := config.NewC(l) cStr := string(cb) c.LoadString(cStr) control, err := nebula.Main(c, false, "e2e-test", l, nil) if err != nil { panic(err) } return control, vpnNetworks, udpAddr, c } type doneCb func() func deadline(t *testing.T, seconds time.Duration) doneCb { timeout := time.After(seconds * time.Second) done := make(chan bool) go func() { select { case <-timeout: t.Fatal("Test did not finish in time") case <-done: } }() return func() { done <- true } } func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } func assertHostInfoPair(t testing.TB, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { // Get both host infos //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") } func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { if toIp.Is6() { assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) } else { assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort) } } func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) assert.NotNil(t, v6, "No ipv6 data found") assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect") assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found") assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") data := packet.ApplicationLayer() assert.NotNil(t, data) assert.Equal(t, expected, data.Payload(), "Data was incorrect") } func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect") assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found") assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") data := packet.ApplicationLayer() assert.NotNil(t, data) assert.Equal(t, expected, data.Payload(), "Data was incorrect") } func getAddrs(ns []netip.Prefix) []netip.Addr { var a []netip.Addr for _, n := range ns { a = append(a, n.Addr()) } return a } func NewTestLogger() *logrus.Logger { l := logrus.New() v := os.Getenv("TEST_LOGS") if v == "" { l.SetOutput(io.Discard) l.SetLevel(logrus.PanicLevel) return l } switch v { case "2": l.SetLevel(logrus.DebugLevel) case "3": l.SetLevel(logrus.TraceLevel) default: l.SetLevel(logrus.InfoLevel) } return l } ================================================ FILE: e2e/router/doc.go ================================================ package router // This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before ================================================ FILE: e2e/router/hostmap.go ================================================ //go:build e2e_testing // +build e2e_testing package router import ( "fmt" "net/netip" "sort" "strings" "github.com/slackhq/nebula" ) type edge struct { from string to string dual bool } func renderHostmaps(controls ...*nebula.Control) string { var lines []*edge r := "graph TB\n" for _, c := range controls { sr, se := renderHostmap(c) r += sr for _, e := range se { add := true // Collapse duplicate edges into a bi-directionally connected edge for _, ge := range lines { if e.to == ge.from && e.from == ge.to { add = false ge.dual = true break } } if add { lines = append(lines, e) } } } for _, line := range lines { if line.dual { r += fmt.Sprintf("\t%v <--> %v\n", line.from, line.to) } else { r += fmt.Sprintf("\t%v --> %v\n", line.from, line.to) } } return r } func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge crt := c.GetCertState().GetDefaultCertificate() clusterName := strings.Trim(crt.Name(), " ") clusterVpnIp := crt.Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() hm.RLock() defer hm.RUnlock() // Draw the vpn to index nodes r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName) hosts := sortedHosts(hm.Hosts) for _, vpnIp := range hosts { hi := hm.Hosts[vpnIp] r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp) lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex())) rs := hi.GetRelayState() for _, relayIp := range rs.CopyRelayIps() { lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, relayIp)) } for _, relayIp := range rs.CopyRelayForIdxs() { lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, relayIp)) } } r += "\t\tend\n" // Draw the relay hostinfos if len(hm.Relays) > 0 { r += fmt.Sprintf("\t\tsubgraph %s.relays[\"Relays (relay index to hostinfo)\"]\n", clusterName) for relayIndex, hi := range hm.Relays { r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, relayIndex, relayIndex) lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, relayIndex, clusterName, hi.GetLocalIndex())) } r += "\t\tend\n" } // Draw the local index to relay or remote index nodes r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName) indexes := sortedIndexes(hm.Indexes) for _, idx := range indexes { hi, ok := hm.Indexes[idx] if ok { r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs()) remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi } } r += "\t\tend\n" // Add the edges inside this host for _, line := range lines { r += fmt.Sprintf("\t\t%v\n", line) } r += "\tend\n" return r, globalLines } func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr { keys := make([]netip.Addr, 0, len(hosts)) for key := range hosts { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { return keys[i].Compare(keys[j]) > 0 }) return keys } func sortedIndexes(indexes map[uint32]*nebula.HostInfo) []uint32 { keys := make([]uint32, 0, len(indexes)) for key := range indexes { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { return keys[i] > keys[j] }) return keys } ================================================ FILE: e2e/router/router.go ================================================ //go:build e2e_testing // +build e2e_testing package router import ( "context" "fmt" "net/netip" "os" "path/filepath" "reflect" "regexp" "sort" "sync" "testing" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "golang.org/x/exp/maps" ) type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? controls map[netip.AddrPort]*nebula.Control // A map for inbound packets for a control that doesn't know about this address inNat map[netip.AddrPort]*nebula.Control // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver outNat map[string]netip.AddrPort // A map of vpn ip to the nebula control it belongs to vpnControls map[netip.Addr]*nebula.Control ignoreFlows []ignoreFlow flow []flowEntry // A set of additional mermaid graphs to draw in the flow log markdown file // Currently consisting only of hostmap renders additionalGraphs []mermaidGraph // All interactions are locked to help serialize behavior sync.Mutex fn string cancelRender context.CancelFunc t testing.TB } type ignoreFlow struct { tun NullBool messageType header.MessageType subType header.MessageSubType //from //to } type mermaidGraph struct { title string content string } type NullBool struct { HasValue bool IsTrue bool } type flowEntry struct { note string packet *packet } type packet struct { from *nebula.Control to *nebula.Control packet *udp.Packet tun bool // a packet pulled off a tun device rx bool // the packet was received by a udp device } func (p *packet) WasReceived() { if p != nil { p.rx = true } } type ExitType int const ( // KeepRouting the function will get called again on the next packet KeepRouting ExitType = 0 // ExitNow does not route this packet and exits immediately ExitNow ExitType = 1 // RouteAndExit routes this packet and exits immediately afterwards RouteAndExit ExitType = 2 ) type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType // NewR creates a new router to pass packets in a controlled fashion between the provided controllers. // The packet flow will be recorded in a file within the mermaid directory under the same name as the test. // Renders will occur automatically, roughly every 100ms, until a call to RenderFlow() is made func NewR(t testing.TB, controls ...*nebula.Control) *R { ctx, cancel := context.WithCancel(context.Background()) if err := os.MkdirAll("mermaid", 0755); err != nil { panic(err) } r := &R{ controls: make(map[netip.AddrPort]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control), inNat: make(map[netip.AddrPort]*nebula.Control), outNat: make(map[string]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), t: t, cancelRender: cancel, } // Try to remove our render file os.Remove(r.fn) for _, c := range controls { addr := c.GetUDPAddr() if _, ok := r.controls[addr]; ok { panic("Duplicate listen address: " + addr.String()) } for _, vpnAddr := range c.GetVpnAddrs() { r.vpnControls[vpnAddr] = c } r.controls[addr] = c } // Spin the renderer in case we go nuts and the test never completes go func() { clockSource := time.NewTicker(time.Millisecond * 100) defer clockSource.Stop() for { select { case <-ctx.Done(): return case <-clockSource.C: r.renderHostmaps("clock tick") r.renderFlow() } } }() return r } // AddRoute will place the nebula controller at the ip and port specified. // It does not look at the addr attached to the instance. // If a route is used, this will behave like a NAT for the return path. // Rewriting the source ip:port to what was last sent to from the origin func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { r.Lock() defer r.Unlock() inAddr := netip.AddrPortFrom(ip, port) if _, ok := r.inNat[inAddr]; ok { panic("Duplicate listen address inNat: " + inAddr.String()) } r.inNat[inAddr] = c } // RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening. func (r *R) RenderFlow() { r.cancelRender() r.renderFlow() } // CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected func (r *R) CancelFlowLogs() { r.cancelRender() r.flow = nil } func (r *R) renderFlow() { if r.flow == nil { return } f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644) if err != nil { panic(err) } var participants = map[netip.AddrPort]struct{}{} var participantsVals []string fmt.Fprintln(f, "```mermaid") fmt.Fprintln(f, "sequenceDiagram") // Assemble participants for _, e := range r.flow { if e.packet == nil { continue } addr := e.packet.from.GetUDPAddr() if _, ok := participants[addr]; ok { continue } participants[addr] = struct{}{} sanAddr := normalizeName(addr.String()) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", sanAddr, e.packet.from.GetVpnAddrs(), sanAddr, ) } if len(participantsVals) > 2 { // Get the first and last participantVals for notes participantsVals = []string{participantsVals[0], participantsVals[len(participantsVals)-1]} } // Print packets h := &header.H{} for _, e := range r.flow { if e.packet == nil { //fmt.Fprintf(f, " note over %s: %s\n", strings.Join(participantsVals, ", "), e.note) continue } p := e.packet if p.tun { fmt.Fprintln(f, r.formatUdpPacket(p)) } else { if err := h.Parse(p.packet.Data); err != nil { panic(err) } line := "--x" if p.rx { line = "->>" } fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", normalizeName(p.from.GetUDPAddr().String()), line, normalizeName(p.to.GetUDPAddr().String()), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } } fmt.Fprintln(f, "```") for _, g := range r.additionalGraphs { fmt.Fprintf(f, "## %s\n", g.title) fmt.Fprintln(f, "```mermaid") fmt.Fprintln(f, g.content) fmt.Fprintln(f, "```") } } func normalizeName(s string) string { rx := regexp.MustCompile("[\\[\\]\\:]") return rx.ReplaceAllLiteralString(s, "_") } // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered func (r *R) IgnoreFlow(messageType header.MessageType, subType header.MessageSubType, tun NullBool) { r.Lock() defer r.Unlock() r.ignoreFlows = append(r.ignoreFlows, ignoreFlow{ tun, messageType, subType, }) } func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { r.Lock() defer r.Unlock() s := renderHostmaps(controls...) if len(r.additionalGraphs) > 0 { lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1] if lastGraph.content == s && lastGraph.title == title { // Ignore this rendering if it matches the last rendering added // This is useful if you want to track rendering changes return } } r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{ title: title, content: s, }) } func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0 }) s := renderHostmaps(c...) if len(r.additionalGraphs) > 0 { lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1] if lastGraph.content == s { // Ignore this rendering if it matches the last rendering added // This is useful if you want to track rendering changes return } } r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{ title: title, content: s, }) } // InjectFlow can be used to record packet flow if the test is handling the routing on its own. // The packet is assumed to have been received func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) { r.Lock() defer r.Unlock() r.unlockedInjectFlow(from, to, p, false) } func (r *R) Log(arg ...any) { if r.flow == nil { return } r.Lock() r.flow = append(r.flow, flowEntry{note: fmt.Sprint(arg...)}) r.t.Log(arg...) r.Unlock() } func (r *R) Logf(format string, arg ...any) { if r.flow == nil { return } r.Lock() r.flow = append(r.flow, flowEntry{note: fmt.Sprintf(format, arg...)}) r.t.Logf(format, arg...) r.Unlock() } // unlockedInjectFlow is used by the router to record a packet has been transmitted, the packet is returned and // should be marked as received AFTER it has been placed on the receivers channel. // If flow logs have been disabled this function will return nil func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool) *packet { if r.flow == nil { return nil } r.renderHostmaps(fmt.Sprintf("Packet %v", len(r.flow))) if len(r.ignoreFlows) > 0 { var h header.H err := h.Parse(p.Data) if err != nil { panic(err) } for _, i := range r.ignoreFlows { if !tun { if i.messageType == h.Type && i.subType == h.Subtype { return nil } } else if i.tun.HasValue && i.tun.IsTrue { return nil } } } fp := &packet{ from: from, to: to, packet: p.Copy(), tun: tun, } r.flow = append(r.flow, flowEntry{packet: fp}) return fp } // OnceFrom will route a single packet from sender then return // If the router doesn't have the nebula controller for that address, we panic func (r *R) OnceFrom(sender *nebula.Control) { r.RouteExitFunc(sender, func(*udp.Packet, *nebula.Control) ExitType { return RouteAndExit }) } // RouteUntilTxTun will route for sender and return when a packet is seen on receivers tun // If the router doesn't have the nebula controller for that address, we panic func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []byte { tunTx := receiver.GetTunTxChan() udpTx := sender.GetUDPTxChan() for { select { // Maybe we already have something on the tun for us case b := <-tunTx: r.Lock() np := udp.Packet{Data: make([]byte, len(b))} copy(np.Data, b) r.unlockedInjectFlow(receiver, receiver, &np, true) r.Unlock() return b // Nope, lets push the sender along case p := <-udpTx: r.Lock() a := sender.GetUDPAddr() c := r.getControl(a, p.To, p) if c == nil { r.Unlock() panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) c.InjectUDPPacket(p) fp.WasReceived() r.Unlock() } } } // RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun // If the router doesn't have the nebula controller for that address, we panic func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { sc := make([]reflect.SelectCase, len(r.controls)+1) cm := make([]*nebula.Control, len(r.controls)+1) i := 0 sc[i] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan()), Send: reflect.Value{}, } cm[i] = receiver i++ for _, c := range r.controls { sc[i] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan()), Send: reflect.Value{}, } cm[i] = c i++ } for { x, rx, _ := reflect.Select(sc) r.Lock() if x == 0 { // we are the tun tx, we can exit p := rx.Interface().([]byte) np := udp.Packet{Data: make([]byte, len(p))} copy(np.Data, p) r.unlockedInjectFlow(cm[x], cm[x], &np, true) r.Unlock() return p } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) a := cm[x].GetUDPAddr() c := r.getControl(a, p.To, p) if c == nil { r.Unlock() panic(fmt.Sprintf("No control for udp tx %s", p.To)) } fp := r.unlockedInjectFlow(cm[x], c, p, false) c.InjectUDPPacket(p) fp.WasReceived() } r.Unlock() } } // RouteExitFunc will call the whatDo func with each udp packet from sender. // whatDo can return: // - exitNow: the packet will not be routed and this call will return immediately // - routeAndExit: this call will return immediately after routing the last packet from sender // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { h := &header.H{} for { p := sender.GetFromUDP(true) r.Lock() if err := h.Parse(p.Data); err != nil { panic(err) } receiver := r.getControl(sender.GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() panic("Can't RouteExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) switch e { case ExitNow: r.Unlock() return case RouteAndExit: fp := r.unlockedInjectFlow(sender, receiver, p, false) receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() return case KeepRouting: fp := r.unlockedInjectFlow(sender, receiver, p, false) receiver.InjectUDPPacket(p) fp.WasReceived() default: panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) } r.Unlock() } } // RouteUntilAfterMsgType will route for sender until a message type is seen and sent from sender // If the router doesn't have the nebula controller for that address, we panic func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType header.MessageType, subType header.MessageSubType) { h := &header.H{} r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { if err := h.Parse(p.Data); err != nil { panic(err) } if h.Type == msgType && h.Subtype == subType { return RouteAndExit } return KeepRouting }) } func (r *R) RouteForAllUntilAfterMsgTypeTo(receiver *nebula.Control, msgType header.MessageType, subType header.MessageSubType) { h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, r *nebula.Control) ExitType { if r != receiver { return KeepRouting } if err := h.Parse(p.Data); err != nil { panic(err) } if h.Type == msgType && h.Subtype == subType { return RouteAndExit } return KeepRouting }) } func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet) { r.Lock() defer r.Unlock() fp := r.unlockedInjectFlow(sender, receiver, packet, false) receiver.InjectUDPPacket(packet) fp.WasReceived() } // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) { if finish == KeepRouting { finish = RouteAndExit } r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { if p.To == toAddr { return finish } return KeepRouting }) } // RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from // whatDo can return: // - exitNow: the packet will not be routed and this call will return immediately // - routeAndExit: this call will return immediately after routing the last packet from sender // - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { sc := make([]reflect.SelectCase, len(r.controls)) cm := make([]*nebula.Control, len(r.controls)) i := 0 for _, c := range r.controls { sc[i] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan()), Send: reflect.Value{}, } cm[i] = c i++ } for { x, rx, _ := reflect.Select(sc) r.Lock() p := rx.Interface().(*udp.Packet) receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() panic("Can't RouteForAllExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) switch e { case ExitNow: r.Unlock() return case RouteAndExit: fp := r.unlockedInjectFlow(cm[x], receiver, p, false) receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() return case KeepRouting: fp := r.unlockedInjectFlow(cm[x], receiver, p, false) receiver.InjectUDPPacket(p) fp.WasReceived() default: panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) } r.Unlock() } } // FlushAll will route for every registered controller, exiting once there are no packets left to route func (r *R) FlushAll() { sc := make([]reflect.SelectCase, len(r.controls)) cm := make([]*nebula.Control, len(r.controls)) i := 0 for _, c := range r.controls { sc[i] = reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan()), Send: reflect.Value{}, } cm[i] = c i++ } // Add a default case to exit when nothing is left to send sc = append(sc, reflect.SelectCase{ Dir: reflect.SelectDefault, Chan: reflect.Value{}, Send: reflect.Value{}, }) for { x, rx, ok := reflect.Select(sc) if !ok { return } r.Lock() p := rx.Interface().(*udp.Packet) receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() panic("Can't FlushAll for host: " + p.To.String()) } receiver.InjectUDPPacket(p) r.Unlock() } } // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { p.From = newAddr } c, ok := r.inNat[toAddr] if ok { r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr return c } return r.controls[toAddr] } func (r *R) formatUdpPacket(p *packet) string { var packet gopacket.Packet var srcAddr netip.Addr packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy) if packet.ErrorLayer() == nil { v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) if v6 == nil { panic("not an ipv6 packet") } srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } else { packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if v6 == nil { panic("not an ipv6 packet") } srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } from := "unknown" if c, ok := r.vpnControls[srcAddr]; ok { from = c.GetUDPAddr().String() } udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) if udpLayer == nil { panic("not a udp packet") } data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", normalizeName(from), normalizeName(p.to.GetUDPAddr().String()), udpLayer.SrcPort, udpLayer.DstPort, string(data.Payload()), ) } ================================================ FILE: e2e/tunnels_test.go ================================================ //go:build e2e_testing // +build e2e_testing package e2e import ( "fmt" "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v3" ) func TestDropInactiveTunnels(t *testing.T) { // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) // Share our underlay information myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() theirControl.Start() r := router.NewR(t, myControl, theirControl) r.Log("Assert the tunnel between me and them works") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Go inactive and wait for the tunnels to get dropped") waitStart := time.Now() for { myIndexes := len(myControl.GetHostmap().Indexes) theirIndexes := len(theirControl.GetHostmap().Indexes) if myIndexes == 0 && theirIndexes == 0 { break } since := time.Since(waitStart) r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) if since > time.Second*30 { t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds") } time.Sleep(1 * time.Second) r.FlushAll() } r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart)) myControl.Stop() theirControl.Stop() } func TestCertUpgrade(t *testing.T) { // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) caB, err := ca.MarshalPEM() if err != nil { panic(err) } ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca2B, err := ca2.MarshalPEM() if err != nil { panic(err) } caStr := fmt.Sprintf("%s\n%s", caB, ca2B) myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) _, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) // Share our underlay information myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() theirControl.Start() r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() r.Log("Assert the tunnel between me and them works") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("yay") //todo ??? time.Sleep(1 * time.Second) r.FlushAll() mc := m{ "pki": m{ "ca": caStr, "cert": string(myCert2Pem), "key": string(myPrivKey), }, //"tun": m{"disabled": true}, "firewall": myC.Settings["firewall"], //"handshakes": m{ // "try_interval": "1s", //}, "listen": myC.Settings["listen"], "logging": myC.Settings["logging"], "timers": myC.Settings["timers"], } cb, err := yaml.Marshal(mc) if err != nil { panic(err) } r.Logf("reload new v2-only config") err = myC.ReloadConfigString(string(cb)) assert.NoError(t, err) r.Log("yay, spin until their sees it") waitStart := time.Now() for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) if c == nil { r.Log("nil") } else { version := c.Cert.Version() r.Logf("version %d", version) if version == cert.Version2 { break } } since := time.Since(waitStart) if since > time.Second*10 { t.Fatal("Cert should be new by now") } time.Sleep(time.Second) } r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestCertDowngrade(t *testing.T) { // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) caB, err := ca.MarshalPEM() if err != nil { panic(err) } ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca2B, err := ca2.MarshalPEM() if err != nil { panic(err) } caStr := fmt.Sprintf("%s\n%s", caB, ca2B) myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) // Share our underlay information myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() theirControl.Start() r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() r.Log("Assert the tunnel between me and them works") //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) //r.Log("yay") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("yay") //todo ??? time.Sleep(1 * time.Second) r.FlushAll() mc := m{ "pki": m{ "ca": caStr, "cert": string(myCertPem), "key": string(myPrivKey), }, "firewall": myC.Settings["firewall"], "listen": myC.Settings["listen"], "logging": myC.Settings["logging"], "timers": myC.Settings["timers"], } cb, err := yaml.Marshal(mc) if err != nil { panic(err) } r.Logf("reload new v1-only config") err = myC.ReloadConfigString(string(cb)) assert.NoError(t, err) r.Log("yay, spin until their sees it") waitStart := time.Now() for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) if c == nil || c2 == nil { r.Log("nil") } else { version := c.Cert.Version() theirVersion := c2.Cert.Version() r.Logf("version %d,%d", version, theirVersion) if version == cert.Version1 { break } } since := time.Since(waitStart) if since > time.Second*5 { r.Log("it is unusual that the cert is not new yet, but not a failure yet") } if since > time.Second*10 { r.Log("wtf") t.Fatal("Cert should be new by now") } time.Sleep(time.Second) } r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestCertMismatchCorrection(t *testing.T) { // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides // under ideal conditions ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) // Share our underlay information myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() theirControl.Start() r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() r.Log("Assert the tunnel between me and them works") //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) //r.Log("yay") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("yay") //todo ??? time.Sleep(1 * time.Second) r.FlushAll() waitStart := time.Now() for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) if c == nil || c2 == nil { r.Log("nil") } else { version := c.Cert.Version() theirVersion := c2.Cert.Version() r.Logf("version %d,%d", version, theirVersion) if version == theirVersion { break } } since := time.Since(waitStart) if since > time.Second*5 { r.Log("wtf") } if since > time.Second*10 { r.Log("wtf") t.Fatal("Cert should be new by now") } time.Sleep(time.Second) } r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() theirControl.Stop() } func TestCrossStackRelaysWork(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) theirUdp := netip.MustParseAddrPort("10.0.0.2:4242") theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}}) //myVpnV4 := myVpnIpNet[0] myVpnV6 := myVpnIpNet[1] relayVpnV4 := relayVpnIpNet[0] relayVpnV6 := relayVpnIpNet[1] theirVpnV6 := theirVpnIpNet[0] // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr) myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()}) relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) defer r.RenderFlow() // Start the servers myControl.Start() relayControl.Start() theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) t.Log("reply?") theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //t.Log("finish up") //myControl.Stop() //theirControl.Stop() //relayControl.Stop() } ================================================ FILE: examples/config.yml ================================================ # This is the nebula example configuration file. You must edit, at a minimum, the static_host_map, lighthouse, and firewall sections # Some options in this file are HUPable, including the pki section. (A HUP will reload credentials from disk without affecting existing tunnels) # PKI defines the location of credentials for this node. Each of these can also be inlined by using the yaml ": |" syntax. pki: # The CAs that are accepted by this node. Must contain one or more certificates created by 'nebula-cert ca' ca: /etc/nebula/ca.crt cert: /etc/nebula/host.crt key: /etc/nebula/host.key # blocklist is a list of certificate fingerprints that we will refuse to talk to #blocklist: # - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72 # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true # initiating_version controls which certificate version is used when initiating handshakes. # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. # initiating_version: 1 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: # "{nebula ip}": ["{routable ip/dns name}:{routable port}"] # Example, if your lighthouse has the nebula IP of 192.168.100.1 and has the real ip address of 100.64.22.11 and runs on port 4242: static_host_map: "192.168.100.1": ["100.64.22.11:4242"] # The static_map config stanza can be used to configure how the static_host_map behaves. #static_map: # cadence determines how frequently DNS is re-queried for updated IP addresses when a static_host_map entry contains # a DNS name. #cadence: 30s # network determines the type of IP addresses to ask the DNS server for. The default is "ip4" because nodes typically # do not know their public IPv4 address. Connecting to the Lighthouse via IPv4 allows the Lighthouse to detect the # public address. Other valid options are "ip6" and "ip" (returns both.) #network: ip4 # lookup_timeout is the DNS query timeout. #lookup_timeout: 250ms lighthouse: # am_lighthouse is used to enable lighthouse functionality for a node. This should ONLY be true on nodes # you have configured to be lighthouses in your network am_lighthouse: false # serve_dns optionally starts a dns listener that responds to various queries and can even be # delegated to for resolution #serve_dns: false #dns: # The DNS host defines the IP to bind the dns listener to. This also allows binding to the nebula node IP. #host: 0.0.0.0 #port: 53 # interval is the number of seconds between updates from this node to a lighthouse. # during updates, a node sends information about its current IP addresses to each node. interval: 60 # hosts is a list of lighthouse hosts this node should report to and query from # IMPORTANT: THIS SHOULD BE EMPTY ON LIGHTHOUSE NODES # IMPORTANT2: THIS SHOULD BE LIGHTHOUSES' NEBULA IPs, NOT LIGHTHOUSES' REAL ROUTABLE IPs hosts: - "192.168.100.1" # remote_allow_list allows you to control ip ranges that this node will # consider when handshaking to another node. By default, any remote IPs are # allowed. You can provide CIDRs here with `true` to allow and `false` to # deny. The most specific CIDR rule applies to each remote. If all rules are # "allow", the default will be "deny", and vice-versa. If both "allow" and # "deny" IPv4 rules are present, then you MUST set a rule for "0.0.0.0/0" as # the default. Similarly if both "allow" and "deny" IPv6 rules are present, # then you MUST set a rule for "::/0" as the default. #remote_allow_list: # Example to block IPs from this subnet from being used for remote IPs. #"172.16.0.0/12": false # A more complicated example, allow public IPs but only private IPs from a specific subnet #"0.0.0.0/0": true #"10.0.0.0/8": false #"10.42.42.0/24": true # EXPERIMENTAL: This option may change or disappear in the future. # Optionally allows the definition of remote_allow_list blocks # specific to an inside VPN IP CIDR. #remote_allow_ranges: # This rule would only allow only private IPs for this VPN range #"10.42.42.0/24": #"192.168.0.0/16": true # local_allow_list allows you to filter which local IP addresses we advertise # to the lighthouses. This uses the same logic as `remote_allow_list`, but # additionally, you can specify an `interfaces` map of regular expressions # to match against interface names. The regexp must match the entire name. # All interface rules must be either true or false (and the default will be # the inverse). CIDR rules are matched after interface name rules. # Default is all local IP addresses. #local_allow_list: # Example to block tun0 and all docker interfaces. #interfaces: #tun0: false #'docker.*': false # Example to only advertise this subnet to the lighthouse. #"10.0.0.0/8": true # advertise_addrs are routable addresses that will be included along with discovered addresses to report to the # lighthouse, the format is "ip:port". `port` can be `0`, in which case the actual listening port will be used in its # place, useful if `listen.port` is set to 0. # This option is mainly useful when there are static ip addresses the host can be reached at that nebula can not # typically discover on its own. Examples being port forwarding or multiple paths to the internet. #advertise_addrs: #- "1.1.1.1:4242" #- "1.2.3.4:0" # port will be replaced with the real listening port # EXPERIMENTAL: This option may change or disappear in the future. # This setting allows us to "guess" what the remote might be for a host # while we wait for the lighthouse response. #calculated_remotes: # For any Nebula IPs in 10.0.10.0/24, this will apply the mask and add # the calculated IP as an initial remote (while we wait for the response # from the lighthouse). Both CIDRs must have the same mask size. # For example, Nebula IP 10.0.10.123 will have a calculated remote of # 192.168.1.123 #10.0.10.0/24: #- mask: 192.168.1.0/24 # port: 4242 # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined, # however using port 0 will dynamically assign a port and is recommended for roaming nodes. listen: # To listen on only ipv4, use "0.0.0.0" host: "::" port: 4242 # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg) # default is 64, does not support reload #batch: 64 # Configure socket buffers for the udp side (outside), leave unset to use the system defaults. Values will be doubled by the kernel # Default is net.core.rmem_default and net.core.wmem_default (/proc/sys/net/core/rmem_default and /proc/sys/net/core/rmem_default) # Maximum is limited by memory in the system, SO_RCVBUFFORCE and SO_SNDBUFFORCE is used to avoid having to raise the system wide # max, net.core.rmem_max and net.core.wmem_max #read_buffer: 10485760 #write_buffer: 10485760 # By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection # in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running # on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes. # valid values: always, never, private # This setting is reloadable. #send_recv_error: always # Similar to send_recv_error, this option lets you configure if you want to accept "recv_error" packets from remote hosts. # valid values: always, never, private # This setting is reloadable. #accept_recv_error: always # The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier. # This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes, # allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set. # This setting is reloadable. #so_mark: 0 # Routines is the number of thread pairs to run that consume from the tun and UDP queues. # Currently, this defaults to 1 which means we have 1 tun queue reader and 1 # UDP queue reader. Setting this above one will set IFF_MULTI_QUEUE on the tun # device and SO_REUSEPORT on the UDP socket to allow multiple queues. # This option is only supported on Linux. #routines: 1 punchy: # Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings punch: true # respond means that a node you are trying to reach will connect back out to you if your hole punching fails # this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT # Default is false #respond: true # delays a punch response for misbehaving NATs, default is 1 second. #delay: 1s # set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect. #respond_delay: 5s # Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes # IMPORTANT: this value must be identical on ALL NODES/LIGHTHOUSES. We do not/will not support use of different ciphers simultaneously! #cipher: aes # Preferred ranges is used to define a hint about the local network ranges, which speeds up discovering the fastest # path to a network adjacent nebula node. # This setting is reloadable. #preferred_ranges: ["172.16.0.0/24"] # sshd can expose informational and administrative functions via ssh. This can expose informational and administrative # functions, and allows manual tweaking of various network settings when debugging or testing. #sshd: # Toggles the feature #enabled: true # Host and port to listen on, port 22 is not allowed for your safety #listen: 127.0.0.1:2222 # A file containing the ssh host private key to use # A decent way to generate one: ssh-keygen -t ed25519 -f ssh_host_ed25519_key -N "" < /dev/null #host_key: ./ssh_host_ed25519_key # Authorized users and their public keys #authorized_users: #- user: steeeeve # keys can be an array of strings or single string #keys: #- "ssh public key string" # Trusted SSH CA public keys. These are the public keys of the CAs that are allowed to sign SSH keys for access. #trusted_cas: #- "ssh public key string" # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: # Relays are a list of Nebula IP's that peers can use to relay packets to me. # IPs in this list must have am_relay set to true in their configs, otherwise # they will reject relay requests. #relays: #- 192.168.100.1 #- # Set am_relay to true to permit other hosts to list my IP in their relays config. Default false. am_relay: false # Set use_relays to false to prevent this instance from attempting to establish connections through relays. # default true use_relays: true # Configure the private interface. Note: addr is baked into the nebula certificate tun: # When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root) disabled: false # Name of the device. If not set, a default will be chosen by the OS. # For macOS: if set, must be in the form `utun[0-9]+`. # For NetBSD: Required to be set, must be in the form `tun[0-9]+` dev: nebula1 # Toggles forwarding of local broadcast packets, the address of which depends on the ip/mask encoded in pki.cert drop_local_broadcast: false # Toggles forwarding of multicast packets drop_multicast: false # Sets the transmit queue length, if you notice lots of transmit drops on the tun it may help to raise this number. Default is 500 tx_queue: 500 # Default MTU for every packet, safe setting is (and the default) 1300 for internet based traffic mtu: 1300 # Route based MTU overrides, you have known vpn ip paths that can support larger MTUs you can increase/decrease them here routes: #- mtu: 8800 # route: 10.0.0.0/16 # Unsafe routes allows you to route traffic over nebula to non-nebula nodes # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula # NOTES: # * You will only see a single gateway in the routing table if you are not on linux # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights # # unsafe_routes: # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways # - route: 192.168.87.0/24 # via: # - gateway: 10.0.0.1 # - gateway: 10.0.0.2 # - gateway: 10.0.0.3 # # Multiple gateways with a weight, this will balance traffic accordingly # - route: 192.168.87.0/24 # via: # - gateway: 10.0.0.1 # weight: 10 # - gateway: 10.0.0.2 # weight: 5 # # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate # `via`: single node or list of gateways to use for this route # `mtu`: will default to tun mtu if this option is not specified # `metric`: will default to 0 if this option is not specified # `install`: will default to true, controls whether this route is installed in the systems routing table. # This setting is reloadable. unsafe_routes: #- route: 172.16.1.0/24 # via: 192.168.100.99 # mtu: 1300 # metric: 100 # install: true # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false # Buffer size for reading routes updates. 0 means default system buffer size. (/proc/sys/net/core/rmem_default). # If using massive routes updates, for example BGP, you may need to increase this value to avoid packet loss. # SO_RCVBUFFORCE is used to avoid having to raise the system wide max #use_system_route_table_buffer_size: 0 # Configure logging level logging: # panic, fatal, error, warning, info, or debug. Default is info and is reloadable. #NOTE: Debug mode can log remotely controlled/untrusted data which can quickly fill a disk in some # scenarios. Debug logging is also CPU intensive and will decrease performance overall. # Only enable debug logging while actively investigating an issue. level: info # json or text formats currently available. Default is text format: text # Disable timestamp logging. useful when output is redirected to logging system that already adds timestamps. Default is false #disable_timestamp: true # timestamp format is specified in Go time format, see: # https://golang.org/pkg/time/#pkg-constants # default when `format: json`: "2006-01-02T15:04:05Z07:00" (RFC3339) # default when `format: text`: # when TTY attached: seconds since beginning of execution # otherwise: "2006-01-02T15:04:05Z07:00" (RFC3339) # As an example, to log as RFC3339 with millisecond precision, set to: #timestamp_format: "2006-01-02T15:04:05.000Z07:00" #stats: #type: graphite #prefix: nebula #protocol: tcp #host: 127.0.0.1:9999 #interval: 10s #type: prometheus #listen: 127.0.0.1:8080 #path: /metrics #namespace: prometheusns #subsystem: nebula #interval: 10s # enables counter metrics for meta packets # e.g.: `messages.tx.handshake` # NOTE: `message.{tx,rx}.recv_error` is always emitted #message_metrics: false # enables detailed counter metrics for lighthouse packets # e.g.: `lighthouse.rx.HostQuery` #lighthouse_metrics: false # Handshake Manager Settings #handshakes: # Handshakes are sent to all known addresses at each interval with a linear backoff, # Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out #try_interval: 100ms #retries: 20 # query_buffer is the size of the buffer channel for querying lighthouses #query_buffer: 64 # trigger_buffer is the size of the buffer channel for quickly sending handshakes # after receiving the response for lighthouse queries #trigger_buffer: 64 # Tunnel manager settings #tunnels: # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has # elapsed. # In general, it is a good idea to enable this setting. It will be enabled by default in a future release. # This setting is reloadable #drop_inactive: false # inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered # inactive and eligible to be dropped. # This setting is reloadable #inactivity_timeout: 10m # Nebula security group configuration firewall: # Action to take when a packet is not allowed by the firewall rules. # Can be one of: # `drop` (default): silently drop the packet. # `reject`: send a reject reply. # - For TCP, this will be a RST "Connection Reset" packet. # - For other protocols, this will be an ICMP port unreachable packet. outbound_action: drop inbound_action: drop # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.) # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr` # is explicitly defined. This is usually not the desired behavior and should be avoided! #default_local_cidr_any: false conntrack: tcp_timeout: 12m udp_timeout: 3m default_timeout: 10m # The firewall is default deny. There is no way to write a deny rule. # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). # proto: `any`, `tcp`, `udp`, or `icmp` # a port specification is ignored if proto is `icmp` # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. # This can be used to filter destinations when using unsafe_routes. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum outbound: # Allow all outbound traffic from this node - port: any proto: any host: any inbound: # Allow icmp between any nebula hosts - port: any proto: icmp host: any # Allow tcp/443 from any host with BOTH laptop and home group - port: 443 proto: tcp groups: - laptop - home # Expose a subnet (unsafe route) to hosts with the group remote_client # This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate - port: 8080 proto: tcp group: remote_client local_cidr: 192.168.100.1/24 ================================================ FILE: examples/go_service/main.go ================================================ package main import ( "bufio" "fmt" "log" "net" "os" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/service" ) func main() { if err := run(); err != nil { log.Fatalf("%+v", err) } } func run() error { configStr := ` tun: user: true static_host_map: '192.168.100.1': ['localhost:4242'] listen: host: 0.0.0.0 port: 4241 lighthouse: am_lighthouse: false interval: 60 hosts: - '192.168.100.1' firewall: outbound: # Allow all outbound traffic from this node - port: any proto: any host: any inbound: # Allow icmp between any nebula hosts - port: any proto: icmp host: any - port: any proto: any host: any pki: ca: /home/rice/Developer/nebula-config/ca.crt cert: /home/rice/Developer/nebula-config/app.crt key: /home/rice/Developer/nebula-config/app.key ` var cfg config.C if err := cfg.LoadString(configStr); err != nil { return err } logger := logrus.New() logger.Out = os.Stdout ctrl, err := nebula.Main(&cfg, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { return err } svc, err := service.New(ctrl) if err != nil { return err } ln, err := svc.Listen("tcp", ":1234") if err != nil { return err } for { conn, err := ln.Accept() if err != nil { log.Printf("accept error: %s", err) break } defer func(conn net.Conn) { _ = conn.Close() }(conn) log.Printf("got connection") _, err = conn.Write([]byte("hello world\n")) if err != nil { log.Printf("write error: %s", err) } scanner := bufio.NewScanner(conn) for scanner.Scan() { message := scanner.Text() _, err = fmt.Fprintf(conn, "echo: %q\n", message) if err != nil { log.Printf("write error: %s", err) } log.Printf("got message %q", message) } if err := scanner.Err(); err != nil { log.Printf("scanner error: %s", err) break } } _ = svc.Close() if err := svc.Wait(); err != nil { return err } return nil } ================================================ FILE: examples/service_scripts/nebula.init.d.sh ================================================ #!/bin/sh ### BEGIN INIT INFO # Provides: nebula # Required-Start: $local_fs $network # Required-Stop: $local_fs $network # Default-Start: 2 3 4 5 # Default-Stop: 0 1 6 # Description: nebula mesh vpn client ### END INIT INFO SCRIPT="/usr/local/bin/nebula -config /etc/nebula/config.yml" RUNAS=root PIDFILE=/var/run/nebula.pid LOGFILE=/var/log/nebula.log start() { if [ -f $PIDFILE ] && kill -0 $(cat $PIDFILE); then echo 'Service already running' >&2 return 1 fi echo 'Starting nebula service…' >&2 local CMD="$SCRIPT &> \"$LOGFILE\" & echo \$!" su -c "$CMD" $RUNAS > "$PIDFILE" echo 'Service started' >&2 } stop() { if [ ! -f "$PIDFILE" ] || ! kill -0 $(cat "$PIDFILE"); then echo 'Service not running' >&2 return 1 fi echo 'Stopping nebula service…' >&2 kill -15 $(cat "$PIDFILE") && rm -f "$PIDFILE" echo 'Service stopped' >&2 } case "$1" in start) start ;; stop) stop ;; restart) stop start ;; *) echo "Usage: $0 {start|stop|restart}" esac ================================================ FILE: examples/service_scripts/nebula.open-rc ================================================ #!/sbin/openrc-run # # nebula service for open-rc systems extra_commands="checkconfig" : ${NEBULA_CONFDIR:=${RC_PREFIX%/}/etc/nebula} : ${NEBULA_CONFIG:=${NEBULA_CONFDIR}/config.yml} : ${NEBULA_BINARY:=${NEBULA_BINARY}${RC_PREFIX%/}/usr/local/sbin/nebula} command="${NEBULA_BINARY}" command_args="${NEBULA_OPTS} -config ${NEBULA_CONFIG}" supervisor="supervise-daemon" description="A scalable overlay networking tool with a focus on performance, simplicity and security" required_dirs="${NEBULA_CONFDIR}" required_files="${NEBULA_CONFIG}" checkconfig() { "${command}" -test ${command_args} || return 1 } start_pre() { if [ "${RC_CMD}" != "restart" ] ; then checkconfig || return $? fi } stop_pre() { if [ "${RC_CMD}" = "restart" ] ; then checkconfig || return $? fi } ================================================ FILE: examples/service_scripts/nebula.plist ================================================ KeepAlive Label net.defined.nebula WorkingDirectory /Users/{username}/.local/bin/nebula LimitLoadToSessionType Aqua Background LoginWindow StandardIO System ProgramArguments ./nebula -config ./config.yml RunAtLoad StandardErrorPath ./nebula.log StandardOutPath ./nebula.log UserName root ================================================ FILE: examples/service_scripts/nebula.service ================================================ [Unit] Description=Nebula overlay networking tool Wants=basic.target network-online.target nss-lookup.target time-sync.target After=basic.target network.target network-online.target Before=sshd.service [Service] Type=notify NotifyAccess=main SyslogIdentifier=nebula ExecReload=/bin/kill -HUP $MAINPID ExecStart=/usr/local/bin/nebula -config /etc/nebula/config.yml Restart=always [Install] WantedBy=multi-user.target ================================================ FILE: firewall/cache.go ================================================ package firewall import ( "sync/atomic" "time" "github.com/sirupsen/logrus" ) // ConntrackCache is used as a local routine cache to know if a given flow // has been seen in the conntrack table. type ConntrackCache map[Packet]struct{} type ConntrackCacheTicker struct { cacheV uint64 cacheTick atomic.Uint64 cache ConntrackCache } func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker { if d == 0 { return nil } c := &ConntrackCacheTicker{ cache: ConntrackCache{}, } go c.tick(d) return c } func (c *ConntrackCacheTicker) tick(d time.Duration) { for { time.Sleep(d) c.cacheTick.Add(1) } } // Get checks if the cache ticker has moved to the next version before returning // the map. If it has moved, we reset the map. func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache { if c == nil { return nil } if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { if l.Level == logrus.DebugLevel { l.WithField("len", ll).Debug("resetting conntrack cache") } c.cache = make(ConntrackCache, ll) } } return c.cache } ================================================ FILE: firewall/packet.go ================================================ package firewall import ( "encoding/json" "fmt" "net/netip" ) type m = map[string]any const ( ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever ProtoTCP = 6 ProtoUDP = 17 ProtoICMP = 1 ProtoICMPv6 = 58 PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` ) type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr // LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP. LocalPort uint16 // RemotePort is the source port for incoming traffic, or the destination port for outgoing. // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool } func (fp *Packet) Copy() *Packet { return &Packet{ LocalAddr: fp.LocalAddr, RemoteAddr: fp.RemoteAddr, LocalPort: fp.LocalPort, RemotePort: fp.RemotePort, Protocol: fp.Protocol, Fragment: fp.Fragment, } } func (fp Packet) MarshalJSON() ([]byte, error) { var proto string switch fp.Protocol { case ProtoTCP: proto = "tcp" case ProtoICMP: proto = "icmp" case ProtoICMPv6: proto = "icmpv6" case ProtoUDP: proto = "udp" default: proto = fmt.Sprintf("unknown %v", fp.Protocol) } return json.Marshal(m{ "LocalAddr": fp.LocalAddr.String(), "RemoteAddr": fp.RemoteAddr.String(), "LocalPort": fp.LocalPort, "RemotePort": fp.RemotePort, "Protocol": proto, "Fragment": fp.Fragment, }) } ================================================ FILE: firewall.go ================================================ package nebula import ( "crypto/sha256" "encoding/hex" "errors" "fmt" "hash/fnv" "net/netip" "reflect" "slices" "strconv" "strings" "sync" "time" "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" ) type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } type conn struct { Expires time.Time // Time when this conntrack entry will expire // record why the original connection passed the firewall, so we can re-validate // after ruleset changes. Note, rulesVersion is a uint16 so that these two // fields pack for free after the uint32 above incoming bool rulesVersion uint16 } // TODO: need conntrack max tracked connections handling type Firewall struct { Conntrack *FirewallConntrack InRules *FirewallTable OutRules *FirewallTable InSendReject bool OutSendReject bool //TODO: we should have many more options for TCP, an option for ICMP, and mimic the kernel a bit better // https://www.kernel.org/doc/Documentation/networking/nf_conntrack-sysctl.txt TCPTimeout time.Duration //linux: 5 days max UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. // The vpn addresses are a full bit match while the unsafe networks only match the prefix routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. assignedNetworks []netip.Prefix hasUnsafeNetworks bool rules string rulesVersion uint16 defaultLocalCIDRAny bool incomingMetrics firewallMetrics outgoingMetrics firewallMetrics l *logrus.Logger } type firewallMetrics struct { droppedLocalAddr metrics.Counter droppedRemoteAddr metrics.Counter droppedNoRule metrics.Counter } type FirewallConntrack struct { sync.Mutex Conns map[firewall.Packet]*conn TimerWheel *TimerWheel[firewall.Packet] } // FirewallTable is the entry point for a rule, the evaluation order is: // Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR) type FirewallTable struct { TCP firewallPort UDP firewallPort ICMP firewallPort AnyProto firewallPort } func newFirewallTable() *FirewallTable { return &FirewallTable{ TCP: firewallPort{}, UDP: firewallPort{}, ICMP: firewallPort{}, AnyProto: firewallPort{}, } } type FirewallCA struct { Any *FirewallRule CANames map[string]*FirewallRule CAShas map[string]*FirewallRule } type FirewallRule struct { // Any makes Hosts, Groups, and CIDR irrelevant Any *firewallLocalCIDR Hosts map[string]*firewallLocalCIDR Groups []*firewallGroups CIDR *bart.Table[*firewallLocalCIDR] } type firewallGroups struct { Groups []string LocalCIDR *firewallLocalCIDR } // Even though ports are uint16, int32 maps are faster for lookup // Plus we can use `-1` for fragment rules type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool LocalCIDR *bart.Lite } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration if tcpTimeout < UDPTimeout { tmin = tcpTimeout tmax = UDPTimeout } else { tmin = UDPTimeout tmax = tcpTimeout } if defaultTimeout < tmin { tmin = defaultTimeout } else if defaultTimeout > tmax { tmax = defaultTimeout } routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { routableNetworks.Insert(n) hasUnsafeNetworks = true } return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, InRules: newFirewallTable(), OutRules: newFirewallTable(), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, routableNetworks: routableNetworks, assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, l: l, incomingMetrics: firewallMetrics{ droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil), droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), }, outgoingMetrics: firewallMetrics{ droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil), droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil), droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), }, } } func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) } if certificate == nil { panic("No certificate available to reconfigure the firewall") } fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, //TODO: max_connections ) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { case "reject": fw.InSendReject = true case "drop": fw.InSendReject = false default: l.WithField("action", inboundAction).Warn("invalid firewall.inbound_action, defaulting to `drop`") fw.InSendReject = false } outboundAction := c.GetString("firewall.outbound_action", "drop") switch outboundAction { case "reject": fw.OutSendReject = true case "drop": fw.OutSendReject = false default: l.WithField("action", outboundAction).Warn("invalid firewall.outbound_action, defaulting to `drop`") fw.OutSendReject = false } err := AddFirewallRulesFromConfig(l, false, c, fw) if err != nil { return nil, err } err = AddFirewallRulesFromConfig(l, true, c, fw) if err != nil { return nil, err } return fw, nil } // AddRule properly creates the in memory rule structure for a firewall table. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { var ( ft *FirewallTable fp firewallPort ) if incoming { ft = f.InRules } else { ft = f.OutRules } switch proto { case firewall.ProtoTCP: fp = ft.TCP case firewall.ProtoUDP: fp = ft.UDP case firewall.ProtoICMP, firewall.ProtoICMPv6: //ICMP traffic doesn't have ports, so we always coerce to "any", even if a value is provided if startPort != firewall.PortAny { f.l.WithField("startPort", startPort).Warn("ignoring port specification for ICMP firewall rule") } startPort = firewall.PortAny endPort = firewall.PortAny fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto default: return fmt.Errorf("unknown protocol %v", proto) } // We need this rule string because we generate a hash. Removing this will break firewall reload. ruleString := fmt.Sprintf( "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, ) f.rules += ruleString + "\n" direction := "incoming" if !incoming { direction = "outgoing" } f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). Info("Firewall rule added") return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } // GetRuleHash returns a hash representation of all inbound and outbound rules func (f *Firewall) GetRuleHash() string { sum := sha256.Sum256([]byte(f.rules)) return hex.EncodeToString(sum[:]) } // GetRuleHashFNV returns a uint32 FNV-1 hash representation the rules, for use as a metric value func (f *Firewall) GetRuleHashFNV() uint32 { h := fnv.New32a() h.Write([]byte(f.rules)) return h.Sum32() } // GetRuleHashes returns both the sha256 and FNV-1 hashes, suitable for logging func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { table = "firewall.inbound" } else { table = "firewall.outbound" } r := c.Get(table) if r == nil { return nil } rs, ok := r.([]any) if !ok { return fmt.Errorf("%s failed to parse, should be an array of rules", table) } for i, t := range rs { r, err := convertRule(l, t, table, i) if err != nil { return fmt.Errorf("%s rule #%v; %s", table, i, err) } if r.Code != "" && r.Port != "" { return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) } if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) } var sPort, errPort string if r.Code != "" { errPort = "code" sPort = r.Code } else { errPort = "port" sPort = r.Port } var proto uint8 var startPort, endPort int32 switch r.Proto { case "any": proto = firewall.ProtoAny startPort, endPort, err = parsePort(sPort) case "tcp": proto = firewall.ProtoTCP startPort, endPort, err = parsePort(sPort) case "udp": proto = firewall.ProtoUDP startPort, endPort, err = parsePort(sPort) case "icmp": proto = firewall.ProtoICMP startPort = firewall.PortAny endPort = firewall.PortAny if sPort != "" { l.WithField("port", sPort).Warn("ignoring port specification for ICMP firewall rule") } default: return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } if err != nil { return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err) } if r.Cidr != "" && r.Cidr != "any" { _, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } if r.LocalCidr != "" && r.LocalCidr != "any" { _, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } } if warning := r.sanity(); warning != nil { l.Warnf("%s rule #%v; %s", table, i, warning) } err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) if err != nil { return fmt.Errorf("%s rule #%v; `%s`", table, i, err) } } return nil } var ErrUnknownNetworkType = errors.New("unknown network type") var ErrPeerRejected = errors.New("remote address is not within a network that we handle") var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks") var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { // Check if we spoke to this tuple, if we did then allow this packet if f.inConns(fp, h, caPool, localCache) { return nil } // Make sure remote address matches nebula certificate, and determine how to treat it if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } else { nwType, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } switch nwType { case NetworkTypeVPN: break // nothing special case NetworkTypeVPNPeer: f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: break // nothing special, one day this may have different FW rules default: f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrUnknownNetworkType //should never happen } } // Make sure we are supposed to be handling this local ip address if !f.routableNetworks.Contains(fp.LocalAddr) { f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } table := f.OutRules if incoming { table = f.InRules } // We now know which firewall table to check against if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { f.metrics(incoming).droppedNoRule.Inc(1) return ErrNoMatchingRule } // We always want to conntrack since it is a faster operation f.addConn(fp, incoming) return nil } func (f *Firewall) metrics(incoming bool) firewallMetrics { if incoming { return f.incomingMetrics } else { return f.outgoingMetrics } } // Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new // firewall object is created func (f *Firewall) Destroy() { //TODO: clean references if/when needed } func (f *Firewall) EmitStats() { conntrack := f.Conntrack conntrack.Lock() conntrackCount := len(conntrack.Conns) conntrack.Unlock() metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount)) metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion)) metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { if _, ok := localCache[fp]; ok { return true } } conntrack := f.Conntrack conntrack.Lock() // Purge every time we test ep, has := conntrack.TimerWheel.Purge() if has { f.evict(ep) } c, ok := conntrack.Conns[fp] if !ok { conntrack.Unlock() return false } if c.rulesVersion != f.rulesVersion { // This conntrack entry was for an older rule set, validate // it still passes with the current rule set table := f.OutRules if c.incoming { table = f.InRules } // We now know which firewall table to check against if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) { if f.l.Level >= logrus.DebugLevel { h.logger(f.l). WithField("fwPacket", fp). WithField("incoming", c.incoming). WithField("rulesVersion", f.rulesVersion). WithField("oldRulesVersion", c.rulesVersion). Debugln("dropping old conntrack entry, does not match new ruleset") } delete(conntrack.Conns, fp) conntrack.Unlock() return false } if f.l.Level >= logrus.DebugLevel { h.logger(f.l). WithField("fwPacket", fp). WithField("incoming", c.incoming). WithField("rulesVersion", f.rulesVersion). WithField("oldRulesVersion", c.rulesVersion). Debugln("keeping old conntrack entry, does match new ruleset") } c.rulesVersion = f.rulesVersion } switch fp.Protocol { case firewall.ProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) case firewall.ProtoUDP: c.Expires = time.Now().Add(f.UDPTimeout) default: c.Expires = time.Now().Add(f.DefaultTimeout) } conntrack.Unlock() if localCache != nil { localCache[fp] = struct{}{} } return true } func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { var timeout time.Duration c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: timeout = f.TCPTimeout case firewall.ProtoUDP: timeout = f.UDPTimeout default: timeout = f.DefaultTimeout } conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Add(fp, timeout) } // Record which rulesVersion allowed this connection, so we can retest after // firewall reload c.incoming = incoming c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c conntrack.Unlock() } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! func (f *Firewall) evict(p firewall.Packet) { // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] if !ok { return } newT := t.Expires.Sub(time.Now()) // Timeout is in the future, re-add the timer if newT > 0 { conntrack.TimerWheel.Advance(time.Now()) conntrack.TimerWheel.Add(p, newT) return } // This conn is done delete(conntrack.Conns, p) } func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if ft.AnyProto.match(p, incoming, c, caPool) { return true } switch p.Protocol { case firewall.ProtoTCP: if ft.TCP.match(p, incoming, c, caPool) { return true } case firewall.ProtoUDP: if ft.UDP.match(p, incoming, c, caPool) { return true } case firewall.ProtoICMP, firewall.ProtoICMPv6: if ft.ICMP.match(p, incoming, c, caPool) { return true } } return false } func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } for i := startPort; i <= endPort; i++ { if _, ok := fp[i]; !ok { fp[i] = &FirewallCA{ CANames: make(map[string]*FirewallRule), CAShas: make(map[string]*FirewallRule), } } if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil { return err } } return nil } func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { // We don't have any allowed ports, bail if fp == nil { return false } // this branch is here to catch traffic from FirewallTable.Any.match and FirewallTable.ICMP.match if p.Protocol == firewall.ProtoICMP || p.Protocol == firewall.ProtoICMPv6 { // port numbers are re-used for connection tracking of ICMP, // but we don't want to actually filter on them. return fp[firewall.PortAny].match(p, c, caPool) } var port int32 if p.Fragment { port = firewall.PortFragment } else if incoming { port = int32(p.LocalPort) } else { port = int32(p.RemotePort) } if fp[port].match(p, c, caPool) { return true } return fp[firewall.PortAny].match(p, c, caPool) } func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), Groups: make([]*firewallGroups, 0), CIDR: new(bart.Table[*firewallLocalCIDR]), } } if caSha == "" && caName == "" { if fc.Any == nil { fc.Any = fr() } return fc.Any.addRule(f, groups, host, cidr, localCidr) } if caSha != "" { if _, ok := fc.CAShas[caSha]; !ok { fc.CAShas[caSha] = fr() } err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr) if err != nil { return err } } if caName != "" { if _, ok := fc.CANames[caName]; !ok { fc.CANames[caName] = fr() } err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr) if err != nil { return err } } return nil } func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if fc == nil { return false } if fc.Any.match(p, c) { return true } if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok { if t.match(p, c) { return true } } s, err := caPool.GetCAForCert(c.Certificate) if err != nil { return false } return fc.CANames[s.Certificate.Name()].match(p, c) } func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ LocalCIDR: new(bart.Lite), } } if fr.isAny(groups, host, cidr) { if fr.Any == nil { fr.Any = flc() } return fr.Any.addRule(f, localCidr) } if len(groups) > 0 { nlc := flc() err := nlc.addRule(f, localCidr) if err != nil { return err } fr.Groups = append(fr.Groups, &firewallGroups{ Groups: groups, LocalCIDR: nlc, }) } if host != "" { nlc := fr.Hosts[host] if nlc == nil { nlc = flc() } err := nlc.addRule(f, localCidr) if err != nil { return err } fr.Hosts[host] = nlc } if cidr != "" { c, err := netip.ParsePrefix(cidr) if err != nil { return err } nlc, _ := fr.CIDR.Get(c) if nlc == nil { nlc = flc() } err = nlc.addRule(f, localCidr) if err != nil { return err } fr.CIDR.Insert(c, nlc) } return nil } func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { if len(groups) == 0 && host == "" && cidr == "" { return true } if slices.Contains(groups, "any") { return true } if host == "any" { return true } if cidr == "any" { return true } return false } func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool { if fr == nil { return false } // Shortcut path for if groups, hosts, or cidr contained an `any` if fr.Any.match(p, c) { return true } // Need any of group, host, or cidr to match for _, sg := range fr.Groups { found := false for _, g := range sg.Groups { if _, ok := c.InvertedGroups[g]; !ok { found = false break } found = true } if found && sg.LocalCIDR.match(p, c) { return true } } if fr.Hosts != nil { if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc.match(p, c) { return true } } } for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) { if v.match(p, c) { return true } } return false } func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { if localCidr == "any" { flc.Any = true return nil } if localCidr == "" { if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { flc.Any = true return nil } for _, network := range f.assignedNetworks { flc.LocalCIDR.Insert(network) } return nil } c, err := netip.ParsePrefix(localCidr) if err != nil { return err } flc.LocalCIDR.Insert(c) return nil } func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { if flc == nil { return false } if flc.Any { return true } return flc.LocalCIDR.Contains(p.LocalAddr) } type rule struct { Port string Code string Proto string Host string Groups []string Cidr string LocalCidr string CAName string CASha string } func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { r := rule{} m, ok := p.(map[string]any) if !ok { return r, errors.New("could not parse rule") } toString := func(k string, m map[string]any) string { v, ok := m[k] if !ok { return "" } return fmt.Sprintf("%v", v) } r.Port = toString("port", m) r.Code = toString("code", m) r.Proto = toString("proto", m) r.Host = toString("host", m) r.Cidr = toString("cidr", m) r.LocalCidr = toString("local_cidr", m) r.CAName = toString("ca_name", m) r.CASha = toString("ca_sha", m) // Make sure group isn't an array if v, ok := m["group"].([]any); ok { if len(v) > 1 { return r, errors.New("group should contain a single value, an array with more than one entry was provided") } l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) m["group"] = v[0] } singleGroup := toString("group", m) if rg, ok := m["groups"]; ok { switch reflect.TypeOf(rg).Kind() { case reflect.Slice: v := reflect.ValueOf(rg) r.Groups = make([]string, v.Len()) for i := 0; i < v.Len(); i++ { r.Groups[i] = v.Index(i).Interface().(string) } case reflect.String: r.Groups = []string{rg.(string)} default: r.Groups = []string{fmt.Sprintf("%v", rg)} } } //flatten group vs groups if singleGroup != "" { // Check if we have both groups and group provided in the rule config if len(r.Groups) > 0 { return r, fmt.Errorf("only one of group or groups should be defined, both provided") } r.Groups = []string{singleGroup} } return r, nil } // sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value // rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr" func (r *rule) sanity() error { //port, proto, local_cidr are AND, no need to check here //ca_sha and ca_name don't have a wildcard value, no need to check here groupsEmpty := len(r.Groups) == 0 hostEmpty := r.Host == "" cidrEmpty := r.Cidr == "" if (groupsEmpty && hostEmpty && cidrEmpty) == true { return nil //no content! } groupsHasAny := slices.Contains(r.Groups, "any") if groupsHasAny && len(r.Groups) > 1 { return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups) } if r.Host == "any" { if !groupsEmpty { return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups) } if !cidrEmpty { return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr) } } if groupsHasAny { if !hostEmpty && r.Host != "any" { return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host) } if !cidrEmpty { return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr) } } if r.Code != "" { return fmt.Errorf("code specified as [%s]. Support for 'code' will be dropped in a future release, as it has never been functional", r.Code) } //todo alert on cidr-any return nil } func parsePort(s string) (int32, int32, error) { var err error const notAPort int32 = -2 if s == "any" { return firewall.PortAny, firewall.PortAny, nil } if s == "fragment" { return firewall.PortFragment, firewall.PortFragment, nil } if !strings.Contains(s, `-`) { rPort, err := strconv.Atoi(s) if err != nil { return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) } return int32(rPort), int32(rPort), nil } sPorts := strings.SplitN(s, `-`, 2) for i := range sPorts { sPorts[i] = strings.Trim(sPorts[i], " ") } if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" { return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) } rStartPort, err := strconv.Atoi(sPorts[0]) if err != nil { return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) } rEndPort, err := strconv.Atoi(sPorts[1]) if err != nil { return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) } startPort := int32(rStartPort) endPort := int32(rEndPort) if startPort == firewall.PortAny { endPort = firewall.PortAny } return startPort, endPort, nil } ================================================ FILE: firewall_test.go ================================================ package nebula import ( "bytes" "errors" "math" "net/netip" "testing" "time" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewFirewall(t *testing.T) { l := test.NewLogger() c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) assert.NotNil(t, conntrack.Conns) assert.NotNil(t, conntrack.TimerWheel) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) assert.Equal(t, time.Second, fw.TCPTimeout) assert.Equal(t, time.Minute, fw.UDPTimeout) assert.Equal(t, time.Hour, fw.DefaultTimeout) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } func TestFirewall_AddRule(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) c := &dummyCert{} fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) ti, err := netip.ParsePrefix("1.2.3.4/32") require.NoError(t, err) ti6, err := netip.ParsePrefix("fd12::34/128") require.NoError(t, err) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) //no matter what port is given for icmp, it should end up as "any" assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.True(t, table.Any) table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.False(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp6, err := netip.ParsePrefix("::/0") require.NoError(t, err) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.True(t, table.Any) table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.False(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } func TestFirewall_Drop(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } c := dummyCert{ name: "host1", networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")}, groups: []string{"default-group"}, issuer: "signer-shasum", } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &cert.CachedCertificate{ Certificate: &c, InvertedGroups: map[string]struct{}{"default-group": {}}, }, }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack require.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_DropV6(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } c := dummyCert{ name: "host1", networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")}, groups: []string{"default-group"}, issuer: "signer-shasum", } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &cert.CachedCertificate{ Certificate: &c, InvertedGroups: map[string]struct{}{"default-group": {}}, }, }, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, } h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack require.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("fd12::56") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { f := &Firewall{} ft := FirewallTable{ TCP: firewallPort{}, } pfix := netip.MustParsePrefix("172.1.1.1/32") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "") pfix6 := netip.MustParsePrefix("fd11::11/128") _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "") _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { // This benchmark is showing us the cost of failing to match the protocol c := &cert.CachedCertificate{ Certificate: &dummyCert{}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)) } }) b.Run("pass proto, fail on port", func(b *testing.B) { // This benchmark is showing us the cost of matching a specific protocol but failing to match the port c := &cert.CachedCertificate{ Certificate: &dummyCert{}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)) } }) b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{}, } ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{}, } ip := netip.MustParsePrefix("fd99::99/128") for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")}, }, InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")}, }, InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")}, }, InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp)) } }) b.Run("pass on group on any local cidr", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", }, InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) b.Run("pass on group on specific local cidr", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", }, InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) b.Run("pass on group on specific local cidr6", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "nope", }, InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp)) } }) b.Run("pass on name", func(b *testing.B) { c := &cert.CachedCertificate{ Certificate: &dummyCert{ name: "good-host", }, InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp) } }) } func TestFirewall_Drop2(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } network := netip.MustParsePrefix("1.2.3.4/24") c := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host1", networks: []netip.Prefix{network}, }, InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, vpnAddrs: []netip.Addr{network.Addr()}, } h.buildNetworks(myVpnNetworksTable, c.Certificate) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host1", networks: []netip.Prefix{network}, }, InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, } h1 := HostInfo{ vpnAddrs: []netip.Addr{network.Addr()}, ConnectionState: &ConnectionState{ peerCert: &c1, }, } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, Fragment: false, } network := netip.MustParsePrefix("1.2.3.4/24") c := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host-owner", networks: []netip.Prefix{network}, }, } c1 := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host1", networks: []netip.Prefix{network}, issuer: "signer-sha-bad", }, } h1 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c1, }, vpnAddrs: []netip.Addr{network.Addr()}, } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host2", networks: []netip.Prefix{network}, issuer: "signer-sha", }, } h2 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c2, }, vpnAddrs: []netip.Addr{network.Addr()}, } h2.buildNetworks(myVpnNetworksTable, c2.Certificate) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host3", networks: []netip.Prefix{network}, issuer: "signer-sha-bad", }, } h3 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c3, }, vpnAddrs: []netip.Addr{network.Addr()}, } h3.buildNetworks(myVpnNetworksTable, c3.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_Drop3V6(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, Fragment: false, } network := netip.MustParsePrefix("fd12::34/120") c := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host-owner", networks: []netip.Prefix{network}, }, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, vpnAddrs: []netip.Addr{network.Addr()}, } h.buildNetworks(myVpnNetworksTable, c.Certificate) // Test a remote address match fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } network := netip.MustParsePrefix("1.2.3.4/24") c := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host1", networks: []netip.Prefix{network}, groups: []string{"default-group"}, issuer: "signer-shasum", }, InvertedGroups: map[string]struct{}{"default-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, vpnAddrs: []netip.Addr{network.Addr()}, } h.buildNetworks(myVpnNetworksTable, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } func TestFirewall_ICMPPortBehavior(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) network := netip.MustParsePrefix("1.2.3.4/24") c := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host1", networks: []netip.Prefix{network}, groups: []string{"default-group"}, issuer: "signer-shasum", }, InvertedGroups: map[string]struct{}{"default-group": {}}, } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c, }, vpnAddrs: []netip.Addr{network.Addr()}, } h.buildNetworks(myVpnNetworksTable, c.Certificate) cp := cert.NewCAPool() templ := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), Protocol: firewall.ProtoICMP, Fragment: false, } t.Run("ICMP allowed", func(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0 p.RemotePort = 0 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) //now also allow outbound require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) }) t.Run("nonzero ports", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) //now also allow outbound require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) }) }) t.Run("Any proto, some ports allowed", func(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, still blocked", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0 p.RemotePort = 0 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero ports, still blocked", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { p := templ.Copy() p.LocalPort = 80 p.RemotePort = 80 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) }) }) t.Run("Any proto, any port", func(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, allowed", func(t *testing.T) { resetConntrack(fw) p := templ.Copy() p.LocalPort = 0 p.RemotePort = 0 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) //now also allow outbound require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) }) t.Run("nonzero ports, allowed", func(t *testing.T) { resetConntrack(fw) p := templ.Copy() p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) //now also allow outbound require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) //different ID is blocked p.RemotePort++ require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) }) }) } func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myVpnNetworksTable := new(bart.Lite) myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) c := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host-owner", networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")}, }, } c1 := cert.CachedCertificate{ Certificate: &dummyCert{ name: "host", networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}, }, } h1 := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &c1, }, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) cp := cert.NewCAPool() // Packet spoofed by `c1`. Note that the remote addr is not a valid one. p := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.0.2.1"), RemoteAddr: netip.MustParseAddr("192.0.2.3"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, Fragment: false, } assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) } func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { for _, sg := range a { found := false for _, g := range sg { if _, ok := m[g]; !ok { found = false break } found = true } if found { return } } } } b.Run("array to map best", func(b *testing.B) { m := map[string]struct{}{ "1ne": {}, "2wo": {}, "3hr": {}, "4ou": {}, "5iv": {}, "6ix": {}, } a := [][]string{ {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"}, {"one", "2wo", "3hr", "4ou", "5iv", "6ix"}, {"one", "two", "3hr", "4ou", "5iv", "6ix"}, {"one", "two", "thr", "4ou", "5iv", "6ix"}, {"one", "two", "thr", "fou", "5iv", "6ix"}, {"one", "two", "thr", "fou", "fiv", "6ix"}, {"one", "two", "thr", "fou", "fiv", "six"}, } for n := 0; n < b.N; n++ { ml(m, a) } }) b.Run("array to map worst", func(b *testing.B) { m := map[string]struct{}{ "one": {}, "two": {}, "thr": {}, "fou": {}, "fiv": {}, "six": {}, } a := [][]string{ {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"}, {"one", "2wo", "3hr", "4ou", "5iv", "6ix"}, {"one", "two", "3hr", "4ou", "5iv", "6ix"}, {"one", "two", "thr", "4ou", "5iv", "6ix"}, {"one", "two", "thr", "fou", "5iv", "6ix"}, {"one", "two", "thr", "fou", "fiv", "6ix"}, {"one", "two", "thr", "fou", "fiv", "six"}, } for n := 0; n < b.N; n++ { ml(m, a) } }) } func Test_parsePort(t *testing.T) { _, _, err := parsePort("") require.EqualError(t, err, "was not a number; ``") _, _, err = parsePort(" ") require.EqualError(t, err, "was not a number; ` `") _, _, err = parsePort("-") require.EqualError(t, err, "appears to be a range but could not be parsed; `-`") _, _, err = parsePort(" - ") require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") _, _, err = parsePort("a-b") require.EqualError(t, err, "beginning range was not a number; `a`") _, _, err = parsePort("1-b") require.EqualError(t, err, "ending range was not a number; `b`") s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) require.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) require.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) require.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) require.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) conf := config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { l := test.NewLogger() // Test adding tcp rule conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule no port conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding rule with cidr cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) // Test adding rule with cidr ipv6 cidr6 := netip.MustParsePrefix("fd00::/8") conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) // Test adding rule with any cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) // Test adding rule with junk cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with local_cidr ipv6 conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) // Test adding rule with any local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) // Test adding rule with junk local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error conf = config.NewC(l) mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) // Ensure group array of 1 is converted and a warning is printed c := map[string]any{ "group": []any{"group1"}, } r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") require.NoError(t, err) assert.Equal(t, []string{"group1"}, r.Groups) // Ensure group array of > 1 is errord ob.Reset() c = map[string]any{ "group": []any{"group1", "group2"}, } r, err = convertRule(l, c, "test", 1) assert.Empty(t, ob.String()) require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright ob.Reset() c = map[string]any{ "group": "group1", } r, err = convertRule(l, c, "test", 1) require.NoError(t, err) assert.Equal(t, []string{"group1"}, r.Groups) } func TestFirewall_convertRuleSanity(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) noWarningPlease := []map[string]any{ {"group": "group1"}, {"groups": []any{"group2"}}, {"host": "bob"}, {"cidr": "1.1.1.1/1"}, {"groups": []any{"group2"}, "host": "bob"}, {"cidr": "1.1.1.1/1", "host": "bob"}, {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"}, {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"}, } for _, c := range noWarningPlease { r, err := convertRule(l, c, "test", 1) require.NoError(t, err) require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c) } yesWarningPlease := []map[string]any{ {"group": "group1"}, {"groups": []any{"group2"}}, {"cidr": "1.1.1.1/1"}, {"groups": []any{"group2"}, "host": "bob"}, {"cidr": "1.1.1.1/1", "host": "bob"}, {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"}, {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"}, } for _, c := range yesWarningPlease { c["host"] = "any" r, err := convertRule(l, c, "test", 1) require.NoError(t, err) err = r.sanity() require.Error(t, err, "I wanted a warning: %+v", c) } //reset the list yesWarningPlease = []map[string]any{ {"group": "group1"}, {"groups": []any{"group2"}}, {"cidr": "1.1.1.1/1"}, {"groups": []any{"group2"}, "host": "bob"}, {"cidr": "1.1.1.1/1", "host": "bob"}, {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"}, {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"}, } for _, c := range yesWarningPlease { r, err := convertRule(l, c, "test", 1) require.NoError(t, err) r.Groups = append(r.Groups, "any") err = r.sanity() require.Error(t, err, "I wanted a warning: %+v", c) } } type testcase struct { h *HostInfo p firewall.Packet c cert.Certificate err error } func (c *testcase) Test(t *testing.T, fw *Firewall) { t.Helper() cp := cert.NewCAPool() resetConntrack(fw) err := fw.Drop(c.p, true, c.h, cp, nil) if c.err == nil { require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) } else { require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr) } } func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { c1 := dummyCert{ name: "host1", networks: theirPrefixes, groups: []string{"default-group"}, issuer: "signer-shasum", } h := HostInfo{ ConnectionState: &ConnectionState{ peerCert: &cert.CachedCertificate{ Certificate: &c1, InvertedGroups: map[string]struct{}{"default-group": {}}, }, }, vpnAddrs: make([]netip.Addr, len(theirPrefixes)), } for i := range theirPrefixes { h.vpnAddrs[i] = theirPrefixes[i].Addr() } h.buildNetworks(setup.myVpnNetworksTable, &c1) p := firewall.Packet{ LocalAddr: setup.c.Networks()[0].Addr(), //todo? RemoteAddr: theirPrefixes[0].Addr(), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, Fragment: false, } return testcase{ h: &h, p: p, c: &c1, err: err, } } type testsetup struct { c dummyCert myVpnNetworksTable *bart.Lite fw *Firewall } func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { c := dummyCert{ name: "me", networks: myPrefixes, groups: []string{"default-group"}, issuer: "signer-shasum", } return newSetupFromCert(t, l, c) } func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) } fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) return testsetup{ c: c, fw: fw, myVpnNetworksTable: myVpnNetworksTable, } } func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { t.Parallel() l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) myPrefix := netip.MustParsePrefix("1.1.1.1/8") // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out t.Run("allow inbound all matching", func(t *testing.T) { t.Parallel() setup := newSetup(t, l, myPrefix) tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24")) tc.Test(t, setup.fw) }) t.Run("allow inbound local matching", func(t *testing.T) { t.Parallel() setup := newSetup(t, l, myPrefix) tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24")) tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8") tc.Test(t, setup.fw) }) t.Run("block inbound remote mismatched", func(t *testing.T) { t.Parallel() setup := newSetup(t, l, myPrefix) tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24")) tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") tc.Test(t, setup.fw) }) t.Run("Block a vpn peer packet", func(t *testing.T) { t.Parallel() setup := newSetup(t, l, myPrefix) tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24")) tc.Test(t, setup.fw) }) twoPrefixes := []netip.Prefix{ netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"), } t.Run("allow inbound one matching", func(t *testing.T) { t.Parallel() setup := newSetup(t, l, myPrefix) tc := buildTestCase(setup, nil, twoPrefixes...) tc.Test(t, setup.fw) }) t.Run("block inbound multimismatch", func(t *testing.T) { t.Parallel() setup := newSetup(t, l, myPrefix) tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...) tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") tc.Test(t, setup.fw) }) t.Run("allow inbound 2nd one matching", func(t *testing.T) { t.Parallel() setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24")) tc := buildTestCase(setup2, nil, twoPrefixes...) tc.p.RemoteAddr = twoPrefixes[1].Addr() tc.Test(t, setup2.fw) }) t.Run("allow inbound unsafe route", func(t *testing.T) { t.Parallel() unsafePrefix := netip.MustParsePrefix("192.168.0.0/24") c := dummyCert{ name: "me", networks: []netip.Prefix{myPrefix}, unsafeNetworks: []netip.Prefix{unsafePrefix}, groups: []string{"default-group"}, issuer: "signer-shasum", } unsafeSetup := newSetupFromCert(t, l, c) tc := buildTestCase(unsafeSetup, nil, twoPrefixes...) tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") tc.err = ErrNoMatchingRule tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) tc.err = nil tc.Test(t, unsafeSetup.fw) //should pass }) } type addRuleCall struct { incoming bool proto uint8 startPort int32 endPort int32 groups []string host string ip string localIp string caName string caSha string } type mockFirewall struct { lastCall addRuleCall nextCallReturn error } func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, startPort: startPort, endPort: endPort, groups: groups, host: host, ip: ip, localIp: localIp, caName: caName, caSha: caSha, } err := mf.nextCallReturn mf.nextCallReturn = nil return err } func resetConntrack(fw *Firewall) { fw.Conntrack.Lock() fw.Conntrack.Conns = map[firewall.Packet]*conn{} fw.Conntrack.Unlock() } ================================================ FILE: go.mod ================================================ module github.com/slackhq/nebula go 1.25 require ( dario.cat/mergo v1.0.2 filippo.io/bigmod v0.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 github.com/gaissmai/bart v0.26.0 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 github.com/miekg/dns v1.1.70 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.4 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.47.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/sys v0.40.0 golang.org/x/term v0.39.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/mod v0.31.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.40.0 // indirect ) ================================================ FILE: go.sum ================================================ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= filippo.io/bigmod v0.1.0 h1:UNzDk7y9ADKST+axd9skUpBQeW7fG2KrTZyOE4uGQy8= filippo.io/bigmod v0.1.0/go.mod h1:OjOXDNlClLblvXdwgFFOQFJEocLhhtai8vGLy0JCZlI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps= github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk= github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.1.70 h1:DZ4u2AV35VJxdD9Fo9fIWm119BsQL5cZU1cQ9s0LkqA= github.com/miekg/dns v1.1.70/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg= github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f/go.mod h1:nwPd6pDNId/Xi16qtKrFHrauSwMNuvk+zcjk89wrnlA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe h1:fre4i6mv4iBuz5lCMOzHD1rH1ljqHWSICFmZRbbgp3g= gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= ================================================ FILE: handshake_ix.go ================================================ package nebula import ( "bytes" "net/netip" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) // NOISE IX Handshakes // This function constructs a handshake packet, but does not actually send it // Sending is done by the handshake manager func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } cs := f.pki.getCertState() v := cs.initiatingVersion if hh.initiatingVersionOverride != cert.VersionPre1 { v = hh.initiatingVersionOverride } else if v < cert.Version2 { // If we're connecting to a v6 address we should encourage use of a V2 cert for _, a := range hh.hostinfo.vpnAddrs { if a.Is6() { v = cert.Version2 break } } } crt := cs.getCertificate(v) if crt == nil { f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", v). Error("Unable to handshake with host because no certificate is available") return false } crtHs := cs.getHandshakeBytes(v) if crtHs == nil { f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", v). Error("Unable to handshake with host because no certificate handshake bytes is available") return false } ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) if err != nil { f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", v). Error("Failed to create connection state") return false } hh.hostinfo.ConnectionState = ci hs := &NebulaHandshake{ Details: &NebulaHandshakeDetails{ InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), Cert: crtHs, CertVersion: uint32(v), }, } hsBytes, err := hs.Marshal() if err != nil { f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1) msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } // We are sending handshake packet 1, so we don't expect to receive // handshake packet 1 from the responder ci.window.Update(f.l, 1) hh.hostinfo.HandshakePacket[0] = msg hh.ready = true return true } func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { cs := f.pki.getCertState() crt := cs.GetDefaultCertificate() if crt == nil { f.l.WithField("from", via). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", cs.initiatingVersion). Error("Unable to handshake with host because no certificate is available") return } ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) if err != nil { f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to create connection state") return } // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to call noise.ReadMessage") return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed unmarshal handshake message") return } rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake did not contain a certificate") return } remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) if err != nil { fp, fperr := rc.Fingerprint() if fperr != nil { fp = "" } e := f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("certVpnNetworks", rc.Networks()). WithField("certFingerprint", fp) if f.l.Level >= logrus.DebugLevel { e = e.WithField("cert", rc) } e.Info("Invalid certificate from host") return } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { f.l.WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") return } if remoteCert.Certificate.Version() != ci.myCert.Version() { // We started off using the wrong certificate version, lets see if we can match the version that was sent to us myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) if myCertOtherVersion == nil { if f.l.Level >= logrus.DebugLevel { f.l.WithError(err).WithFields(m{ "from": via, "handshake": m{"stage": 1, "style": "ix_psk0"}, "cert": remoteCert, }).Debug("Might be unable to handshake with host due to missing certificate version") } } else { // Record the certificate we are actually using ci.myCert = myCertOtherVersion } } if len(remoteCert.Certificate.Networks()) == 0 { f.l.WithError(err).WithField("from", via). WithField("cert", remoteCert). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("No networks in certificate") return } certName := remoteCert.Certificate.Name() certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() vpnNetworks := remoteCert.Certificate.Networks() anyVpnAddrsInCommon := false vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(network.Addr()) { f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } vpnAddrs[i] = network.Addr() if f.myVpnNetworksTable.Contains(network.Addr()) { anyVpnAddrsInCommon = true } } if !via.IsRelayed { // We only want to apply the remote allow list for direct tunnels here if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). Debug("lighthouse.remote_allow_list denied incoming handshake") return } } myIndex, err := generateIndex(f.l) if err != nil { f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") return } hostinfo := &HostInfo{ ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } msgRxL := f.l.WithFields(m{ "vpnAddrs": vpnAddrs, "from": via, "certName": certName, "certVersion": certVersion, "fingerprint": fingerprint, "issuer": issuer, "initiatorIndex": hs.Details.InitiatorIndex, "responderIndex": hs.Details.ResponderIndex, "remoteIndex": h.RemoteIndex, "handshake": m{"stage": 1, "style": "ix_psk0"}, }) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") } else { //todo warn if not lighthouse or relay? msgRxL.Info("Handshake message received, but no vpnNetworks in common.") } hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) if hs.Details.Cert == nil { msgRxL.WithField("myCertVersion", ci.myCert.Version()). Error("Unable to handshake with host because no certificate handshake bytes is available") return } hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) hsBytes, err := hs.Marshal() if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return } nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key") return } hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:])) copy(hostinfo.HandshakePacket[0], packet[header.Len:]) // Regardless of whether you are the sender or receiver, you should arrive here // and complete standing up the connection. hostinfo.HandshakePacket[2] = make([]byte, len(msg)) copy(hostinfo.HandshakePacket[2], msg) // We are sending handshake packet 2, so we don't expect to receive // handshake packet 2 from the initiator. ci.window.Update(f.l, 2) ci.peerCert = remoteCert ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) if !via.IsRelayed { hostinfo.SetRemote(via.UdpAddr) } hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { switch err { case ErrAlreadySeen: // Update remote if preferred if existing.SetRemoteIfPreferred(f.hostMap, via) { // Send a test packet to ensure the other side has also switched to // the preferred remote f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if !via.IsRelayed { err := f.outside.WriteTo(msg, via.UdpAddr) if err != nil { f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } return } else { if via.relay == nil { f.l.Error("Handshake send failed: both addr and via.relay are nil.") return } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). Error("Failed to add HostInfo due to localIndex collision") return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to add HostInfo to HostMap") return } } // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) if !via.IsRelayed { err = f.outside.WriteTo(msg, via.UdpAddr) log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if err != nil { log.WithError(err).Error("Failed to send handshake") } else { log.Info("Handshake message sent") } } else { if via.relay == nil { f.l.Error("Handshake send failed: both addr and via.relay are nil.") return } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure // it's correctly marked as working. via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Handshake message sent") } f.connectionManager.AddTrafficWatch(hostinfo) hostinfo.remotes.RefreshFromHandshake(vpnAddrs) return } func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true } hh.Lock() defer hh.Unlock() hostinfo := hh.hostinfo if !via.IsRelayed { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") // We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying // to DOS us. Every other error condition after should to allow a possible good handshake to complete in the // near future return false } else if dKey == nil || eKey == nil { f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") // This should be impossible in IX but just in case, if we get here then there is no chance to recover // the handshake state machine. Tear it down return true } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Handshake did not contain a certificate") return true } remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) if err != nil { fp, err := rc.Fingerprint() if err != nil { fp = "" } e := f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("certFingerprint", fp). WithField("certVpnNetworks", rc.Networks()) if f.l.Level >= logrus.DebugLevel { e = e.WithField("cert", rc) } e.Info("Invalid certificate from host") return true } if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) { f.l.WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake") return true } if len(remoteCert.Certificate.Networks()) == 0 { f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("cert", remoteCert). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("No networks in certificate") return true } vpnNetworks := remoteCert.Certificate.Networks() certName := remoteCert.Certificate.Name() certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() hostinfo.remoteIndexId = hs.Details.ResponderIndex hostinfo.lastHandshakeTime = hs.Details.Time // Store their cert and our symmetric keys ci.peerCert = remoteCert ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding if !via.IsRelayed { hostinfo.SetRemote(via.UdpAddr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } correctHostResponded := false anyVpnAddrsInCommon := false vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for i, network := range vpnNetworks { vpnAddrs[i] = network.Addr() if f.myVpnNetworksTable.Contains(network.Addr()) { anyVpnAddrsInCommon = true } if hostinfo.vpnAddrs[0] == network.Addr() { // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? correctHostResponded = true } } // Ensure the right host responded if !correctHostResponded { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") // Release our old handshake from pending, it should not continue f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip //TODO is hostinfo.vpnAddrs[0] always the address to use? f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(via) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). WithField("vpnNetworks", vpnNetworks). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") // Swap the packet store to benefit the original intended recipient newHH.packetStore = hh.packetStore hh.packetStore = []*cachedPacket{} // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down hostinfo.vpnAddrs = vpnAddrs f.sendCloseTunnel(hostinfo) }) return true } // Mark packet 2 as seen so it doesn't show up as missed ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("durationNs", duration). WithField("sentCachedPackets", len(hh.packetStore)) if anyVpnAddrsInCommon { msgRxL.Info("Handshake message received") } else { //todo warn if not lighthouse or relay? msgRxL.Info("Handshake message received, but no vpnNetworks in common.") } // Build up the radix for the firewall if we have subnets in the cert hostinfo.vpnAddrs = vpnAddrs hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) } if len(hh.packetStore) > 0 { nb := make([]byte, 12, 12) out := make([]byte, mtu) for _, cp := range hh.packetStore { cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out) } f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) } hostinfo.remotes.RefreshFromHandshake(vpnAddrs) f.metricHandshakes.Update(duration) return false } ================================================ FILE: handshake_manager.go ================================================ package nebula import ( "bytes" "context" "crypto/rand" "encoding/binary" "errors" "net/netip" "slices" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) const ( DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 DefaultUseRelays = true ) var ( defaultHandshakeConfig = HandshakeConfig{ tryInterval: DefaultHandshakeTryInterval, retries: DefaultHandshakeRetries, triggerBuffer: DefaultHandshakeTriggerBuffer, useRelays: DefaultUseRelays, } ) type HandshakeConfig struct { tryInterval time.Duration retries int64 triggerBuffer int useRelays bool messageMetrics *MessageMetrics } type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex vpnIps map[netip.Addr]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn config HandshakeConfig OutboundHandshakeTimer *LockingTimerWheel[netip.Addr] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter f *Interface l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp trigger chan netip.Addr } type HandshakeHostInfo struct { sync.Mutex startTime time.Time // Time that we first started trying with this handshake ready bool // Is the handshake ready initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake? counter int64 // How many attempts have we made so far lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { if len(hh.packetStore) < 100 { tempPacket := make([]byte, len(packet)) copy(tempPacket, packet) hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket}) if l.Level >= logrus.DebugLevel { hh.hostinfo.logger(l). WithField("length", len(hh.packetStore)). WithField("stored", true). Debugf("Packet store") } } else { m.dropped.Inc(1) if l.Level >= logrus.DebugLevel { hh.hostinfo.logger(l). WithField("length", len(hh.packetStore)). WithField("stored", false). Debugf("Packet store") } } } func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, trigger: make(chan netip.Addr, config.triggerBuffer), OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), l: l, } } func (hm *HandshakeManager) Run(ctx context.Context) { clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() for { select { case <-ctx.Done(): return case vpnIP := <-hm.trigger: hm.handleOutbound(vpnIP, true) case now := <-clockSource.C: hm.NextOutboundHandshakeTimerTick(now) } } } func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp if !via.IsRelayed { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } switch h.Subtype { case header.HandshakeIXPSK0: switch h.MessageCounter { case 1: ixHandshakeStage1(hm.f, via, packet, h) case 2: newHostinfo := hm.queryIndex(h.RemoteIndex) tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) if tearDown && newHostinfo != nil { hm.DeleteHostInfo(newHostinfo.hostinfo) } } } } func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { hm.OutboundHandshakeTimer.Advance(now) for { vpnIp, has := hm.OutboundHandshakeTimer.Purge() if !has { break } hm.handleOutbound(vpnIp, false) } } func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) { hh := hm.queryVpnIp(vpnIp) if hh == nil { return } hh.Lock() defer hh.Unlock() hostinfo := hh.hostinfo // If we are out of time, clean up if hh.counter >= hm.config.retries { hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())). WithField("initiatorIndex", hh.hostinfo.localIndexId). WithField("remoteIndex", hh.hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("durationNs", time.Since(hh.startTime).Nanoseconds()). Info("Handshake timed out") hm.metricTimedOut.Inc(1) hm.DeleteHostInfo(hostinfo) return } // Increment the counter to increase our delay, linear backoff hh.counter++ // Check if we have a handshake packet to transmit yet if !hh.ready { if !ixHandshakeStage0(hm.f, hh) { hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) return } } // Get a remotes object if we don't already have one. // This is mainly to protect us as this should never be the case // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. if lighthouseTriggered && !remotesHaveChanged { // If we didn't return here a lighthouse could cause us to aggressively send handshakes return } hh.lastRemotes = remotes // This will generate a load of queries for hosts with only 1 ip // (such as ones registered to the lighthouse with only a private IP) // So we only do it one time after attempting 5 handshakes already. if len(remotes) <= 1 && hh.counter == 5 { // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter hm.lightHouse.QueryServer(vpnIp) } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []netip.AddrPort hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { hostinfo.logger(hm.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake message") } else { sentTo = append(sentTo, addr) } }) // Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout, // so only log when the list of remotes has changed if remotesHaveChanged { hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") } else if hm.l.Level >= logrus.DebugLevel { hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Debug("Handshake message sent") } if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay through the host I'm trying to connect to if relay == vpnIp { continue } // Don't relay to myself if hm.f.myVpnAddrsTable.Contains(relay) { continue } relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hm.f.Handshake(relay) continue } // Check the relay HostInfo to see if we already established a relay through existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) if !ok { // No relays exist or requested yet. if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, } switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !hm.f.myVpnAddrs[0].Is4() { hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !vpnIp.Is4() { hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } b := hm.f.myVpnAddrs[0].As4() m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) b = vpnIp.As4() m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) case cert.Version2: m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") continue } msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": relay}). Info("send CreateRelayRequest") } } continue } switch existingRelay.State { case Established: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) fallthrough case Requested: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, } switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !hm.f.myVpnAddrs[0].Is4() { hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !vpnIp.Is4() { hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } b := hm.f.myVpnAddrs[0].As4() m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) b = vpnIp.As4() m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) case cert.Version2: m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") continue } msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). WithError(err). Error("Failed to marshal Control message to create relay") } else { // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, "relay": relay}). Info("send CreateRelayRequest") } case PeerRequested: // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. fallthrough default: hostinfo.logger(hm.l). WithField("vpnIp", vpnIp). WithField("state", existingRelay.State). WithField("relay", relay). Errorf("Relay unexpected state") } } } // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) } } // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { hm.mainHostMap.RLock() h, ok := hm.mainHostMap.Hosts[vpnIp] hm.mainHostMap.RUnlock() if ok { // Do not attempt promotion if you are a lighthouse if !hm.lightHouse.amLighthouse { h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f) } return h, true } return hm.StartHandshake(vpnIp, cacheCb), false } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() if hh, ok := hm.vpnIps[vpnAddr]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) } hm.Unlock() return hh.hostinfo } hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ relays: nil, relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } hh := &HandshakeHostInfo{ hostinfo: hostinfo, startTime: time.Now(), } hm.vpnIps[vpnAddr] = hh hm.metricInitiated.Inc(1) hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval) if cacheCb != nil { cacheCb(hh) } // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr] if !doTrigger { // Add any calculated remotes, and trigger early handshake if one found doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr) } if doTrigger { select { case hm.trigger <- vpnAddr: default: } } hm.Unlock() hm.lightHouse.QueryServer(vpnAddr) return hostinfo } var ( ErrExistingHostInfo = errors.New("existing hostinfo") ErrAlreadySeen = errors.New("already seen") ErrLocalIndexCollision = errors.New("local index collision") ) // CheckAndComplete checks for any conflicts in the main and pending hostmap // before adding hostinfo to main. If err is nil, it was added. Otherwise err will be: // // ErrAlreadySeen if we already have an entry in the hostmap that has seen the // exact same handshake packet // // ErrExistingHostInfo if we already have an entry in the hostmap for this // VpnIp and the new handshake was older than the one we currently have // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { hm.mainHostMap.Lock() defer hm.mainHostMap.Unlock() hm.Lock() defer hm.Unlock() // Check if we already have a tunnel with this vpn ip existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]] if found && existingHostInfo != nil { testHostInfo := existingHostInfo for testHostInfo != nil { // Is it just a delayed handshake packet? if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) { return testHostInfo, ErrAlreadySeen } testHostInfo = testHostInfo.next } // Is this a newer handshake? if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime && !existingHostInfo.ConnectionState.initiator { return existingHostInfo, ErrExistingHostInfo } existingHostInfo.logger(hm.l).Info("Taking new handshake") } existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId] if found { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } existingPendingIndex, found := hm.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingPendingIndex.hostinfo, ErrLocalIndexCollision } existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } // Complete is a simpler version of CheckAndComplete when we already know we // won't have a localIndexId collision because we already have an entry in the // pendingHostMap. An existing hostinfo is returned if there was one. func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { hm.mainHostMap.Lock() defer hm.mainHostMap.Unlock() hm.Lock() defer hm.Unlock() existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] if found && existingRemoteIndex != nil { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. hm.unlockedDeleteHostInfo(hostinfo) hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) } // allocateIndex generates a unique localIndexId for this HostInfo // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { hm.mainHostMap.RLock() defer hm.mainHostMap.RUnlock() hm.Lock() defer hm.Unlock() for range 32 { index, err := generateIndex(hm.l) if err != nil { return err } _, inPending := hm.indexes[index] _, inMain := hm.mainHostMap.Indexes[index] if !inMain && !inPending { hh.hostinfo.localIndexId = index hm.indexes[index] = hh return nil } } return errors.New("failed to generate unique localIndexId") } func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { hm.Lock() defer hm.Unlock() hm.unlockedDeleteHostInfo(hostinfo) } func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { for _, addr := range hostinfo.vpnAddrs { delete(hm.vpnIps, addr) } if len(hm.vpnIps) == 0 { hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } delete(hm.indexes, hostinfo.localIndexId) if len(hm.indexes) == 0 { hm.indexes = map[uint32]*HandshakeHostInfo{} } if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Pending hostmap hostInfo deleted") } } func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo } return nil } func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.vpnIps[vpnIp] } func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo { hh := hm.queryIndex(index) if hh != nil { return hh.hostinfo } return nil } func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.indexes[index] } func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix { return hm.mainHostMap.GetPreferredRanges() } func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() for _, v := range hm.vpnIps { f(v.hostinfo) } } func (hm *HandshakeManager) ForEachIndex(f controlEach) { hm.RLock() defer hm.RUnlock() for _, v := range hm.indexes { f(v.hostinfo) } } func (hm *HandshakeManager) EmitStats() { hm.RLock() hostLen := len(hm.vpnIps) indexLen := len(hm.indexes) hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) hm.mainHostMap.EmitStats() } // Utility functions below func generateIndex(l *logrus.Logger) (uint32, error) { b := make([]byte, 4) // Let zero mean we don't know the ID, so don't generate zero var index uint32 for index == 0 { _, err := rand.Read(b) if err != nil { l.Errorln(err) return 0, err } index = binary.BigEndian.Uint32(b) } if l.Level >= logrus.DebugLevel { l.WithField("index", index). Debug("Generated index") } return index, nil } func hsTimeout(tries int64, interval time.Duration) time.Duration { return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) } ================================================ FILE: handshake_manager_test.go ================================================ package nebula import ( "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() localrange := netip.MustParsePrefix("10.1.1.1/24") ip := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} mainHM := newHostMap(l) mainHM.preferredRanges.Store(&preferredRanges) lh := newTestLighthouse() cs := &CertState{ initiatingVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l} blah.f.pki.cs.Store(cs) now := time.Now() blah.NextOutboundHandshakeTimerTick(now) i := blah.StartHandshake(ip, nil) i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap assert.Empty(t, mainHM.Hosts) // Confirm they are in the pending index list assert.Contains(t, blah.vpnIps, ip) // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right for i := 1; i <= DefaultHandshakeRetries+1; i++ { now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval) blah.NextOutboundHandshakeTimerTick(now) } // Confirm they are still in the pending index list assert.Contains(t, blah.vpnIps, ip) // Tick 1 more time, a minute will certainly flush it out blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute)) // Confirm they have been removed assert.NotContains(t, blah.vpnIps, ip) } func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head for n != nil { c++ n = n.Next } } return c } type mockEncWriter struct { } func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { return } func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { return } func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { return } func (mw *mockEncWriter) Handshake(_ netip.Addr) {} func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { return nil } func (mw *mockEncWriter) GetCertState() *CertState { return &CertState{initiatingVersion: cert.Version2} } ================================================ FILE: header/header.go ================================================ package header import ( "encoding/binary" "encoding/json" "errors" "fmt" ) //Version 1 header: // 0 31 // |-----------------------------------------------------------------------| // | Version (uint4) | Type (uint4) | Subtype (uint8) | Reserved (uint16) | 32 // |-----------------------------------------------------------------------| // | Remote index (uint32) | 64 // |-----------------------------------------------------------------------| // | Message counter | 96 // | (uint64) | 128 // |-----------------------------------------------------------------------| // | payload... | type m = map[string]any const ( Version uint8 = 1 Len = 16 ) type MessageType uint8 type MessageSubType uint8 const ( Handshake MessageType = 0 Message MessageType = 1 RecvError MessageType = 2 LightHouse MessageType = 3 Test MessageType = 4 CloseTunnel MessageType = 5 Control MessageType = 6 ) var typeMap = map[MessageType]string{ Handshake: "handshake", Message: "message", RecvError: "recvError", LightHouse: "lightHouse", Test: "test", CloseTunnel: "closeTunnel", Control: "control", } const ( MessageNone MessageSubType = 0 MessageRelay MessageSubType = 1 ) const ( TestRequest MessageSubType = 0 TestReply MessageSubType = 1 ) const ( HandshakeIXPSK0 MessageSubType = 0 HandshakeXXPSK0 MessageSubType = 1 ) var ErrHeaderTooShort = errors.New("header is too short") var subTypeTestMap = map[MessageSubType]string{ TestRequest: "testRequest", TestReply: "testReply", } var subTypeNoneMap = map[MessageSubType]string{0: "none"} var subTypeMap = map[MessageType]*map[MessageSubType]string{ Message: { MessageNone: "none", MessageRelay: "relay", }, RecvError: &subTypeNoneMap, LightHouse: &subTypeNoneMap, Test: &subTypeTestMap, CloseTunnel: &subTypeNoneMap, Handshake: { HandshakeIXPSK0: "ix_psk0", }, Control: &subTypeNoneMap, } type H struct { Version uint8 Type MessageType Subtype MessageSubType Reserved uint16 RemoteIndex uint32 MessageCounter uint64 } // Encode uses the provided byte array to encode the provided header values into. // Byte array must be capped higher than HeaderLen or this will panic func Encode(b []byte, v uint8, t MessageType, st MessageSubType, ri uint32, c uint64) []byte { b = b[:Len] b[0] = v<<4 | byte(t&0x0f) b[1] = byte(st) binary.BigEndian.PutUint16(b[2:4], 0) binary.BigEndian.PutUint32(b[4:8], ri) binary.BigEndian.PutUint64(b[8:16], c) return b } // String creates a readable string representation of a header func (h *H) String() string { if h == nil { return "" } return fmt.Sprintf("ver=%d type=%s subtype=%s reserved=%#x remoteindex=%v messagecounter=%v", h.Version, h.TypeName(), h.SubTypeName(), h.Reserved, h.RemoteIndex, h.MessageCounter) } // MarshalJSON creates a json string representation of a header func (h *H) MarshalJSON() ([]byte, error) { return json.Marshal(m{ "version": h.Version, "type": h.TypeName(), "subType": h.SubTypeName(), "reserved": h.Reserved, "remoteIndex": h.RemoteIndex, "messageCounter": h.MessageCounter, }) } // Encode turns header into bytes func (h *H) Encode(b []byte) ([]byte, error) { if h == nil { return nil, errors.New("nil header") } return Encode(b, h.Version, h.Type, h.Subtype, h.RemoteIndex, h.MessageCounter), nil } // Parse is a helper function to parses given bytes into new Header struct func (h *H) Parse(b []byte) error { if len(b) < Len { return ErrHeaderTooShort } // get upper 4 bytes h.Version = uint8((b[0] >> 4) & 0x0f) // get lower 4 bytes h.Type = MessageType(b[0] & 0x0f) h.Subtype = MessageSubType(b[1]) h.Reserved = binary.BigEndian.Uint16(b[2:4]) h.RemoteIndex = binary.BigEndian.Uint32(b[4:8]) h.MessageCounter = binary.BigEndian.Uint64(b[8:16]) return nil } // TypeName will transform the headers message type into a human string func (h *H) TypeName() string { return TypeName(h.Type) } // TypeName will transform a nebula message type into a human string func TypeName(t MessageType) string { if n, ok := typeMap[t]; ok { return n } return "unknown" } // SubTypeName will transform the headers message sub type into a human string func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } // SubTypeName will transform a nebula message sub type into a human string func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { if x, ok := (*n)[s]; ok { return x } } return "unknown" } // NewHeader turns bytes into a header func NewHeader(b []byte) (*H, error) { h := new(H) if err := h.Parse(b); err != nil { return nil, err } return h, nil } ================================================ FILE: header/header_test.go ================================================ package header import ( "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type headerTest struct { expectedBytes []byte *H } // 0001 0010 00010010 var headerBigEndianTests = []headerTest{{ expectedBytes: []byte{0x54, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x9}, // 1010 0000 H: &H{ // 1111 1+2+4+8 = 15 Version: 5, Type: 4, Subtype: 0, Reserved: 0, RemoteIndex: 10, MessageCounter: 9, }, }, } func TestEncode(t *testing.T) { for _, tt := range headerBigEndianTests { b, err := tt.Encode(make([]byte, Len)) if err != nil { t.Fatal(err) } assert.Equal(t, tt.expectedBytes, b) } } func TestParse(t *testing.T) { for _, tt := range headerBigEndianTests { b := tt.expectedBytes parsedHeader := &H{} parsedHeader.Parse(b) if !reflect.DeepEqual(tt.H, parsedHeader) { t.Fatalf("got %#v; want %#v", parsedHeader, tt.H) } } } func TestTypeName(t *testing.T) { assert.Equal(t, "test", TypeName(Test)) assert.Equal(t, "test", (&H{Type: Test}).TypeName()) assert.Equal(t, "unknown", TypeName(99)) assert.Equal(t, "unknown", (&H{Type: 99}).TypeName()) } func TestSubTypeName(t *testing.T) { assert.Equal(t, "testRequest", SubTypeName(Test, TestRequest)) assert.Equal(t, "testRequest", (&H{Type: Test, Subtype: TestRequest}).SubTypeName()) assert.Equal(t, "unknown", SubTypeName(99, TestRequest)) assert.Equal(t, "unknown", (&H{Type: 99, Subtype: TestRequest}).SubTypeName()) assert.Equal(t, "unknown", SubTypeName(Test, 99)) assert.Equal(t, "unknown", (&H{Type: Test, Subtype: 99}).SubTypeName()) assert.Equal(t, "none", SubTypeName(Message, 0)) assert.Equal(t, "none", (&H{Type: Message, Subtype: 0}).SubTypeName()) } func TestTypeMap(t *testing.T) { // Force people to document this stuff assert.Equal(t, map[MessageType]string{ Handshake: "handshake", Message: "message", RecvError: "recvError", LightHouse: "lightHouse", Test: "test", CloseTunnel: "closeTunnel", Control: "control", }, typeMap) assert.Equal(t, map[MessageType]*map[MessageSubType]string{ Message: { MessageNone: "none", MessageRelay: "relay", }, RecvError: &subTypeNoneMap, LightHouse: &subTypeNoneMap, Test: &subTypeTestMap, CloseTunnel: &subTypeNoneMap, Handshake: { HandshakeIXPSK0: "ix_psk0", }, Control: &subTypeNoneMap, }, subTypeMap) } func TestHeader_String(t *testing.T) { assert.Equal( t, "ver=100 type=test subtype=testRequest reserved=0x63 remoteindex=98 messagecounter=97", (&H{100, Test, TestRequest, 99, 98, 97}).String(), ) } func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() require.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", string(b), ) } ================================================ FILE: hostmap.go ================================================ package nebula import ( "encoding/json" "errors" "fmt" "net" "net/netip" "slices" "sync" "sync/atomic" "time" "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const MaxRemotes = 10 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // 5 allows for an initial handshake and each host pair re-handshaking twice const MaxHostInfosPerVpnIp = 5 // How long we should prevent roaming back to the previous IP. // This helps prevent flapping due to packets already in flight const RoamingSuppressSeconds = 2 const ( Requested = iota PeerRequested Established Disestablished ) const ( Unknowntype = iota ForwardingType TerminalType ) type Relay struct { Type int State int LocalIndex uint32 RemoteIndex uint32 PeerAddr netip.Addr } type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] l *logrus.Logger } // For synchronization, treat the pointed-to Relay struct as immutable. To edit the Relay // struct, make a copy of an existing value, edit the fileds in the copy, and // then store a pointer to the new copy in both realyForBy* maps. type RelayState struct { sync.RWMutex relays []netip.Addr // Ordered set of VpnAddrs of Hosts to use as relays to access this peer // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with // the RelayState Lock held) relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() for idx, val := range rs.relays { if val == ip { rs.relays = append(rs.relays[:idx], rs.relays[idx+1:]...) return } } } func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { rs.Lock() defer rs.Unlock() if r, ok := rs.relayForByAddr[vpnIp]; ok { newRelay := *r newRelay.State = state rs.relayForByAddr[newRelay.PeerAddr] = &newRelay rs.relayForByIdx[newRelay.LocalIndex] = &newRelay } } func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) { rs.Lock() defer rs.Unlock() if r, ok := rs.relayForByIdx[idx]; ok { newRelay := *r newRelay.State = state rs.relayForByAddr[newRelay.PeerAddr] = &newRelay rs.relayForByIdx[newRelay.LocalIndex] = &newRelay } } func (rs *RelayState) CopyAllRelayFor() []*Relay { rs.RLock() defer rs.RUnlock() ret := make([]*Relay, 0, len(rs.relayForByIdx)) for _, r := range rs.relayForByIdx { ret = append(ret, r) } return ret } func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByAddr[addr] return r, ok } func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() if !slices.Contains(rs.relays, ip) { rs.relays = append(rs.relays, ip) } } func (rs *RelayState) CopyRelayIps() []netip.Addr { ret := make([]netip.Addr, len(rs.relays)) rs.RLock() defer rs.RUnlock() copy(ret, rs.relays) return ret } func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) for relayIp := range rs.relayForByAddr { currentRelays = append(currentRelays, relayIp) } return currentRelays } func (rs *RelayState) CopyRelayForIdxs() []uint32 { rs.RLock() defer rs.RUnlock() ret := make([]uint32, 0, len(rs.relayForByIdx)) for i := range rs.relayForByIdx { ret = append(ret, i) } return ret } func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() r, ok := rs.relayForByAddr[vpnIp] if !ok { return false } newRelay := *r newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay rs.relayForByAddr[r.PeerAddr] = &newRelay return true } func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Relay, bool) { rs.Lock() defer rs.Unlock() r, ok := rs.relayForByIdx[localIdx] if !ok { return nil, false } newRelay := *r newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay rs.relayForByAddr[r.PeerAddr] = &newRelay return &newRelay, true } func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByAddr[vpnIp] return r, ok } func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIdx[idx] return r, ok } func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() rs.relayForByAddr[ip] = r rs.relayForByIdx[idx] = r } type NetworkType uint8 const ( NetworkTypeUnknown NetworkType = iota // NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate NetworkTypeVPN // NetworkTypeVPNPeer is a network that does not overlap one of our networks NetworkTypeVPNPeer // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() NetworkTypeUnsafe ) type HostInfo struct { remote netip.AddrPort remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks // The host may have other vpn addresses that are outside our // vpn networks but were removed because they are not usable vpnAddrs []netip.Addr // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. networks *bart.Table[NetworkType] relayState RelayState // HandshakePacket records the packets used to create this hostinfo // We need these to avoid replayed handshake packets creating new hostinfos which causes churn HandshakePacket map[uint8][]byte // nextLHQuery is the earliest we can ask the lighthouse for new information. // This is used to limit lighthouse re-queries in chatty clients nextLHQuery atomic.Int64 // lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH // for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like // with a handshake lastRebindCount int8 // lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally // Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator // This is used to avoid an attack where a handshake packet is replayed after some time lastHandshakeTime uint64 lastRoam time.Time lastRoamRemote netip.AddrPort // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. next, prev *HostInfo //TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing in, out, pendingDeletion atomic.Bool // lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use. // This value will be behind against actual tunnel utilization in the hot path. // This should only be used by the ConnectionManagers ticker routine. lastUsed time.Time } type ViaSender struct { UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay remoteIdx uint32 // remoteIdx is the index included in the header of the received packet relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. IsRelayed bool // IsRelayed is true if the packet was sent through a relay } func (v ViaSender) String() string { if v.IsRelayed { return fmt.Sprintf("%s (relayed)", v.UdpAddr) } return v.UdpAddr.String() } func (v ViaSender) MarshalJSON() ([]byte, error) { if v.IsRelayed { return json.Marshal(m{"relay": v.UdpAddr}) } return json.Marshal(m{"direct": v.UdpAddr}) } type cachedPacket struct { messageType header.MessageType messageSubType header.MessageSubType callback packetCallback packet []byte } type packetCallback func(t header.MessageType, st header.MessageSubType, h *HostInfo, p, nb, out []byte) type cachedPacketMetrics struct { sent metrics.Counter dropped metrics.Counter } func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { hm := newHostMap(l) hm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { hm.reload(c, false) }) l.WithField("preferredRanges", hm.GetPreferredRanges()). Info("Main HostMap created") return hm } func newHostMap(l *logrus.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{}, l: l, } } func (hm *HostMap) reload(c *config.C, initial bool) { if initial || c.HasChanged("preferred_ranges") { var preferredRanges []netip.Prefix rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) for _, rawPreferredRange := range rawPreferredRanges { preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") continue } preferredRanges = append(preferredRanges, preferredRange) } oldRanges := hm.preferredRanges.Swap(&preferredRanges) if !initial { hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed") } } } // EmitStats reports host, index, and relay counts to the stats collection system func (hm *HostMap) EmitStats() { hm.RLock() hostLen := len(hm.Hosts) indexLen := len(hm.Indexes) remoteIndexLen := len(hm.RemoteIndexes) relaysLen := len(hm.Relays) hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen)) metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen)) metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore hm.Lock() // If we have a previous or next hostinfo then we are not the last one for this vpn ip final := (hostinfo.next == nil && hostinfo.prev == nil) hm.unlockedDeleteHostInfo(hostinfo) hm.Unlock() return final } func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { hm.Lock() defer hm.Unlock() hm.unlockedMakePrimary(hostinfo) } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { // Get the current primary, if it exists oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] // Every address in the hostinfo gets elevated to primary for _, vpnAddr := range hostinfo.vpnAddrs { //NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on // indexes so it should be fine. hm.Hosts[vpnAddr] = hostinfo } // If we are already primary then we won't bother re-linking if oldHostinfo == hostinfo { return } // Unlink this hostinfo if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } if hostinfo.next != nil { hostinfo.next.prev = hostinfo.prev } // If there wasn't a previous primary then clear out any links if oldHostinfo == nil { hostinfo.next = nil hostinfo.prev = nil return } // Relink the hostinfo as primary hostinfo.next = oldHostinfo oldHostinfo.prev = hostinfo hostinfo.prev = nil } func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { for _, addr := range hostinfo.vpnAddrs { h := hm.Hosts[addr] for h != nil { if h == hostinfo { hm.unlockedInnerDeleteHostInfo(h, addr) } h = h.next } } } func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) { primary, ok := hm.Hosts[addr] isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil if ok && primary == hostinfo { // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it delete(hm.Hosts, addr) if len(hm.Hosts) == 0 { hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary hm.Hosts[addr] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } } else { // Relink if we were in the middle of multiple hostinfos for this vpn addr if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } if hostinfo.next != nil { hostinfo.next.prev = hostinfo.prev } } hostinfo.next = nil hostinfo.prev = nil // The remote index uses index ids outside our control so lets make sure we are only removing // the remote index pointer here if it points to the hostinfo we are deleting hostinfo2, ok := hm.RemoteIndexes[hostinfo.remoteIndexId] if ok && hostinfo2 == hostinfo { delete(hm.RemoteIndexes, hostinfo.remoteIndexId) if len(hm.RemoteIndexes) == 0 { hm.RemoteIndexes = map[uint32]*HostInfo{} } } delete(hm.Indexes, hostinfo.localIndexId) if len(hm.Indexes) == 0 { hm.Indexes = map[uint32]*HostInfo{} } if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } if isLastHostinfo { // I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next // hops as 'Requested' so that new relay tunnels are created in the future. hm.unlockedDisestablishVpnAddrRelayFor(hostinfo) } // Clean up any local relay indexes for which I am acting as a relay hop for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } } func (hm *HostMap) QueryIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.Indexes[index]; ok { hm.RUnlock() return h } else { hm.RUnlock() return nil } } func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.Relays[index]; ok { hm.RUnlock() return h } else { hm.RUnlock() return nil } } func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.RemoteIndexes[index]; ok { hm.RUnlock() return h } else { hm.RUnlock() return nil } } func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { return hm.queryVpnAddr(vpnIp, nil) } func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() h, ok := hm.Hosts[relayHostIp] if !ok { return nil, nil, errors.New("unable to find host") } for h != nil { for _, targetIp := range targetIps { r, ok := h.relayState.QueryRelayForByIp(targetIp) if ok && r.State == Established { return h, r, nil } } h = h.next } return nil, nil, errors.New("unable to find host with relay") } func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) { for _, relayHostIp := range hi.relayState.CopyRelayIps() { if h, ok := hm.Hosts[relayHostIp]; ok { for h != nil { h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) h = h.next } } } for _, rs := range hi.relayState.CopyAllRelayFor() { if rs.Type == ForwardingType { if h, ok := hm.Hosts[rs.PeerAddr]; ok { for h != nil { h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) h = h.next } } } } } func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() // Do not attempt promotion if you are a lighthouse if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce) } return h } hm.RUnlock() return nil } // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } for _, addr := range hostinfo.vpnAddrs { hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } } func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { existing := hm.Hosts[vpnAddr] hm.Hosts[vpnAddr] = hostinfo if existing != nil && existing != hostinfo { hostinfo.next = existing existing.prev = hostinfo } i := 1 check := hostinfo for check != nil { if i > MaxHostInfosPerVpnIp { hm.unlockedDeleteHostInfo(check) } check = check.next i++ } } func (hm *HostMap) GetPreferredRanges() []netip.Prefix { //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer return *hm.preferredRanges.Load() } func (hm *HostMap) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() for _, v := range hm.Hosts { f(v) } } func (hm *HostMap) ForEachIndex(f controlEach) { hm.RLock() defer hm.RUnlock() for _, v := range hm.Indexes { f(v) } } // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote // return early if we are already on a preferred remote if remote.IsValid() { rIP := remote.Addr() for _, l := range preferredRanges { if l.Contains(rIP) { return } } } i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { if remote.IsValid() && (!addr.IsValid() || !preferred) { return } // Try to send a test packet to that host, this should // cause it to detect a roaming event and switch remotes ifce.sendTo(header.Test, header.TestRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }) } // Re query our lighthouses for new remotes occasionally if c%ifce.reQueryEvery.Load() == 0 && ifce.lightHouse != nil { now := time.Now().UnixNano() if now < i.nextLHQuery.Load() { return } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) ifce.lightHouse.QueryServer(i.vpnAddrs[0]) } } func (i *HostInfo) GetCert() *cert.CachedCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert } return nil } // TODO: Maybe use ViaSender here? func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { i.remote = remote i.remotes.LearnRemote(i.vpnAddrs[0], remote) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool { if via.IsRelayed { return false } currentRemote := i.remote if !currentRemote.IsValid() { i.SetRemote(via.UdpAddr) return true } // NOTE: We do this loop here instead of calling `isPreferred` in // remote_list.go so that we only have to loop over preferredRanges once. newIsPreferred := false for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote if l.Contains(currentRemote.Addr()) { return false } if l.Contains(via.UdpAddr.Addr()) { newIsPreferred = true } } if newIsPreferred { // Consider this a roaming event i.lastRoam = time.Now() i.lastRoamRemote = currentRemote i.SetRemote(via.UdpAddr) return true } return false } // buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { return // Simple case, no BART needed } } i.networks = new(bart.Table[NetworkType]) for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) if myVpnNetworksTable.Contains(network.Addr()) { i.networks.Insert(nprefix, NetworkTypeVPN) } else { i.networks.Insert(nprefix, NetworkTypeVPNPeer) } } for _, network := range c.UnsafeNetworks() { i.networks.Insert(network, NetworkTypeUnsafe) } } func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { if i == nil { return logrus.NewEntry(l) } li := l.WithField("vpnAddrs", i.vpnAddrs). WithField("localIndex", i.localIndexId). WithField("remoteIndex", i.remoteIndexId) if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { li = li.WithField("certName", peerCert.Certificate.Name()) } } return li } // Utility functions func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) if l.Level >= logrus.TraceLevel { l.WithField("interfaceName", i.Name).WithField("allow", allow).Trace("localAllowList.AllowName") } if !allow { continue } addrs, _ := i.Addrs() for _, rawAddr := range addrs { var addr netip.Addr switch v := rawAddr.(type) { case *net.IPNet: //continue addr, _ = netip.AddrFromSlice(v.IP) case *net.IPAddr: addr, _ = netip.AddrFromSlice(v.IP) } if !addr.IsValid() { if l.Level >= logrus.DebugLevel { l.WithField("localAddr", rawAddr).Debug("addr was invalid") } continue } addr = addr.Unmap() if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { isAllowed := allowList.Allow(addr) if l.Level >= logrus.TraceLevel { l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") } if !isAllowed { continue } finalAddrs = append(finalAddrs, addr) } } } return finalAddrs } ================================================ FILE: hostmap_test.go ================================================ package nebula import ( "net/netip" "testing" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() hm := newHostMap(l) f := &Interface{} h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) hm.unlockedAddHostInfo(h2, f) hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) assert.Equal(t, h3.localIndexId, h2.next.localIndexId) assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) assert.Equal(t, h4.localIndexId, h3.next.localIndexId) assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) assert.Nil(t, h4.next) // Swap h3/middle to primary hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h2.localIndexId, h1.next.localIndexId) assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) assert.Equal(t, h4.localIndexId, h2.next.localIndexId) assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) assert.Nil(t, h4.next) // Swap h4/tail to primary hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h1.localIndexId, h3.next.localIndexId) assert.Equal(t, h4.localIndexId, h3.prev.localIndexId) assert.Equal(t, h2.localIndexId, h1.next.localIndexId) assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) assert.Nil(t, h2.next) // Swap h4 again should be no-op hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h1.localIndexId, h3.next.localIndexId) assert.Equal(t, h4.localIndexId, h3.prev.localIndexId) assert.Equal(t, h2.localIndexId, h1.next.localIndexId) assert.Equal(t, h3.localIndexId, h1.prev.localIndexId) assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) assert.Nil(t, h2.next) } func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() hm := newHostMap(l) f := &Interface{} h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5} h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) hm.unlockedAddHostInfo(h2, f) hm.unlockedAddHostInfo(h1, f) // h6 should be deleted assert.Nil(t, h6.next) assert.Nil(t, h6.prev) h := hm.QueryIndex(h6.localIndexId) assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h1.localIndexId, h2.prev.localIndexId) assert.Equal(t, h3.localIndexId, h2.next.localIndexId) assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) assert.Equal(t, h4.localIndexId, h3.next.localIndexId) assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) assert.Equal(t, h5.localIndexId, h4.next.localIndexId) assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) assert.Nil(t, h5.next) // Delete primary hm.DeleteHostInfo(h1) assert.Nil(t, h1.prev) assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h3.localIndexId, h2.next.localIndexId) assert.Equal(t, h2.localIndexId, h3.prev.localIndexId) assert.Equal(t, h4.localIndexId, h3.next.localIndexId) assert.Equal(t, h3.localIndexId, h4.prev.localIndexId) assert.Equal(t, h5.localIndexId, h4.next.localIndexId) assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) assert.Nil(t, h5.next) // Delete in the middle hm.DeleteHostInfo(h3) assert.Nil(t, h3.prev) assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h4.localIndexId, h2.next.localIndexId) assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) assert.Equal(t, h5.localIndexId, h4.next.localIndexId) assert.Equal(t, h4.localIndexId, h5.prev.localIndexId) assert.Nil(t, h5.next) // Delete the tail hm.DeleteHostInfo(h5) assert.Nil(t, h5.prev) assert.Nil(t, h5.next) // Make sure we go h2 -> h4 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) assert.Equal(t, h4.localIndexId, h2.next.localIndexId) assert.Equal(t, h2.localIndexId, h4.prev.localIndexId) assert.Nil(t, h4.next) // Delete the head hm.DeleteHostInfo(h2) assert.Nil(t, h2.prev) assert.Nil(t, h2.next) // Make sure we only have h4 prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) assert.Nil(t, h4.next) // Delete the only item hm.DeleteHostInfo(h4) assert.Nil(t, h4.prev) assert.Nil(t, h4.next) // Make sure we have nil prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } func TestHostMap_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) hm := NewHostMapFromConfig(l, c) toS := func(ipn []netip.Prefix) []string { var s []string for _, n := range ipn { s = append(s, n.String()) } return s } assert.Empty(t, hm.GetPreferredRanges()) c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]") assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges())) c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]") assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges())) } func TestHostMap_RelayState(t *testing.T) { h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} a1 := netip.MustParseAddr("::1") a2 := netip.MustParseAddr("2001::1") h1.relayState.InsertRelayTo(a1) assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays) h1.relayState.InsertRelayTo(a2) assert.Equal(t, []netip.Addr{a1, a2}, h1.relayState.relays) // Ensure that the first relay added is the first one returned in the copy currentRelays := h1.relayState.CopyRelayIps() require.Len(t, currentRelays, 2) assert.Equal(t, a1, currentRelays[0]) // Deleting the last one in the list works ok h1.relayState.DeleteRelay(a2) assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays) // Deleting an element not in the list works ok h1.relayState.DeleteRelay(a2) assert.Equal(t, []netip.Addr{a1}, h1.relayState.relays) // Deleting the only element in the list works ok h1.relayState.DeleteRelay(a1) assert.Equal(t, []netip.Addr{}, h1.relayState.relays) } ================================================ FILE: hostmap_tester.go ================================================ //go:build e2e_testing package nebula // This file contains functions used to export information to the e2e testing framework import ( "net/netip" ) func (i *HostInfo) GetVpnAddrs() []netip.Addr { return i.vpnAddrs } func (i *HostInfo) GetLocalIndex() uint32 { return i.localIndexId } func (i *HostInfo) GetRemoteIndex() uint32 { return i.remoteIndexId } func (i *HostInfo) GetRelayState() *RelayState { return &i.relayState } ================================================ FILE: inside.go ================================================ package nebula import ( "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/routing" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err) } return } // Ignore local broadcast packets if f.dropLocalBroadcast { if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) { return } } if f.myVpnAddrsTable.Contains(fwPacket.RemoteAddr) { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) if err != nil { f.l.WithError(err).Error("Failed to forward to tun") } } // Otherwise, drop. On linux, we should never see these packets - Linux // routes packets from the nebula addr to the nebula addr through the loopback device. return } // Ignore multicast packets if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { return } hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } return } if !ready { return } dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l). WithField("fwPacket", fwPacket). WithField("reason", dropReason). Debugln("dropping outbound packet") } } } func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return } out = iputil.CreateRejectPacket(packet, out) if len(out) == 0 { return } _, err := f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") } } func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *HostInfo, nb, out []byte, q int) { if !f.firewall.OutSendReject { return } out = iputil.CreateRejectPacket(packet, out) if len(out) == 0 { return } if len(out) > iputil.MaxRejectPacketSize { if f.l.GetLevel() >= logrus.InfoLevel { f.l. WithField("packet", packet). WithField("outPacket", out). Info("rejectOutside: packet too big, not sending") } return } f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } // Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established // it does not check if it is within our vpn networks! func (f *Interface) Handshake(vpnAddr netip.Addr) { f.handshakeManager.GetOrHandshake(vpnAddr, nil) } // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { if f.myVpnNetworksTable.Contains(vpnAddr) { return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } return nil, false } // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { destinationAddr := fwPacket.RemoteAddr hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) // Host is inside the mesh, no routing required if hostinfo != nil { return hostinfo, ready } gateways := f.inside.RoutesFor(destinationAddr) switch len(gateways) { case 0: return nil, false case 1: // Single gateway route return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback) default: // Multi gateway route, perform ECMP categorization gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways) if !balancingOk { // This happens if the gateway buckets were not calculated, this _should_ never happen f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.") } var handshakeInfoForChosenGateway *HandshakeHostInfo var hhReceiver = func(hh *HandshakeHostInfo) { handshakeInfoForChosenGateway = hh } // Store the handshakeHostInfo for later. // If this node is not reachable we will attempt other nodes, if none are reachable we will // cache the packet for this gateway. if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready { return hostinfo, true } // It appears the selected gateway cannot be reached, find another gateway to fallback on. // The current implementation breaks ECMP but that seems better than no connectivity. // If ECMP is also required when a gateway is down then connectivity status // for each gateway needs to be kept and the weights recalculated when they go up or down. // This would also need to interact with unsafe_route updates through reloading the config or // use of the use_system_route_table option if f.l.Level >= logrus.DebugLevel { f.l.WithField("destination", destinationAddr). WithField("originalGateway", gatewayAddr). Debugln("Calculated gateway for ECMP not available, attempting other gateways") } for i := range gateways { // Skip the gateway that failed previously if gateways[i].Addr() == gatewayAddr { continue } // We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready { return hostinfo, true } } // No gateways reachable, cache the packet in the originally chosen gateway cacheCallback(handshakeInfoForChosenGateway) return hostinfo, false } } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { fp := &firewall.Packet{} err := newPacket(p, false, fp) if err != nil { f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err) return } // check if packet is in outbound fw rules dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). WithField("reason", dropReason). Debugln("dropping cached packet") } return } f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. // This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddr", vpnAddr). Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") } return } if !ready { return } f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out) } func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) { f.send(t, st, hi.ConnectionState, hi, p, nb, out) } func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) } func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } // SendVia sends a payload through a Relay tunnel. No authentication or encryption is done // to the payload for the ultimate target host, making this a useful method for sending // handshake messages to peers through relay tunnels. // via is the HostInfo through which the message is relayed. // ad is the plaintext data to authenticate, but not encrypt // nb is a buffer used to store the nonce value, re-used for performance reasons. // out is a buffer used to store the result of the Encrypt operation // q indicates which writer to use to send the packet. func (f *Interface) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool, ) { if noiseutil.EncryptLockNeeded { // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check via.ConnectionState.writeLock.Lock() } c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) f.connectionManager.Out(via) // Authenticate the header and payload, but do not encrypt for this message type. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. if len(out)+len(ad)+via.ConnectionState.eKey.Overhead() > cap(out) { if noiseutil.EncryptLockNeeded { via.ConnectionState.writeLock.Unlock() } via.logger(f.l). WithField("outCap", cap(out)). WithField("payloadLen", len(ad)). WithField("headerLen", len(out)). WithField("cipherOverhead", via.ConnectionState.eKey.Overhead()). Error("SendVia out buffer not large enough for relay") return } // The header bytes are written to the 'out' slice; Grow the slice to hold the header and associated data payload. offset := len(out) out = out[:offset+len(ad)] // In one call path, the associated data _is_ already stored in out. In other call paths, the associated data must // be copied into 'out'. if !nocopy { copy(out[offset:], ad) } var err error out, err = via.ConnectionState.eKey.EncryptDanger(out, out, nil, c, nb) if noiseutil.EncryptLockNeeded { via.ConnectionState.writeLock.Unlock() } if err != nil { via.logger(f.l).WithError(err).Info("Failed to EncryptDanger in sendVia") return } err = f.writers[0].WriteTo(out, via.remote) if err != nil { via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia") } f.connectionManager.RelayUsed(relay.LocalIndex) } func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { return } useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out if useRelay { if len(out) < header.Len { // out always has a capacity of mtu, but not always a length greater than the header.Len. // Grow it to make sure the next operation works. out = out[:header.Len] } // Save a header's worth of data at the front of the 'out' buffer. out = out[header.Len:] } if noiseutil.EncryptLockNeeded { // NOTE: for goboring AESGCMTLS we need to lock because of the nonce check ci.writeLock.Lock() } c := ci.messageCounter.Add(1) //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) f.connectionManager.Out(hostinfo) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // all our addrs and enable a faster roaming. if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") } } var err error out, err = ci.eKey.EncryptDanger(out, out, p, c, nb) if noiseutil.EncryptLockNeeded { ci.writeLock.Unlock() } if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).WithField("counter", c). WithField("attemptedCounter", c). Error("Failed to encrypt outgoing packet") return } if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") continue } f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true) break } } } ================================================ FILE: inside_bsd.go ================================================ //go:build darwin || dragonfly || freebsd || netbsd || openbsd package nebula const immediatelyForwardToSelf bool = true ================================================ FILE: inside_generic.go ================================================ //go:build !darwin && !dragonfly && !freebsd && !netbsd && !openbsd package nebula const immediatelyForwardToSelf bool = false ================================================ FILE: interface.go ================================================ package nebula import ( "context" "errors" "fmt" "io" "net/netip" "os" "runtime" "sync/atomic" "time" "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) const mtu = 9001 type InterfaceConfig struct { HostMap *HostMap Outside udp.Conn Inside overlay.Device pki *PKI Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager lightHouse *LightHouse connectionManager *connectionManager DropLocalBroadcast bool DropMulticast bool routines int MessageMetrics *MessageMetrics version string relayManager *relayManager punchy *Punchy tryPromoteEvery uint32 reQueryEvery uint32 reQueryWait time.Duration ConntrackCacheTimeout time.Duration l *logrus.Logger } type Interface struct { hostMap *HostMap outside udp.Conn inside overlay.Device pki *PKI firewall *Firewall connectionManager *connectionManager handshakeManager *HandshakeManager serveDns bool createTime time.Time lightHouse *LightHouse myBroadcastAddrsTable *bart.Lite myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate myVpnAddrsTable *bart.Lite myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate myVpnNetworksTable *bart.Lite dropLocalBroadcast bool dropMulticast bool routines int disconnectInvalid atomic.Bool closed atomic.Bool relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 reQueryWait atomic.Int64 sendRecvErrorConfig recvErrorConfig acceptRecvErrorConfig recvErrorConfig // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse rebindCount int8 version string conntrackCacheTimeout time.Duration writers []udp.Conn readers []io.ReadWriteCloser metricHandshakes metrics.Histogram messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics l *logrus.Logger } type EncWriter interface { SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool, ) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) Handshake(vpnAddr netip.Addr) GetHostInfo(vpnAddr netip.Addr) *HostInfo GetCertState() *CertState } type recvErrorConfig uint8 const ( recvErrorAlways recvErrorConfig = iota recvErrorNever recvErrorPrivate ) func (s recvErrorConfig) ShouldRecvError(endpoint netip.AddrPort) bool { switch s { case recvErrorPrivate: return endpoint.Addr().IsPrivate() case recvErrorAlways: return true case recvErrorNever: return false default: panic(fmt.Errorf("invalid recvErrorConfig value: %d", s)) } } func (s recvErrorConfig) String() string { switch s { case recvErrorAlways: return "always" case recvErrorNever: return "never" case recvErrorPrivate: return "private" default: return fmt.Sprintf("invalid(%d)", s) } } func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Outside == nil { return nil, errors.New("no outside connection") } if c.Inside == nil { return nil, errors.New("no inside interface (tun)") } if c.pki == nil { return nil, errors.New("no certificate state") } if c.Firewall == nil { return nil, errors.New("no firewall rules") } if c.connectionManager == nil { return nil, errors.New("no connection manager") } cs := c.pki.getCertState() ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, firewall: c.Firewall, serveDns: c.ServeDns, handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, myVpnAddrsTable: cs.myVpnAddrsTable, myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, relayManager: c.relayManager, connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, cachedPacketMetrics: &cachedPacketMetrics{ sent: metrics.GetOrRegisterCounter("hostinfo.cached_packets.sent", nil), dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), }, l: c.l, } ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) ifce.connectionManager.intf = ifce return ifce, nil } // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. func (f *Interface) activate() { // actually turn on tun dev addr, err := f.outside.LocalAddr() if err != nil { f.l.WithError(err).Error("Failed to get udp listen address") } f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). WithField("build", f.version).WithField("udpAddr", addr). WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") if f.routines > 1 { if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { f.routines = 1 f.l.Warn("routines is not supported on this platform, falling back to a single routine") } } metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) // Prepare n tun queues var reader io.ReadWriteCloser = f.inside for i := 0; i < f.routines; i++ { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { f.l.Fatal(err) } } f.readers[i] = reader } if err := f.inside.Activate(); err != nil { f.inside.Close() f.l.Fatal(err) } } func (f *Interface) run() { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { go f.listenOut(i) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { go f.listenIn(f.readers[i], i) } } func (f *Interface) listenOut(i int) { runtime.LockOSThread() var li udp.Conn if i > 0 { li = f.writers[i] } else { li = f.outside } ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() plaintext := make([]byte, udp.MTU) h := &header.H{} fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) for { n, err := reader.Read(packet) if err != nil { if errors.Is(err, os.ErrClosed) && f.closed.Load() { return } f.l.WithError(err).Error("Error while reading outbound packet") // This only seems to happen when something fatal happens to the fd, so exit. os.Exit(2) } f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadAcceptRecvError) c.RegisterReloadCallback(f.reloadDisconnectInvalid) c.RegisterReloadCallback(f.reloadMisc) for _, udpConn := range f.writers { c.RegisterReloadCallback(udpConn.ReloadConfig) } } func (f *Interface) reloadDisconnectInvalid(c *config.C) { initial := c.InitialLoad() if initial || c.HasChanged("pki.disconnect_invalid") { f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true)) if !initial { f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load()) } } } func (f *Interface) reloadFirewall(c *config.C) { //TODO: need to trigger/detect if the certificate changed too if c.HasChanged("firewall") == false { f.l.Debug("No firewall config change detected") return } fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return } oldFw := f.firewall conntrack := oldFw.Conntrack conntrack.Lock() defer conntrack.Unlock() fw.rulesVersion = oldFw.rulesVersion + 1 // If rulesVersion is back to zero, we have wrapped all the way around. Be // safe and just reset conntrack in this case. if fw.rulesVersion == 0 { f.l.WithField("firewallHashes", fw.GetRuleHashes()). WithField("oldFirewallHashes", oldFw.GetRuleHashes()). WithField("rulesVersion", fw.rulesVersion). Warn("firewall rulesVersion has overflowed, resetting conntrack") } else { fw.Conntrack = conntrack } f.firewall = fw oldFw.Destroy() f.l.WithField("firewallHashes", fw.GetRuleHashes()). WithField("oldFirewallHashes", oldFw.GetRuleHashes()). WithField("rulesVersion", fw.rulesVersion). Info("New firewall has been installed") } func (f *Interface) reloadSendRecvError(c *config.C) { if c.InitialLoad() || c.HasChanged("listen.send_recv_error") { stringValue := c.GetString("listen.send_recv_error", "always") switch stringValue { case "always": f.sendRecvErrorConfig = recvErrorAlways case "never": f.sendRecvErrorConfig = recvErrorNever case "private": f.sendRecvErrorConfig = recvErrorPrivate default: if c.GetBool("listen.send_recv_error", true) { f.sendRecvErrorConfig = recvErrorAlways } else { f.sendRecvErrorConfig = recvErrorNever } } f.l.WithField("sendRecvError", f.sendRecvErrorConfig.String()). Info("Loaded send_recv_error config") } } func (f *Interface) reloadAcceptRecvError(c *config.C) { if c.InitialLoad() || c.HasChanged("listen.accept_recv_error") { stringValue := c.GetString("listen.accept_recv_error", "always") switch stringValue { case "always": f.acceptRecvErrorConfig = recvErrorAlways case "never": f.acceptRecvErrorConfig = recvErrorNever case "private": f.acceptRecvErrorConfig = recvErrorPrivate default: if c.GetBool("listen.accept_recv_error", true) { f.acceptRecvErrorConfig = recvErrorAlways } else { f.acceptRecvErrorConfig = recvErrorNever } } f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). Info("Loaded accept_recv_error config") } } func (f *Interface) reloadMisc(c *config.C) { if c.HasChanged("counters.try_promote") { n := c.GetUint32("counters.try_promote", defaultPromoteEvery) f.tryPromoteEvery.Store(n) f.l.Info("counters.try_promote has changed") } if c.HasChanged("counters.requery_every_packets") { n := c.GetUint32("counters.requery_every_packets", defaultReQueryEvery) f.reQueryEvery.Store(n) f.l.Info("counters.requery_every_packets has changed") } if c.HasChanged("timers.requery_wait_duration") { n := c.GetDuration("timers.requery_wait_duration", defaultReQueryWait) f.reQueryWait.Store(int64(n)) f.l.Info("timers.requery_wait_duration has changed") } } func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) defer ticker.Stop() udpStats := udp.NewUDPStatsEmitter(f.writers) certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil) certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) for { select { case <-ctx.Done(): return case <-ticker.C: f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() certState := f.pki.getCertState() defaultCrt := certState.GetDefaultCertificate() certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) certInitiatingVersion.Update(int64(defaultCrt.Version())) // Report the max certificate version we are capable of using if certState.v2Cert != nil { certMaxVersion.Update(int64(certState.v2Cert.Version())) } else { certMaxVersion.Update(int64(certState.v1Cert.Version())) } } } } func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { return f.hostMap.QueryVpnAddr(vpnIp) } func (f *Interface) GetCertState() *CertState { return f.pki.getCertState() } func (f *Interface) Close() error { f.closed.Store(true) for _, u := range f.writers { err := u.Close() if err != nil { f.l.WithError(err).Error("Error while closing udp socket") } } for i, r := range f.readers { if i == 0 { continue // f.readers[0] is f.inside, which we want to save for last } if err := r.Close(); err != nil { f.l.WithError(err).Error("Error while closing tun reader") } } // Release the tun device return f.inside.Close() } ================================================ FILE: iputil/packet.go ================================================ package iputil import ( "encoding/binary" "golang.org/x/net/ipv4" ) const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header // - 8 byte icmpv4 header // - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header) MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8 ) func CreateRejectPacket(packet []byte, out []byte) []byte { if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version { return nil } switch packet[9] { case 6: // tcp return ipv4CreateRejectTCPPacket(packet, out) default: return ipv4CreateRejectICMPPacket(packet, out) } } func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte { ihl := int(packet[0]&0x0f) << 2 if len(packet) < ihl { // We need at least this many bytes for this to be a valid packet return nil } // ICMP reply includes original header and first 8 bytes of the packet packetLen := len(packet) if packetLen > ihl+8 { packetLen = ihl + 8 } outLen := ipv4.HeaderLen + 8 + packetLen if outLen > cap(out) { return nil } out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl ipHdr[1] = 0 // DSCP, ECN binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length ipHdr[4] = 0 // id ipHdr[5] = 0 // . ipHdr[6] = 0 // flags, fragment offset ipHdr[7] = 0 // . ipHdr[8] = 64 // TTL ipHdr[9] = 1 // protocol (icmp) ipHdr[10] = 0 // checksum ipHdr[11] = 0 // . // Swap dest / src IPs copy(ipHdr[12:16], packet[16:20]) copy(ipHdr[16:20], packet[12:16]) // Calculate checksum binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0)) // ICMP Destination Unreachable icmpOut := out[ipv4.HeaderLen:] icmpOut[0] = 3 // type (Destination unreachable) icmpOut[1] = 3 // code (Port unreachable error) icmpOut[2] = 0 // checksum icmpOut[3] = 0 // . icmpOut[4] = 0 // unused icmpOut[5] = 0 // . icmpOut[6] = 0 // . icmpOut[7] = 0 // . // Copy original IP header and first 8 bytes as body copy(icmpOut[8:], packet[:packetLen]) // Calculate checksum binary.BigEndian.PutUint16(icmpOut[2:], tcpipChecksum(icmpOut, 0)) return out } func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte { const tcpLen = 20 ihl := int(packet[0]&0x0f) << 2 outLen := ipv4.HeaderLen + tcpLen if len(packet) < ihl+tcpLen { // We need at least this many bytes for this to be a valid packet return nil } if outLen > cap(out) { return nil } out = out[:outLen] ipHdr := out[0:ipv4.HeaderLen] ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2) // version, ihl ipHdr[1] = 0 // DSCP, ECN binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length ipHdr[4] = 0 // id ipHdr[5] = 0 // . ipHdr[6] = 0 // flags, fragment offset ipHdr[7] = 0 // . ipHdr[8] = 64 // TTL ipHdr[9] = 6 // protocol (tcp) ipHdr[10] = 0 // checksum ipHdr[11] = 0 // . // Swap dest / src IPs copy(ipHdr[12:16], packet[16:20]) copy(ipHdr[16:20], packet[12:16]) // Calculate checksum binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0)) // TCP RST tcpIn := packet[ihl:] var ackSeq, seq uint32 outFlags := byte(0b00000100) // RST // Set seq and ackSeq based on how iptables/netfilter does it in Linux: // - https://github.com/torvalds/linux/blob/v5.19/net/ipv4/netfilter/nf_reject_ipv4.c#L193-L221 inAck := tcpIn[13]&0b00010000 != 0 if inAck { seq = binary.BigEndian.Uint32(tcpIn[8:]) } else { inSyn := uint32((tcpIn[13] & 0b00000010) >> 1) inFin := uint32(tcpIn[13] & 0b00000001) // seq from the packet + syn + fin + tcp segment length ackSeq = binary.BigEndian.Uint32(tcpIn[4:]) + inSyn + inFin + uint32(len(tcpIn)) - uint32(tcpIn[12]>>4)<<2 outFlags |= 0b00010000 // ACK } tcpOut := out[ipv4.HeaderLen:] // Swap dest / src ports copy(tcpOut[0:2], tcpIn[2:4]) copy(tcpOut[2:4], tcpIn[0:2]) binary.BigEndian.PutUint32(tcpOut[4:], seq) binary.BigEndian.PutUint32(tcpOut[8:], ackSeq) tcpOut[12] = (tcpLen >> 2) << 4 // data offset, reserved, NS tcpOut[13] = outFlags // CWR, ECE, URG, ACK, PSH, RST, SYN, FIN tcpOut[14] = 0 // window size tcpOut[15] = 0 // . tcpOut[16] = 0 // checksum tcpOut[17] = 0 // . tcpOut[18] = 0 // URG Pointer tcpOut[19] = 0 // . // Calculate checksum csum := ipv4PseudoheaderChecksum(ipHdr[12:16], ipHdr[16:20], 6, tcpLen) binary.BigEndian.PutUint16(tcpOut[16:], tcpipChecksum(tcpOut, csum)) return out } func CreateICMPEchoResponse(packet, out []byte) []byte { // Return early if this is not a simple ICMP Echo Request //TODO: make constants out of these if !(len(packet) >= 28 && len(packet) <= 9001 && packet[0] == 0x45 && packet[9] == 0x01 && packet[20] == 0x08) { return nil } // We don't support fragmented packets if packet[7] != 0 || (packet[6]&0x2F != 0) { return nil } out = out[:len(packet)] copy(out, packet) // Swap dest / src IPs and recalculate checksum ipv4 := out[0:20] copy(ipv4[12:16], packet[16:20]) copy(ipv4[16:20], packet[12:16]) ipv4[10] = 0 ipv4[11] = 0 binary.BigEndian.PutUint16(ipv4[10:], tcpipChecksum(ipv4, 0)) // Change type to ICMP Echo Reply and recalculate checksum icmp := out[20:] icmp[0] = 0 icmp[2] = 0 icmp[3] = 0 binary.BigEndian.PutUint16(icmp[2:], tcpipChecksum(icmp, 0)) return out } // calculates the TCP/IP checksum defined in rfc1071. The passed-in // csum is any initial checksum data that's already been computed. // // based on: // - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L50-L70 func tcpipChecksum(data []byte, csum uint32) uint16 { // to handle odd lengths, we loop to length - 1, incrementing by 2, then // handle the last byte specifically by checking against the original // length. length := len(data) - 1 for i := 0; i < length; i += 2 { // For our test packet, doing this manually is about 25% faster // (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16. csum += uint32(data[i]) << 8 csum += uint32(data[i+1]) } if len(data)%2 == 1 { csum += uint32(data[length]) << 8 } for csum > 0xffff { csum = (csum >> 16) + (csum & 0xffff) } return ^uint16(csum) } // based on: // - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L26-L35 func ipv4PseudoheaderChecksum(src, dst []byte, proto, length uint32) (csum uint32) { csum += (uint32(src[0]) + uint32(src[2])) << 8 csum += uint32(src[1]) + uint32(src[3]) csum += (uint32(dst[0]) + uint32(dst[2])) << 8 csum += uint32(dst[1]) + uint32(dst[3]) csum += proto csum += length & 0xffff csum += length >> 16 return csum } ================================================ FILE: iputil/packet_test.go ================================================ package iputil import ( "net" "testing" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" ) func Test_CreateRejectPacket(t *testing.T) { h := ipv4.Header{ Len: 20, Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Protocol: 1, // ICMP } b, err := h.Marshal() if err != nil { t.Fatalf("h.Marhshal: %v", err) } b = append(b, []byte{0, 3, 0, 4}...) expectedLen := ipv4.HeaderLen + 8 + h.Len + 4 out := make([]byte, expectedLen) rejectPacket := CreateRejectPacket(b, out) assert.NotNil(t, rejectPacket) assert.Len(t, rejectPacket, expectedLen) // ICMP with max header len h = ipv4.Header{ Len: 60, Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Protocol: 1, // ICMP Options: make([]byte, 40), } b, err = h.Marshal() if err != nil { t.Fatalf("h.Marhshal: %v", err) } b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...) expectedLen = MaxRejectPacketSize out = make([]byte, MaxRejectPacketSize) rejectPacket = CreateRejectPacket(b, out) assert.NotNil(t, rejectPacket) assert.Len(t, rejectPacket, expectedLen) // TCP with max header len h = ipv4.Header{ Len: 60, Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Protocol: 6, // TCP Options: make([]byte, 40), } b, err = h.Marshal() if err != nil { t.Fatalf("h.Marhshal: %v", err) } b = append(b, []byte{0, 3, 0, 4}...) b = append(b, make([]byte, 16)...) expectedLen = ipv4.HeaderLen + 20 out = make([]byte, expectedLen) rejectPacket = CreateRejectPacket(b, out) assert.NotNil(t, rejectPacket) assert.Len(t, rejectPacket, expectedLen) } ================================================ FILE: lighthouse.go ================================================ package nebula import ( "context" "encoding/binary" "errors" "fmt" "net" "net/netip" "slices" "strconv" "sync" "sync/atomic" "time" "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) var ErrHostNotKnown = errors.New("host not known") var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr") type LightHouse struct { //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Lite punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses // map of vpn addr to answers addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and // respond with. // - When we are not a lighthouse, this filters which addresses we accept // from lighthouses. remoteAllowList atomic.Pointer[RemoteAllowList] // filters local addresses that we advertise to lighthouses localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply handshakeTrigger chan<- netip.Addr // staticList exists to avoid having a bool in each addrMap entry // since static should be rare staticList atomic.Pointer[map[netip.Addr]struct{}] lighthouses atomic.Pointer[[]netip.Addr] interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 advertiseAddrs atomic.Pointer[[]netip.AddrPort] // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] queryChan chan netip.Addr calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter l *logrus.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { return nil, util.NewContextualError("lighthouse.am_lighthouse enabled on node but no port number is set in config", nil, nil) } // If port is dynamic, discover it if nebulaPort == 0 && pc != nil { uPort, err := pc.LocalAddr() if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } nebulaPort = uint32(uPort.Port()) } h := LightHouse{ ctx: ctx, amLighthouse: amLighthouse, myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, punchConn: pc, punchy: p, queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } lighthouses := make([]netip.Addr, 0) h.lighthouses.Store(&lighthouses) staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) } else { h.metricHolepunchTx = metrics.NilCounter{} } err := h.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := h.reload(c, false) switch v := err.(type) { case *util.ContextualError: v.Log(l) case error: l.WithError(err).Error("failed to reload lighthouse") } }) h.startQueryWorker() return &h, nil } func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } func (lh *LightHouse) GetLighthouses() []netip.Addr { return *lh.lighthouses.Load() } func (lh *LightHouse) GetRemoteAllowList() *RemoteAllowList { return lh.remoteAllowList.Load() } func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { return lh.localAllowList.Load() } func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort { return *lh.advertiseAddrs.Load() } func (lh *LightHouse) GetRelaysForMe() []netip.Addr { return *lh.relaysForMe.Load() } func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } func (lh *LightHouse) GetUpdateInterval() int64 { return lh.interval.Load() } func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) advAddrs := make([]netip.AddrPort, 0) for i, rawAddr := range rawAdvAddrs { host, sport, err := net.SplitHostPort(rawAddr) if err != nil { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } addrs, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) if err != nil { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } if len(addrs) == 0 { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) } port, err := strconv.Atoi(sport) if err != nil { return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } if port == 0 { port = int(lh.nebulaPort) } //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used addr := addrs[0].Unmap() if lh.myVpnNetworksTable.Contains(addr) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } advAddrs = append(advAddrs, netip.AddrPortFrom(addr, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) if !initial { lh.l.Info("lighthouse.advertise_addrs has changed") } } if initial || c.HasChanged("lighthouse.interval") { lh.interval.Store(int64(c.GetInt("lighthouse.interval", 10))) if !initial { lh.l.Infof("lighthouse.interval changed to %v", lh.interval.Load()) if lh.updateCancel != nil { // May not always have a running routine lh.updateCancel() } lh.StartUpdateWorker() } } if initial || c.HasChanged("lighthouse.remote_allow_list") || c.HasChanged("lighthouse.remote_allow_ranges") { ral, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { return util.NewContextualError("Invalid lighthouse.remote_allow_list", nil, err) } lh.remoteAllowList.Store(ral) if !initial { lh.l.Info("lighthouse.remote_allow_list and/or lighthouse.remote_allow_ranges has changed") } } if initial || c.HasChanged("lighthouse.local_allow_list") { lal, err := NewLocalAllowListFromConfig(c, "lighthouse.local_allow_list") if err != nil { return util.NewContextualError("Invalid lighthouse.local_allow_list", nil, err) } lh.localAllowList.Store(lal) if !initial { lh.l.Info("lighthouse.local_allow_list has changed") } } if initial || c.HasChanged("lighthouse.calculated_remotes") { cr, err := NewCalculatedRemotesFromConfig(c, "lighthouse.calculated_remotes") if err != nil { return util.NewContextualError("Invalid lighthouse.calculated_remotes", nil, err) } lh.calculatedRemotes.Store(cr) if !initial { lh.l.Info("lighthouse.calculated_remotes has changed") } } //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { // Clean up. Entries still in the static_host_map will be re-built. // Entries no longer present must have their (possible) background DNS goroutines stopped. if existingStaticList := lh.staticList.Load(); existingStaticList != nil { lh.RLock() for staticVpnAddr := range *existingStaticList { if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { am.hr.Cancel() } } lh.RUnlock() } // Build a new list based on current config. staticList := make(map[netip.Addr]struct{}) err := lh.loadStaticMap(c, staticList) if err != nil { return err } lh.staticList.Store(&staticList) if !initial { if c.HasChanged("static_host_map") { lh.l.Info("static_host_map has changed") } if c.HasChanged("static_map.cadence") { lh.l.Info("static_map.cadence has changed") } if c.HasChanged("static_map.network") { lh.l.Info("static_map.network has changed") } if c.HasChanged("static_map.lookup_timeout") { lh.l.Info("static_map.lookup_timeout has changed") } } } if initial || c.HasChanged("lighthouse.hosts") { lhList, err := lh.parseLighthouses(c) if err != nil { return err } lh.lighthouses.Store(&lhList) if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") } } if initial || c.HasChanged("relay.relays") { switch c.GetBool("relay.am_relay", false) { case true: // Relays aren't allowed to specify other relays if len(c.GetStringSlice("relay.relays", nil)) > 0 { lh.l.Info("Ignoring relays from config because am_relay is true") } relaysForMe := []netip.Addr{} lh.relaysForMe.Store(&relaysForMe) case false: relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { configRIP, err := netip.ParseAddr(v) if err != nil { lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") } else { lh.l.WithField("relay", v).Info("Read relay from config") relaysForMe = append(relaysForMe, configRIP) } } lh.relaysForMe.Store(&relaysForMe) } } return nil } func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } out := make([]netip.Addr, len(lhs)) for i, host := range lhs { addr, err := netip.ParseAddr(host) if err != nil { return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } if !lh.myVpnNetworksTable.Contains(addr) { lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") } out[i] = addr } if !lh.amLighthouse && len(out) == 0 { lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") } staticList := lh.GetStaticHostList() for i := range out { if _, ok := staticList[out[i]]; !ok { return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i]) } } return out, nil } func getStaticMapCadence(c *config.C) (time.Duration, error) { cadence := c.GetString("static_map.cadence", "30s") d, err := time.ParseDuration(cadence) if err != nil { return 0, err } return d, nil } func getStaticMapLookupTimeout(c *config.C) (time.Duration, error) { lookupTimeout := c.GetString("static_map.lookup_timeout", "250ms") d, err := time.ParseDuration(lookupTimeout) if err != nil { return 0, err } return d, nil } func getStaticMapNetwork(c *config.C) (string, error) { network := c.GetString("static_map.network", "ip4") if network != "ip" && network != "ip4" && network != "ip6" { return "", fmt.Errorf("static_map.network must be one of ip, ip4, or ip6") } return network, nil } func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error { d, err := getStaticMapCadence(c) if err != nil { return err } network, err := getStaticMapNetwork(c) if err != nil { return err } lookupTimeout, err := getStaticMapLookupTimeout(c) if err != nil { return err } shm := c.GetMap("static_host_map", map[string]any{}) i := 0 for k, v := range shm { vpnAddr, err := netip.ParseAddr(fmt.Sprintf("%v", k)) if err != nil { return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } if !lh.myVpnNetworksTable.Contains(vpnAddr) { lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") } vals, ok := v.([]any) if !ok { vals = []any{v} } remoteAddrs := []string{} for _, v := range vals { remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnAddr, remoteAddrs, staticList) if err != nil { return err } i++ } return nil } func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { if !lh.IsLighthouseAddr(vpnAddr) { lh.QueryServer(vpnAddr) } lh.RLock() if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } lh.RUnlock() return nil } // QueryServer is asynchronous so no reply should be expected func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { // Don't put lighthouse addrs in the query channel because we can't query lighthouses about lighthouses if lh.amLighthouse || lh.IsLighthouseAddr(vpnAddr) { return } lh.queryChan <- vpnAddr } func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } lh.RUnlock() lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnAddr // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() lh.RUnlock() // We may be asking about a non primary address so lets get the primary address if slices.Contains(v.vpnAddrs, vpnAddr) { vpnAddr = v.vpnAddrs[0] } c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) return true, n, err } return false, 0, nil } lh.RUnlock() return false, 0, nil } func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing. staticList := lh.GetStaticHostList() for _, addr := range allVpnAddrs { if _, ok := staticList[addr]; ok { return } } // None of the VpnAddrs were present. Now we can do the deletes. lh.Lock() rm, ok := lh.addrMap[allVpnAddrs[0]] if ok { for _, addr := range allVpnAddrs { srm := lh.addrMap[addr] if srm == rm { delete(lh.addrMap, addr) if lh.l.Level >= logrus.DebugLevel { lh.l.Debugf("deleting %s from lighthouse.", addr) } } } } lh.Unlock() } // AddStaticRemote adds a static host entry for vpnAddr as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx lh.Unlock() hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { // This callback runs whenever the DNS hostname resolver finds a different set of addr's // in its resolution for hostnames. am.Lock() defer am.Unlock() am.shouldRebuild = true }) if err != nil { return util.NewContextualError("Static host address could not be parsed", m{"vpnAddr": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetAddrs() { if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) { continue } switch { case addrPort.Addr().Is4(): am.unlockedPrependV4(lh.myVpnNetworks[0].Addr(), netAddrToProtoV4AddrPort(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): am.unlockedPrependV6(lh.myVpnNetworks[0].Addr(), netAddrToProtoV6AddrPort(addrPort.Addr(), addrPort.Port())) } } // Mark it as static in the caller provided map staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } var calculatedV4 []*V4AddrPort var calculatedV6 []*V6AddrPort for _, cr := range calculatedRemotes { if vpnAddr.Is4() { c := cr.ApplyV4(vpnAddr) if c != nil { calculatedV4 = append(calculatedV4, c) } } else if vpnAddr.Is6() { c := cr.ApplyV6(vpnAddr) if c != nil { calculatedV6 = append(calculatedV6, c) } } } lh.Lock() am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() if len(calculatedV4) > 0 { am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV4, lh.unlockedShouldAddV4) } if len(calculatedV6) > 0 { am.unlockedSetV6(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV6, lh.unlockedShouldAddV6) } return len(calculatedV4) > 0 || len(calculatedV6) > 0 } // unlockedGetRemoteList assumes you have the lh lock func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { // before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet for i, addr := range allAddrs { am, ok := lh.addrMap[addr] if ok { if i != 0 { lh.addrMap[allAddrs[0]] = am } return am } } am := NewRemoteList(allAddrs, lh.shouldAdd) for _, addr := range allAddrs { lh.addrMap[addr] = am } return am } func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). Trace("remoteAllowList.Allow") } if !allow { return false } if lh.myVpnNetworksTable.Contains(to) { return false } return true } // unlockedShouldAddV4 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { udpAddr := protoV4AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). Trace("remoteAllowList.Allow") } if !allow { return false } if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) { return false } return true } // unlockedShouldAddV6 checks if to is allowed by our allow list func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { udpAddr := protoV6AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). Trace("remoteAllowList.Allow") } if !allow { return false } if lh.myVpnNetworksTable.Contains(udpAddr.Addr()) { return false } return true } func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { l := lh.GetLighthouses() return slices.Contains(l, vpnAddr) } func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool { l := lh.GetLighthouses() for i := range vpnAddrs { if slices.Contains(l, vpnAddrs[i]) { return true } } return false } func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { return } go func() { nb := make([]byte, 12, 12) out := make([]byte, mtu) for { select { case <-lh.ctx.Done(): return case addr := <-lh.queryChan: lh.innerQueryServer(addr, nb, out) } } }() } func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { if lh.IsLighthouseAddr(addr) { return } msg := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{}, } var v1Query, v2Query []byte var err error var v cert.Version queried := 0 lighthouses := lh.GetLighthouses() for _, lhVpnAddr := range lighthouses { hi := lh.ifce.GetHostInfo(lhVpnAddr) if hi != nil { v = hi.ConnectionState.myCert.Version() } else { v = lh.ifce.GetCertState().initiatingVersion } if v == cert.Version1 { if !addr.Is4() { lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). Error("Can't query lighthouse for v6 address using a v1 protocol") continue } if v1Query == nil { b := addr.As4() msg.Details.VpnAddr = nil msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) v1Query, err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("queryVpnAddr", addr). WithField("lighthouseAddr", lhVpnAddr). Error("Failed to marshal lighthouse v1 query payload") continue } } lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) queried++ } else if v == cert.Version2 { if v2Query == nil { msg.Details.OldVpnAddr = 0 msg.Details.VpnAddr = netAddrToProtoAddr(addr) v2Query, err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("queryVpnAddr", addr). WithField("lighthouseAddr", lhVpnAddr). Error("Failed to marshal lighthouse v2 query payload") continue } } lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) queried++ } else { lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) continue } } lh.metricTx(NebulaMeta_HostQuery, int64(queried)) } func (lh *LightHouse) StartUpdateWorker() { interval := lh.GetUpdateInterval() if lh.amLighthouse || interval == 0 { return } clockSource := time.NewTicker(time.Second * time.Duration(interval)) updateCtx, cancel := context.WithCancel(lh.ctx) lh.updateCancel = cancel go func() { defer clockSource.Stop() for { lh.SendUpdate() select { case <-updateCtx.Done(): return case <-clockSource.C: continue } } }() } func (lh *LightHouse) SendUpdate() { var v4 []*V4AddrPort var v6 []*V6AddrPort for _, e := range lh.GetAdvertiseAddrs() { if e.Addr().Is4() { v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port())) } else { v6 = append(v6, netAddrToProtoV6AddrPort(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() for _, e := range localAddrs(lh.l, lal) { if lh.myVpnNetworksTable.Contains(e) { continue } // Only add addrs that aren't my VPN/tun networks if e.Is4() { v4 = append(v4, netAddrToProtoV4AddrPort(e, uint16(lh.nebulaPort))) } else { v6 = append(v6, netAddrToProtoV6AddrPort(e, uint16(lh.nebulaPort))) } } nb := make([]byte, 12, 12) out := make([]byte, mtu) var v1Update, v2Update []byte var err error updated := 0 lighthouses := lh.GetLighthouses() for _, lhVpnAddr := range lighthouses { var v cert.Version hi := lh.ifce.GetHostInfo(lhVpnAddr) if hi != nil { v = hi.ConnectionState.myCert.Version() } else { v = lh.ifce.GetCertState().initiatingVersion } if v == cert.Version1 { if v1Update == nil { if !lh.myVpnNetworks[0].Addr().Is4() { lh.l.WithField("lighthouseAddr", lhVpnAddr). Warn("cannot update lighthouse using v1 protocol without an IPv4 address") continue } var relays []uint32 for _, r := range lh.GetRelaysForMe() { if !r.Is4() { continue } b := r.As4() relays = append(relays, binary.BigEndian.Uint32(b[:])) } b := lh.myVpnNetworks[0].Addr().As4() msg := NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ V4AddrPorts: v4, V6AddrPorts: v6, OldRelayVpnAddrs: relays, OldVpnAddr: binary.BigEndian.Uint32(b[:]), }, } v1Update, err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). Error("Error while marshaling for lighthouse v1 update") continue } } lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) updated++ } else if v == cert.Version2 { if v2Update == nil { var relays []*Addr for _, r := range lh.GetRelaysForMe() { relays = append(relays, netAddrToProtoAddr(r)) } msg := NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ V4AddrPorts: v4, V6AddrPorts: v6, RelayVpnAddrs: relays, }, } v2Update, err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). Error("Error while marshaling for lighthouse v2 update") continue } } lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) updated++ } else { lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) continue } } lh.metricTx(NebulaMeta_HostUpdateNotification, int64(updated)) } type LightHouseHandler struct { lh *LightHouse nb []byte out []byte pb []byte meta *NebulaMeta l *logrus.Logger } func (lh *LightHouse) NewRequestHandler() *LightHouseHandler { lhh := &LightHouseHandler{ lh: lh, nb: make([]byte, 12, 12), out: make([]byte, mtu), l: lh.l, pb: make([]byte, mtu), meta: &NebulaMeta{ Details: &NebulaMetaDetails{}, }, } return lhh } func (lh *LightHouse) metricRx(t NebulaMeta_MessageType, i int64) { lh.metrics.Rx(header.MessageType(t), 0, i) } func (lh *LightHouse) metricTx(t NebulaMeta_MessageType, i int64) { lh.metrics.Tx(header.MessageType(t), 0, i) } // This method is similar to Reset(), but it re-uses the pointer structs // so that we don't have to re-allocate them func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { details := lhh.meta.Details lhh.meta.Reset() // Keep the array memory around details.V4AddrPorts = details.V4AddrPorts[:0] details.V6AddrPorts = details.V6AddrPorts[:0] details.RelayVpnAddrs = details.RelayVpnAddrs[:0] details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] details.OldVpnAddr = 0 details.VpnAddr = nil lhh.meta.Details = details return lhh.meta } func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") return } if n.Details == nil { lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") return } lhh.lh.metricRx(n.Type, 1) switch n.Type { case NebulaMeta_HostQuery: lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I don't answer queries, but received from: ", addr) } return } queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() if err != nil { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). Debugln("Dropping malformed HostQuery") } return } if useVersion == cert.Version1 && queryVpnAddr.Is6() { // this case really shouldn't be possible to represent, but reject it anyway. if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). Debugln("invalid vpn addr for v1 handleHostQuery") } return } found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply if useVersion == cert.Version1 { b := queryVpnAddr.As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) } else { n.Details.VpnAddr = netAddrToProtoAddr(queryVpnAddr) } lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) if !found { return } if err != nil { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) } // sendHostPunchNotification signals the other side to punch some zero byte udp packets func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) { whereToPunch := fromVpnAddrs[0] found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) var useVersion cert.Version if targetHI == nil { useVersion = lhh.lh.ifce.GetCertState().initiatingVersion } else { crt := targetHI.GetCert().Certificate useVersion = crt.Version() // we can only retarget if we have a hostinfo newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs) if ok { whereToPunch = newDest } else { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") } } } if useVersion == cert.Version1 { if !whereToPunch.Is4() { return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") } b := whereToPunch.As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) } else if useVersion == cert.Version2 { n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch) } else { return 0, errors.New("unsupported version") } lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) if !found { return } if err != nil { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { if c.v4 != nil { if c.v4.learned != nil { n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } if c.v4.reported != nil && len(c.v4.reported) > 0 { n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } if c.v6 != nil { if c.v6.learned != nil { n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } if c.v6.reported != nil && len(c.v6.reported) > 0 { n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } if c.relay != nil { if v == cert.Version1 { b := [4]byte{} for _, r := range c.relay.relay { if !r.Is4() { continue } b = r.As4() n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) } } else if v == cert.Version2 { for _, r := range c.relay.relay { n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) } } else { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("version", v).Debug("unsupported protocol version") } } } } func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") } return } relays := n.Details.GetRelays() lhh.lh.Lock() am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am.Lock() lhh.lh.Unlock() am.unlockedSetV4(fromVpnAddrs[0], certVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(fromVpnAddrs[0], certVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { case lhh.lh.handshakeTrigger <- certVpnAddr: default: } } func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) } return } // not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr var detailsVpnAddr netip.Addr var useVersion cert.Version if n.Details.OldVpnAddr != 0 { //v1 always sets this field b := [4]byte{} binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) detailsVpnAddr = netip.AddrFrom4(b) useVersion = cert.Version1 } else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) useVersion = cert.Version2 } else { detailsVpnAddr = netip.Addr{} useVersion = cert.Version2 } //Simple check that the host sent this not someone else, if detailsVpnAddr is filled if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") } return } relays := n.Details.GetRelays() lhh.lh.Lock() am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck switch useVersion { case cert.Version1: if !fromVpnAddrs[0].Is4() { lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") return } vpnAddrB := fromVpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) case cert.Version2: // do nothing, we want to send a blank message default: lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") return } ln, err := n.MarshalTo(lhh.pb) if err != nil { lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { //It's possible the lighthouse is communicating with us using a non primary vpn addr, //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() if err != nil { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") } return } empty := []byte{0} punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) { if !vpnPeer.IsValid() { return } go func() { time.Sleep(lhh.lh.punchy.GetDelay()) lhh.lh.metricHolepunchTx.Inc(1) lhh.lh.punchConn.WriteTo(empty, vpnPeer) }() if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) } } remoteAllowList := lhh.lh.GetRemoteAllowList() for _, a := range n.Details.V4AddrPorts { b := protoV4AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { punch(b, detailsVpnAddr) } } for _, a := range n.Details.V6AddrPorts { b := protoV6AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { punch(b, detailsVpnAddr) } } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } func protoAddrToNetAddr(addr *Addr) netip.Addr { b := [16]byte{} binary.BigEndian.PutUint64(b[:8], addr.Hi) binary.BigEndian.PutUint64(b[8:], addr.Lo) return netip.AddrFrom16(b).Unmap() } func protoV4AddrPortToNetAddrPort(ap *V4AddrPort) netip.AddrPort { b := [4]byte{} binary.BigEndian.PutUint32(b[:], ap.Addr) return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ap.Port)) } func protoV6AddrPortToNetAddrPort(ap *V6AddrPort) netip.AddrPort { b := [16]byte{} binary.BigEndian.PutUint64(b[:8], ap.Hi) binary.BigEndian.PutUint64(b[8:], ap.Lo) return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ap.Port)) } func netAddrToProtoAddr(addr netip.Addr) *Addr { b := addr.As16() return &Addr{ Hi: binary.BigEndian.Uint64(b[:8]), Lo: binary.BigEndian.Uint64(b[8:]), } } func netAddrToProtoV4AddrPort(addr netip.Addr, port uint16) *V4AddrPort { v4Addr := addr.As4() return &V4AddrPort{ Addr: binary.BigEndian.Uint32(v4Addr[:]), Port: uint32(port), } } func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort { v6Addr := addr.As16() return &V6AddrPort{ Hi: binary.BigEndian.Uint64(v6Addr[:8]), Lo: binary.BigEndian.Uint64(v6Addr[8:]), Port: uint32(port), } } func (d *NebulaMetaDetails) GetRelays() []netip.Addr { var relays []netip.Addr if len(d.OldRelayVpnAddrs) > 0 { b := [4]byte{} for _, r := range d.OldRelayVpnAddrs { binary.BigEndian.PutUint32(b[:], r) relays = append(relays, netip.AddrFrom4(b)) } } if len(d.RelayVpnAddrs) > 0 { for _, r := range d.RelayVpnAddrs { relays = append(relays, protoAddrToNetAddr(r)) } } return relays } // FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) { for i := range prefixes { for j := range addrs { if prefixes[i].Contains(addrs[j]) { return addrs[j], true } } } return netip.Addr{}, false } func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) { if d.OldVpnAddr != 0 { b := [4]byte{} binary.BigEndian.PutUint32(b[:], d.OldVpnAddr) detailsVpnAddr := netip.AddrFrom4(b) return detailsVpnAddr, cert.Version1, nil } else if d.VpnAddr != nil { detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr) return detailsVpnAddr, cert.Version2, nil } else { return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr } } ================================================ FILE: lighthouse_test.go ================================================ package nebula import ( "encoding/binary" "fmt" "net/netip" "testing" "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" ) func TestOldIPv4Only(t *testing.T) { // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility b := []byte{8, 129, 130, 132, 80, 16, 10} var m V4AddrPort err := m.Unmarshal(b) require.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) } func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } lh1 := "10.128.0.2" c := config.NewC(l) c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}} c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}} _, err = NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } lh1 := "10.128.0.2" c := config.NewC(l) c.Settings["lighthouse"] = map[string]any{ "hosts": []any{lh1}, "interval": "1s", } c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}} lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } c := config.NewC(l) lh, err := NewLightHouseFromConfig(b.Context(), l, c, cs, nil, nil) require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") vpnIp3 := netip.MustParseAddr("0.0.0.3") lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil) lh.addrMap[vpnIp3].unlockedSetV4( vpnIp3, vpnIp3, []*V4AddrPort{ netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()), netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()), }, func(netip.Addr, *V4AddrPort) bool { return true }, ) rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") vpnIp2 := netip.MustParseAddr("0.0.0.3") lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil) lh.addrMap[vpnIp2].unlockedSetV4( vpnIp3, vpnIp3, []*V4AddrPort{ netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()), netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()), }, func(netip.Addr, *V4AddrPort) bool { return true }, ) mw := &mockEncWriter{} hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ OldVpnAddr: 4, V4AddrPorts: nil, }, } p, err := req.Marshal() require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) } }) b.Run("found", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ OldVpnAddr: 3, V4AddrPorts: nil, }, } p, err := req.Marshal() require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) } }) } func TestLighthouse_Memory(t *testing.T) { l := test.NewLogger() myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242") myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242") myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242") myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242") myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242") myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243") myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244") myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245") myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246") myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247") myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248") myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249") myVpnIp := netip.MustParseAddr("10.128.0.2") theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242") theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242") theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242") theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242") theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242") theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) lh.ifce = &mockEncWriter{} require.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3) // Grow it back to 2 newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) newLHHostUpdate( myUdpAddr0, myVpnIp, []netip.AddrPort{ myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, //Duplicated on purpose myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr10, myUdpAddr11, // This should get cut }, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) // Make sure we won't add ips in our vpn network bad1 := netip.MustParseAddrPort("10.128.0.99:4242") bad2 := netip.MustParseAddrPort("10.128.0.100:4242") good := netip.MustParseAddrPort("1.128.0.99:4242") newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.V4AddrPorts, good) } func TestLighthouse_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} c.Settings["listen"] = map[string]any{"port": 4242} myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) nc := map[string]any{ "static_host_map": map[string]any{ "10.128.0.2": []any{"1.1.1.1:4242"}, }, } rc, err := yaml.Marshal(nc) require.NoError(t, err) c.ReloadConfigString(string(rc)) err = lh.reload(c, false) require.NoError(t, err) } func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{}, } if queryVpnIp.Is4() { bip := queryVpnIp.As4() req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) } else { req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp) } b, err := req.Marshal() if err != nil { panic(err) } filter := NebulaMeta_HostQueryReply w := &testEncWriter{ metaFilter: &filter, } lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{}, } if vpnIp.Is4() { bip := vpnIp.As4() req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:]) } else { req.Details.VpnAddr = netAddrToProtoAddr(vpnIp) } for _, v := range addrs { if v.Addr().Is4() { req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port())) } else { req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port())) } } b, err := req.Marshal() if err != nil { panic(err) } w := &testEncWriter{} lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType vpnIp netip.Addr msg *NebulaMeta } type testEncWriter struct { lastReply testLhReply metaFilter *NebulaMeta_MessageType protocolVersion cert.Version } func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { tw.lastReply = testLhReply{ nebType: t, nebSubType: st, vpnIp: hostinfo.vpnAddrs[0], msg: msg, } } if err != nil { panic(err) } } func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { tw.lastReply = testLhReply{ nebType: t, nebSubType: st, vpnIp: vpnIp, msg: msg, } } if err != nil { panic(err) } } func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { return nil } func (tw *testEncWriter) GetCertState() *CertState { return &CertState{initiatingVersion: tw.protocolVersion} } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { h := protoV4AddrPortToNetAddrPort(have[k]) if !(h == w) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } func Test_findNetworkUnion(t *testing.T) { var out netip.Addr var ok bool tenDot := netip.MustParsePrefix("10.0.0.0/8") oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16") fe80 := netip.MustParsePrefix("fe80::/8") fc00 := netip.MustParsePrefix("fc00::/7") a1 := netip.MustParseAddr("10.0.0.1") afe81 := netip.MustParseAddr("fe80::1") //simple out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1}) assert.True(t, ok) assert.Equal(t, out, a1) //mixed lengths out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81}) assert.True(t, ok) assert.Equal(t, out, a1) out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1}) assert.True(t, ok) assert.Equal(t, out, a1) //mixed family out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1}) assert.True(t, ok) assert.Equal(t, out, a1) out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) assert.True(t, ok) assert.Equal(t, out, a1) //ordering out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1}) assert.True(t, ok) assert.Equal(t, out, a1) out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1}) assert.True(t, ok) assert.Equal(t, out, afe81) //some mismatches out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81}) assert.True(t, ok) assert.Equal(t, out, afe81) out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) assert.True(t, ok) assert.Equal(t, out, afe81) //falsey cases out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) assert.False(t, ok) out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) assert.False(t, ok) out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) assert.False(t, ok) out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) assert.False(t, ok) } func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) { l := test.NewLogger() myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242") testSameHostNotStatic := netip.MustParseAddr("10.128.0.41") testStaticHost := netip.MustParseAddr("10.128.0.42") //myVpnIp := netip.MustParseAddr("10.128.0.2") c := config.NewC(l) lh1 := "10.128.0.2" c.Settings["lighthouse"] = map[string]any{ "hosts": []any{lh1}, "interval": "1s", } c.Settings["listen"] = map[string]any{"port": 4242} c.Settings["static_host_map"] = map[string]any{ lh1: []any{"1.1.1.1:4242"}, "10.128.0.42": []any{"1.2.3.4:4242"}, } myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} //test that we actually have the static entry: out := lh.Query(testStaticHost) assert.NotNil(t, out) assert.Equal(t, out.vpnAddrs[0], testStaticHost) out.Rebuild([]netip.Prefix{}) //why tho assert.Equal(t, out.addrs[0], myUdpAddr2) //bolt on a lower numbered primary IP am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost}) am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost} lh.addrMap[testSameHostNotStatic] = am out.Rebuild([]netip.Prefix{}) //??? //test that we actually have the static entry: out = lh.Query(testStaticHost) assert.NotNil(t, out) assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic) assert.Equal(t, out.vpnAddrs[1], testStaticHost) assert.Equal(t, out.addrs[0], myUdpAddr2) //test that we actually have the static entry for BOTH: out2 := lh.Query(testSameHostNotStatic) assert.Same(t, out2, out) //now do the delete lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost}) //verify out = lh.Query(testSameHostNotStatic) assert.NotNil(t, out) if out == nil { t.Fatal("expected non-nil query for the static host") } assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic) assert.Equal(t, out.vpnAddrs[1], testStaticHost) assert.Equal(t, out.addrs[0], myUdpAddr2) } func TestLighthouse_DeletesWork(t *testing.T) { l := test.NewLogger() myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242") testHost := netip.MustParseAddr("10.128.0.42") c := config.NewC(l) lh1 := "10.128.0.2" c.Settings["lighthouse"] = map[string]any{ "hosts": []any{lh1}, "interval": "1s", } c.Settings["listen"] = map[string]any{"port": 4242} c.Settings["static_host_map"] = map[string]any{ lh1: []any{"1.1.1.1:4242"}, } myVpnNet := netip.MustParsePrefix("10.128.0.1/24") nt := new(bart.Lite) nt.Insert(myVpnNet) cs := &CertState{ myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} //insert the host am := lh.unlockedGetRemoteList([]netip.Addr{testHost}) am.vpnAddrs = []netip.Addr{testHost} am.addrs = []netip.AddrPort{myUdpAddr2} lh.addrMap[testHost] = am am.Rebuild([]netip.Prefix{}) //??? //test that we actually have the entry: out := lh.Query(testHost) assert.NotNil(t, out) assert.Equal(t, out.vpnAddrs[0], testHost) out.Rebuild([]netip.Prefix{}) //why tho assert.Equal(t, out.addrs[0], myUdpAddr2) //now do the delete lh.DeleteVpnAddrs([]netip.Addr{testHost}) //verify out = lh.Query(testHost) assert.Nil(t, out) } ================================================ FILE: logger.go ================================================ package nebula import ( "fmt" "strings" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) func configLogger(l *logrus.Logger, c *config.C) error { // set up our logging level logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info"))) if err != nil { return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels) } l.SetLevel(logLevel) disableTimestamp := c.GetBool("logging.disable_timestamp", false) timestampFormat := c.GetString("logging.timestamp_format", "") fullTimestamp := (timestampFormat != "") if timestampFormat == "" { timestampFormat = time.RFC3339 } logFormat := strings.ToLower(c.GetString("logging.format", "text")) switch logFormat { case "text": l.Formatter = &logrus.TextFormatter{ TimestampFormat: timestampFormat, FullTimestamp: fullTimestamp, DisableTimestamp: disableTimestamp, } case "json": l.Formatter = &logrus.JSONFormatter{ TimestampFormat: timestampFormat, DisableTimestamp: disableTimestamp, } default: return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) } return nil } ================================================ FILE: main.go ================================================ package nebula import ( "context" "fmt" "net" "net/netip" "runtime/debug" "strings" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" "go.yaml.in/yaml/v3" ) type m = map[string]any func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) { ctx, cancel := context.WithCancel(context.Background()) // Automatically cancel the context if Main returns an error, to signal all created goroutines to quit. defer func() { if reterr != nil { cancel() } }() if buildVersion == "" { buildVersion = moduleVersion() } l := logger l.Formatter = &logrus.TextFormatter{ FullTimestamp: true, } // Print the config if in test, the exit comes later if configTest { b, err := yaml.Marshal(c.Settings) if err != nil { return nil, err } // Print the final config l.Println(string(b)) } err := configLogger(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err) } c.RegisterReloadCallback(func(c *config.C) { err := configLogger(l, c) if err != nil { l.WithError(err).Error("Failed to configure the logger") } }) pki, err := NewPKIFromConfig(l, c) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } wireSSHReload(l, ssh, c) var sshStart func() if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") sshStart = nil } } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // All non system modifying configuration consumption should live above this line // tun config, listeners, anything modifying the computer should be below //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// var routines int // If `routines` is set, use that and ignore the specific values if routines = c.GetInt("routines", 0); routines != 0 { if routines < 1 { routines = 1 } if routines > 1 { l.WithField("routines", routines).Info("Using multiple routines") } } else { // deprecated and undocumented tunQueues := c.GetInt("tun.routines", 1) udpQueues := c.GetInt("listen.routines", 1) routines = max(tunQueues, udpQueues) if routines != 1 { l.WithField("routines", routines).Warn("Setting tun.routines and listen.routines is deprecated. Use `routines` instead") } } // EXPERIMENTAL // Intentionally not documented yet while we do more testing and determine // a good default value. conntrackCacheTimeout := c.GetDuration("firewall.conntrack.routine_cache_timeout", 0) if routines > 1 && !c.IsSet("firewall.conntrack.routine_cache_timeout") { // Use a different default if we are running with multiple routines conntrackCacheTimeout = 1 * time.Second } if conntrackCacheTimeout > 0 { l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") } var tun overlay.Device if !configTest { c.CatchHUP(ctx) if deviceFactory == nil { deviceFactory = overlay.NewDeviceFromConfig } tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } defer func() { if reterr != nil { tun.Close() } }() } // set up our UDP listener udpConns := make([]udp.Conn, routines) port := c.GetInt("listen.port", 0) if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") var listenHost netip.Addr if rawListenHost == "[::]" { // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. listenHost = netip.IPv6Unspecified() } else { ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } if len(ips) == 0 { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } listenHost = ips[0].Unmap() } for i := 0; i < routines; i++ { l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } udpServer.ReloadConfig(c) udpConns[i] = udpServer // If port is dynamic, discover it before the next pass through the for loop // This way all routines will use the same port correctly if port == 0 { uPort, err := udpServer.LocalAddr() if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } port = int(uPort.Port()) } } } hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } var messageMetrics *MessageMetrics if c.GetBool("stats.message_metrics", false) { messageMetrics = newMessageMetrics() } else { messageMetrics = newMessageMetricsOnlyRecvError() } useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) handshakeConfig := HandshakeConfig{ tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), useRelays: useRelays, messageMetrics: messageMetrics, } handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger serveDns := false if c.GetBool("lighthouse.serve_dns", false) { if c.GetBool("lighthouse.am_lighthouse", false) { serveDns = true } else { l.Warn("DNS server refusing to run because this host is not a lighthouse.") } } ifConfig := &InterfaceConfig{ HostMap: hostMap, Inside: tun, Outside: udpConns[0], pki: pki, Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, connectionManager: connManager, lightHouse: lightHouse, tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), DropMulticast: c.GetBool("tun.drop_multicast", false), routines: routines, MessageMetrics: messageMetrics, version: buildVersion, relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, } var ifce *Interface if !configTest { ifce, err = NewInterface(ctx, ifConfig) if err != nil { return nil, fmt.Errorf("failed to initialize interface: %s", err) } ifce.writers = udpConns lightHouse.ifce = ifce ifce.RegisterConfigChangeCallbacks(c) ifce.reloadDisconnectInvalid(c) ifce.reloadSendRecvError(c) ifce.reloadAcceptRecvError(c) handshakeManager.f = ifce go handshakeManager.Run(ctx) } statsStart, err := startStats(l, c, buildVersion, configTest) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } if configTest { return nil, nil } go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) attachCommands(l, c, ssh, ifce) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) } return &Control{ ifce, l, ctx, cancel, sshStart, statsStart, dnsStart, lightHouse.StartUpdateWorker, connManager.Start, }, nil } func moduleVersion() string { info, ok := debug.ReadBuildInfo() if !ok { return "" } for _, dep := range info.Deps { if dep.Path == "github.com/slackhq/nebula" { return strings.TrimPrefix(dep.Version, "v") } } return "" } ================================================ FILE: message_metrics.go ================================================ package nebula import ( "fmt" "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/header" ) type MessageMetrics struct { rx [][]metrics.Counter tx [][]metrics.Counter rxUnknown metrics.Counter txUnknown metrics.Counter } func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { if m != nil { if t >= 0 && int(t) < len(m.rx) && s >= 0 && int(s) < len(m.rx[t]) { m.rx[t][s].Inc(i) } else if m.rxUnknown != nil { m.rxUnknown.Inc(i) } } } func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int64) { if m != nil { if t >= 0 && int(t) < len(m.tx) && s >= 0 && int(s) < len(m.tx[t]) { m.tx[t][s].Inc(i) } else if m.txUnknown != nil { m.txUnknown.Inc(i) } } } func newMessageMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { return [][]metrics.Counter{ { metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.handshake_ixpsk0", t), nil), }, nil, {metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)}, {metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.lighthouse", t), nil)}, { metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_request", t), nil), metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.test_response", t), nil), }, {metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.close_tunnel", t), nil)}, } } return &MessageMetrics{ rx: gen("rx"), tx: gen("tx"), rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), } } // Historically we only recorded recv_error, so this is backwards compat func newMessageMetricsOnlyRecvError() *MessageMetrics { gen := func(t string) [][]metrics.Counter { return [][]metrics.Counter{ nil, nil, {metrics.GetOrRegisterCounter(fmt.Sprintf("messages.%s.recv_error", t), nil)}, } } return &MessageMetrics{ rx: gen("rx"), tx: gen("tx"), } } func newLighthouseMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { h := make([][]metrics.Counter, len(NebulaMeta_MessageType_name)) used := []NebulaMeta_MessageType{ NebulaMeta_HostQuery, NebulaMeta_HostQueryReply, NebulaMeta_HostUpdateNotification, NebulaMeta_HostPunchNotification, NebulaMeta_HostUpdateNotificationAck, } for _, i := range used { h[i] = []metrics.Counter{metrics.GetOrRegisterCounter(fmt.Sprintf("lighthouse.%s.%s", t, i.String()), nil)} } return h } return &MessageMetrics{ rx: gen("rx"), tx: gen("tx"), rxUnknown: metrics.GetOrRegisterCounter("lighthouse.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("lighthouse.tx.other", nil), } } ================================================ FILE: nebula.pb.go ================================================ // Code generated by protoc-gen-gogo. DO NOT EDIT. // source: nebula.proto package nebula import ( fmt "fmt" proto "github.com/gogo/protobuf/proto" io "io" math "math" math_bits "math/bits" ) // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal var _ = fmt.Errorf var _ = math.Inf // This is a compile-time assertion to ensure that this generated file // is compatible with the proto package it is being compiled against. // A compilation error at this line likely means your copy of the // proto package needs to be updated. const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type NebulaMeta_MessageType int32 const ( NebulaMeta_None NebulaMeta_MessageType = 0 NebulaMeta_HostQuery NebulaMeta_MessageType = 1 NebulaMeta_HostQueryReply NebulaMeta_MessageType = 2 NebulaMeta_HostUpdateNotification NebulaMeta_MessageType = 3 NebulaMeta_HostMovedNotification NebulaMeta_MessageType = 4 NebulaMeta_HostPunchNotification NebulaMeta_MessageType = 5 NebulaMeta_HostWhoami NebulaMeta_MessageType = 6 NebulaMeta_HostWhoamiReply NebulaMeta_MessageType = 7 NebulaMeta_PathCheck NebulaMeta_MessageType = 8 NebulaMeta_PathCheckReply NebulaMeta_MessageType = 9 NebulaMeta_HostUpdateNotificationAck NebulaMeta_MessageType = 10 ) var NebulaMeta_MessageType_name = map[int32]string{ 0: "None", 1: "HostQuery", 2: "HostQueryReply", 3: "HostUpdateNotification", 4: "HostMovedNotification", 5: "HostPunchNotification", 6: "HostWhoami", 7: "HostWhoamiReply", 8: "PathCheck", 9: "PathCheckReply", 10: "HostUpdateNotificationAck", } var NebulaMeta_MessageType_value = map[string]int32{ "None": 0, "HostQuery": 1, "HostQueryReply": 2, "HostUpdateNotification": 3, "HostMovedNotification": 4, "HostPunchNotification": 5, "HostWhoami": 6, "HostWhoamiReply": 7, "PathCheck": 8, "PathCheckReply": 9, "HostUpdateNotificationAck": 10, } func (x NebulaMeta_MessageType) String() string { return proto.EnumName(NebulaMeta_MessageType_name, int32(x)) } func (NebulaMeta_MessageType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{0, 0} } type NebulaPing_MessageType int32 const ( NebulaPing_Ping NebulaPing_MessageType = 0 NebulaPing_Reply NebulaPing_MessageType = 1 ) var NebulaPing_MessageType_name = map[int32]string{ 0: "Ping", 1: "Reply", } var NebulaPing_MessageType_value = map[string]int32{ "Ping": 0, "Reply": 1, } func (x NebulaPing_MessageType) String() string { return proto.EnumName(NebulaPing_MessageType_name, int32(x)) } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{5, 0} } type NebulaControl_MessageType int32 const ( NebulaControl_None NebulaControl_MessageType = 0 NebulaControl_CreateRelayRequest NebulaControl_MessageType = 1 NebulaControl_CreateRelayResponse NebulaControl_MessageType = 2 ) var NebulaControl_MessageType_name = map[int32]string{ 0: "None", 1: "CreateRelayRequest", 2: "CreateRelayResponse", } var NebulaControl_MessageType_value = map[string]int32{ "None": 0, "CreateRelayRequest": 1, "CreateRelayResponse": 2, } func (x NebulaControl_MessageType) String() string { return proto.EnumName(NebulaControl_MessageType_name, int32(x)) } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{8, 0} } type NebulaMeta struct { Type NebulaMeta_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaMeta_MessageType" json:"Type,omitempty"` Details *NebulaMetaDetails `protobuf:"bytes,2,opt,name=Details,proto3" json:"Details,omitempty"` } func (m *NebulaMeta) Reset() { *m = NebulaMeta{} } func (m *NebulaMeta) String() string { return proto.CompactTextString(m) } func (*NebulaMeta) ProtoMessage() {} func (*NebulaMeta) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{0} } func (m *NebulaMeta) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *NebulaMeta) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_NebulaMeta.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *NebulaMeta) XXX_Merge(src proto.Message) { xxx_messageInfo_NebulaMeta.Merge(m, src) } func (m *NebulaMeta) XXX_Size() int { return m.Size() } func (m *NebulaMeta) XXX_DiscardUnknown() { xxx_messageInfo_NebulaMeta.DiscardUnknown(m) } var xxx_messageInfo_NebulaMeta proto.InternalMessageInfo func (m *NebulaMeta) GetType() NebulaMeta_MessageType { if m != nil { return m.Type } return NebulaMeta_None } func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { if m != nil { return m.Details } return nil } type NebulaMetaDetails struct { OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } func (m *NebulaMetaDetails) String() string { return proto.CompactTextString(m) } func (*NebulaMetaDetails) ProtoMessage() {} func (*NebulaMetaDetails) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{1} } func (m *NebulaMetaDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *NebulaMetaDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_NebulaMetaDetails.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *NebulaMetaDetails) XXX_Merge(src proto.Message) { xxx_messageInfo_NebulaMetaDetails.Merge(m, src) } func (m *NebulaMetaDetails) XXX_Size() int { return m.Size() } func (m *NebulaMetaDetails) XXX_DiscardUnknown() { xxx_messageInfo_NebulaMetaDetails.DiscardUnknown(m) } var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo // Deprecated: Do not use. func (m *NebulaMetaDetails) GetOldVpnAddr() uint32 { if m != nil { return m.OldVpnAddr } return 0 } func (m *NebulaMetaDetails) GetVpnAddr() *Addr { if m != nil { return m.VpnAddr } return nil } // Deprecated: Do not use. func (m *NebulaMetaDetails) GetOldRelayVpnAddrs() []uint32 { if m != nil { return m.OldRelayVpnAddrs } return nil } func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { if m != nil { return m.RelayVpnAddrs } return nil } func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { if m != nil { return m.V4AddrPorts } return nil } func (m *NebulaMetaDetails) GetV6AddrPorts() []*V6AddrPort { if m != nil { return m.V6AddrPorts } return nil } func (m *NebulaMetaDetails) GetCounter() uint32 { if m != nil { return m.Counter } return 0 } type Addr struct { Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` } func (m *Addr) Reset() { *m = Addr{} } func (m *Addr) String() string { return proto.CompactTextString(m) } func (*Addr) ProtoMessage() {} func (*Addr) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{2} } func (m *Addr) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *Addr) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_Addr.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *Addr) XXX_Merge(src proto.Message) { xxx_messageInfo_Addr.Merge(m, src) } func (m *Addr) XXX_Size() int { return m.Size() } func (m *Addr) XXX_DiscardUnknown() { xxx_messageInfo_Addr.DiscardUnknown(m) } var xxx_messageInfo_Addr proto.InternalMessageInfo func (m *Addr) GetHi() uint64 { if m != nil { return m.Hi } return 0 } func (m *Addr) GetLo() uint64 { if m != nil { return m.Lo } return 0 } type V4AddrPort struct { Addr uint32 `protobuf:"varint,1,opt,name=Addr,proto3" json:"Addr,omitempty"` Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` } func (m *V4AddrPort) Reset() { *m = V4AddrPort{} } func (m *V4AddrPort) String() string { return proto.CompactTextString(m) } func (*V4AddrPort) ProtoMessage() {} func (*V4AddrPort) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{3} } func (m *V4AddrPort) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *V4AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_V4AddrPort.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *V4AddrPort) XXX_Merge(src proto.Message) { xxx_messageInfo_V4AddrPort.Merge(m, src) } func (m *V4AddrPort) XXX_Size() int { return m.Size() } func (m *V4AddrPort) XXX_DiscardUnknown() { xxx_messageInfo_V4AddrPort.DiscardUnknown(m) } var xxx_messageInfo_V4AddrPort proto.InternalMessageInfo func (m *V4AddrPort) GetAddr() uint32 { if m != nil { return m.Addr } return 0 } func (m *V4AddrPort) GetPort() uint32 { if m != nil { return m.Port } return 0 } type V6AddrPort struct { Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` } func (m *V6AddrPort) Reset() { *m = V6AddrPort{} } func (m *V6AddrPort) String() string { return proto.CompactTextString(m) } func (*V6AddrPort) ProtoMessage() {} func (*V6AddrPort) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{4} } func (m *V6AddrPort) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *V6AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_V6AddrPort.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *V6AddrPort) XXX_Merge(src proto.Message) { xxx_messageInfo_V6AddrPort.Merge(m, src) } func (m *V6AddrPort) XXX_Size() int { return m.Size() } func (m *V6AddrPort) XXX_DiscardUnknown() { xxx_messageInfo_V6AddrPort.DiscardUnknown(m) } var xxx_messageInfo_V6AddrPort proto.InternalMessageInfo func (m *V6AddrPort) GetHi() uint64 { if m != nil { return m.Hi } return 0 } func (m *V6AddrPort) GetLo() uint64 { if m != nil { return m.Lo } return 0 } func (m *V6AddrPort) GetPort() uint32 { if m != nil { return m.Port } return 0 } type NebulaPing struct { Type NebulaPing_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaPing_MessageType" json:"Type,omitempty"` Time uint64 `protobuf:"varint,2,opt,name=Time,proto3" json:"Time,omitempty"` } func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{5} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *NebulaPing) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_NebulaPing.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *NebulaPing) XXX_Merge(src proto.Message) { xxx_messageInfo_NebulaPing.Merge(m, src) } func (m *NebulaPing) XXX_Size() int { return m.Size() } func (m *NebulaPing) XXX_DiscardUnknown() { xxx_messageInfo_NebulaPing.DiscardUnknown(m) } var xxx_messageInfo_NebulaPing proto.InternalMessageInfo func (m *NebulaPing) GetType() NebulaPing_MessageType { if m != nil { return m.Type } return NebulaPing_Ping } func (m *NebulaPing) GetTime() uint64 { if m != nil { return m.Time } return 0 } type NebulaHandshake struct { Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"` Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,proto3" json:"Hmac,omitempty"` } func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *NebulaHandshake) XXX_Merge(src proto.Message) { xxx_messageInfo_NebulaHandshake.Merge(m, src) } func (m *NebulaHandshake) XXX_Size() int { return m.Size() } func (m *NebulaHandshake) XXX_DiscardUnknown() { xxx_messageInfo_NebulaHandshake.DiscardUnknown(m) } var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails { if m != nil { return m.Details } return nil } func (m *NebulaHandshake) GetHmac() []byte { if m != nil { return m.Hmac } return nil } type NebulaHandshakeDetails struct { Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"` InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"` ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` } func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) { xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src) } func (m *NebulaHandshakeDetails) XXX_Size() int { return m.Size() } func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() { xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m) } var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo func (m *NebulaHandshakeDetails) GetCert() []byte { if m != nil { return m.Cert } return nil } func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 { if m != nil { return m.InitiatorIndex } return 0 } func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 { if m != nil { return m.ResponderIndex } return 0 } func (m *NebulaHandshakeDetails) GetCookie() uint64 { if m != nil { return m.Cookie } return 0 } func (m *NebulaHandshakeDetails) GetTime() uint64 { if m != nil { return m.Time } return 0 } func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { if m != nil { return m.CertVersion } return 0 } type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` ResponderRelayIndex uint32 `protobuf:"varint,3,opt,name=ResponderRelayIndex,proto3" json:"ResponderRelayIndex,omitempty"` OldRelayToAddr uint32 `protobuf:"varint,4,opt,name=OldRelayToAddr,proto3" json:"OldRelayToAddr,omitempty"` // Deprecated: Do not use. OldRelayFromAddr uint32 `protobuf:"varint,5,opt,name=OldRelayFromAddr,proto3" json:"OldRelayFromAddr,omitempty"` // Deprecated: Do not use. RelayToAddr *Addr `protobuf:"bytes,6,opt,name=RelayToAddr,proto3" json:"RelayToAddr,omitempty"` RelayFromAddr *Addr `protobuf:"bytes,7,opt,name=RelayFromAddr,proto3" json:"RelayFromAddr,omitempty"` } func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } func (m *NebulaControl) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { return xxx_messageInfo_NebulaControl.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) if err != nil { return nil, err } return b[:n], nil } } func (m *NebulaControl) XXX_Merge(src proto.Message) { xxx_messageInfo_NebulaControl.Merge(m, src) } func (m *NebulaControl) XXX_Size() int { return m.Size() } func (m *NebulaControl) XXX_DiscardUnknown() { xxx_messageInfo_NebulaControl.DiscardUnknown(m) } var xxx_messageInfo_NebulaControl proto.InternalMessageInfo func (m *NebulaControl) GetType() NebulaControl_MessageType { if m != nil { return m.Type } return NebulaControl_None } func (m *NebulaControl) GetInitiatorRelayIndex() uint32 { if m != nil { return m.InitiatorRelayIndex } return 0 } func (m *NebulaControl) GetResponderRelayIndex() uint32 { if m != nil { return m.ResponderRelayIndex } return 0 } // Deprecated: Do not use. func (m *NebulaControl) GetOldRelayToAddr() uint32 { if m != nil { return m.OldRelayToAddr } return 0 } // Deprecated: Do not use. func (m *NebulaControl) GetOldRelayFromAddr() uint32 { if m != nil { return m.OldRelayFromAddr } return 0 } func (m *NebulaControl) GetRelayToAddr() *Addr { if m != nil { return m.RelayToAddr } return nil } func (m *NebulaControl) GetRelayFromAddr() *Addr { if m != nil { return m.RelayFromAddr } return nil } func init() { proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) proto.RegisterEnum("nebula.NebulaControl_MessageType", NebulaControl_MessageType_name, NebulaControl_MessageType_value) proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") proto.RegisterType((*Addr)(nil), "nebula.Addr") proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl") } func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ // 785 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *NebulaMeta) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *NebulaMeta) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.Details != nil { { size, err := m.Details.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x12 } if m.Type != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Type)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func (m *NebulaMetaDetails) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *NebulaMetaDetails) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if len(m.RelayVpnAddrs) > 0 { for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { { size, err := m.RelayVpnAddrs[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x3a } } if m.VpnAddr != nil { { size, err := m.VpnAddr.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x32 } if len(m.OldRelayVpnAddrs) > 0 { dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) var j3 int for _, num := range m.OldRelayVpnAddrs { for num >= 1<<7 { dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) num >>= 7 j3++ } dAtA4[j3] = uint8(num) j3++ } i -= j3 copy(dAtA[i:], dAtA4[:j3]) i = encodeVarintNebula(dAtA, i, uint64(j3)) i-- dAtA[i] = 0x2a } if len(m.V6AddrPorts) > 0 { for iNdEx := len(m.V6AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { size, err := m.V6AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x22 } } if m.Counter != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Counter)) i-- dAtA[i] = 0x18 } if len(m.V4AddrPorts) > 0 { for iNdEx := len(m.V4AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { size, err := m.V4AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x12 } } if m.OldVpnAddr != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.OldVpnAddr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func (m *Addr) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *Addr) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *Addr) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.Lo != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) i-- dAtA[i] = 0x10 } if m.Hi != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func (m *V4AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *V4AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *V4AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.Port != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Port)) i-- dAtA[i] = 0x10 } if m.Addr != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Addr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func (m *V6AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *V6AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.Port != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Port)) i-- dAtA[i] = 0x18 } if m.Lo != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) i-- dAtA[i] = 0x10 } if m.Hi != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func (m *NebulaPing) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *NebulaPing) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *NebulaPing) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.Time != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Time)) i-- dAtA[i] = 0x10 } if m.Type != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Type)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func (m *NebulaHandshake) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *NebulaHandshake) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if len(m.Hmac) > 0 { i -= len(m.Hmac) copy(dAtA[i:], m.Hmac) i = encodeVarintNebula(dAtA, i, uint64(len(m.Hmac))) i-- dAtA[i] = 0x12 } if m.Details != nil { { size, err := m.Details.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0xa } return len(dAtA) - i, nil } func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *NebulaHandshakeDetails) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.CertVersion != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) i-- dAtA[i] = 0x40 } if m.Time != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Time)) i-- dAtA[i] = 0x28 } if m.Cookie != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Cookie)) i-- dAtA[i] = 0x20 } if m.ResponderIndex != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.ResponderIndex)) i-- dAtA[i] = 0x18 } if m.InitiatorIndex != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorIndex)) i-- dAtA[i] = 0x10 } if len(m.Cert) > 0 { i -= len(m.Cert) copy(dAtA[i:], m.Cert) i = encodeVarintNebula(dAtA, i, uint64(len(m.Cert))) i-- dAtA[i] = 0xa } return len(dAtA) - i, nil } func (m *NebulaControl) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) if err != nil { return nil, err } return dAtA[:n], nil } func (m *NebulaControl) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } func (m *NebulaControl) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int _ = l if m.RelayFromAddr != nil { { size, err := m.RelayFromAddr.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x3a } if m.RelayToAddr != nil { { size, err := m.RelayToAddr.MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } i -= size i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- dAtA[i] = 0x32 } if m.OldRelayFromAddr != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayFromAddr)) i-- dAtA[i] = 0x28 } if m.OldRelayToAddr != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayToAddr)) i-- dAtA[i] = 0x20 } if m.ResponderRelayIndex != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.ResponderRelayIndex)) i-- dAtA[i] = 0x18 } if m.InitiatorRelayIndex != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorRelayIndex)) i-- dAtA[i] = 0x10 } if m.Type != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Type)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } func encodeVarintNebula(dAtA []byte, offset int, v uint64) int { offset -= sovNebula(v) base := offset for v >= 1<<7 { dAtA[offset] = uint8(v&0x7f | 0x80) v >>= 7 offset++ } dAtA[offset] = uint8(v) return base } func (m *NebulaMeta) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Type != 0 { n += 1 + sovNebula(uint64(m.Type)) } if m.Details != nil { l = m.Details.Size() n += 1 + l + sovNebula(uint64(l)) } return n } func (m *NebulaMetaDetails) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.OldVpnAddr != 0 { n += 1 + sovNebula(uint64(m.OldVpnAddr)) } if len(m.V4AddrPorts) > 0 { for _, e := range m.V4AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } if m.Counter != 0 { n += 1 + sovNebula(uint64(m.Counter)) } if len(m.V6AddrPorts) > 0 { for _, e := range m.V6AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } if len(m.OldRelayVpnAddrs) > 0 { l = 0 for _, e := range m.OldRelayVpnAddrs { l += sovNebula(uint64(e)) } n += 1 + sovNebula(uint64(l)) + l } if m.VpnAddr != nil { l = m.VpnAddr.Size() n += 1 + l + sovNebula(uint64(l)) } if len(m.RelayVpnAddrs) > 0 { for _, e := range m.RelayVpnAddrs { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } return n } func (m *Addr) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Hi != 0 { n += 1 + sovNebula(uint64(m.Hi)) } if m.Lo != 0 { n += 1 + sovNebula(uint64(m.Lo)) } return n } func (m *V4AddrPort) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Addr != 0 { n += 1 + sovNebula(uint64(m.Addr)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) } return n } func (m *V6AddrPort) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Hi != 0 { n += 1 + sovNebula(uint64(m.Hi)) } if m.Lo != 0 { n += 1 + sovNebula(uint64(m.Lo)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) } return n } func (m *NebulaPing) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Type != 0 { n += 1 + sovNebula(uint64(m.Type)) } if m.Time != 0 { n += 1 + sovNebula(uint64(m.Time)) } return n } func (m *NebulaHandshake) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Details != nil { l = m.Details.Size() n += 1 + l + sovNebula(uint64(l)) } l = len(m.Hmac) if l > 0 { n += 1 + l + sovNebula(uint64(l)) } return n } func (m *NebulaHandshakeDetails) Size() (n int) { if m == nil { return 0 } var l int _ = l l = len(m.Cert) if l > 0 { n += 1 + l + sovNebula(uint64(l)) } if m.InitiatorIndex != 0 { n += 1 + sovNebula(uint64(m.InitiatorIndex)) } if m.ResponderIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderIndex)) } if m.Cookie != 0 { n += 1 + sovNebula(uint64(m.Cookie)) } if m.Time != 0 { n += 1 + sovNebula(uint64(m.Time)) } if m.CertVersion != 0 { n += 1 + sovNebula(uint64(m.CertVersion)) } return n } func (m *NebulaControl) Size() (n int) { if m == nil { return 0 } var l int _ = l if m.Type != 0 { n += 1 + sovNebula(uint64(m.Type)) } if m.InitiatorRelayIndex != 0 { n += 1 + sovNebula(uint64(m.InitiatorRelayIndex)) } if m.ResponderRelayIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderRelayIndex)) } if m.OldRelayToAddr != 0 { n += 1 + sovNebula(uint64(m.OldRelayToAddr)) } if m.OldRelayFromAddr != 0 { n += 1 + sovNebula(uint64(m.OldRelayFromAddr)) } if m.RelayToAddr != nil { l = m.RelayToAddr.Size() n += 1 + l + sovNebula(uint64(l)) } if m.RelayFromAddr != nil { l = m.RelayFromAddr.Size() n += 1 + l + sovNebula(uint64(l)) } return n } func sovNebula(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } func sozNebula(x uint64) (n int) { return sovNebula(uint64((x << 1) ^ uint64((int64(x) >> 63)))) } func (m *NebulaMeta) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: NebulaMeta: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: NebulaMeta: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) } m.Type = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Type |= NebulaMeta_MessageType(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } if m.Details == nil { m.Details = &NebulaMetaDetails{} } if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: NebulaMetaDetails: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: NebulaMetaDetails: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field OldVpnAddr", wireType) } m.OldVpnAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.OldVpnAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field V4AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } m.V4AddrPorts = append(m.V4AddrPorts, &V4AddrPort{}) if err := m.V4AddrPorts[len(m.V4AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 3: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Counter", wireType) } m.Counter = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Counter |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 4: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field V6AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } m.V6AddrPorts = append(m.V6AddrPorts, &V6AddrPort{}) if err := m.V6AddrPorts[len(m.V6AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 5: if wireType == 0 { var v uint32 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ v |= uint32(b&0x7F) << shift if b < 0x80 { break } } m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ packedLen |= int(b&0x7F) << shift if b < 0x80 { break } } if packedLen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + packedLen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } var elementCount int var count int for _, integer := range dAtA[iNdEx:postIndex] { if integer < 128 { count++ } } elementCount = count if elementCount != 0 && len(m.OldRelayVpnAddrs) == 0 { m.OldRelayVpnAddrs = make([]uint32, 0, elementCount) } for iNdEx < postIndex { var v uint32 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ v |= uint32(b&0x7F) << shift if b < 0x80 { break } } m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } } else { return fmt.Errorf("proto: wrong wireType = %d for field OldRelayVpnAddrs", wireType) } case 6: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field VpnAddr", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } if m.VpnAddr == nil { m.VpnAddr = &Addr{} } if err := m.VpnAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 7: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnAddrs", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } m.RelayVpnAddrs = append(m.RelayVpnAddrs, &Addr{}) if err := m.RelayVpnAddrs[len(m.RelayVpnAddrs)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *Addr) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: Addr: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: Addr: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) } m.Hi = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Hi |= uint64(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) } m.Lo = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Lo |= uint64(b&0x7F) << shift if b < 0x80 { break } } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *V4AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: V4AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: V4AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType) } m.Addr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Addr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Port", wireType) } m.Port = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Port |= uint32(b&0x7F) << shift if b < 0x80 { break } } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *V6AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: V6AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: V6AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) } m.Hi = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Hi |= uint64(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) } m.Lo = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Lo |= uint64(b&0x7F) << shift if b < 0x80 { break } } case 3: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Port", wireType) } m.Port = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Port |= uint32(b&0x7F) << shift if b < 0x80 { break } } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *NebulaPing) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: NebulaPing: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: NebulaPing: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) } m.Type = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Type |= NebulaPing_MessageType(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) } m.Time = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Time |= uint64(b&0x7F) << shift if b < 0x80 { break } } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *NebulaHandshake) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: NebulaHandshake: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: NebulaHandshake: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } if m.Details == nil { m.Details = &NebulaHandshakeDetails{} } if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 2: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Hmac", wireType) } var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ byteLen |= int(b&0x7F) << shift if b < 0x80 { break } } if byteLen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + byteLen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } m.Hmac = append(m.Hmac[:0], dAtA[iNdEx:postIndex]...) if m.Hmac == nil { m.Hmac = []byte{} } iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: NebulaHandshakeDetails: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: NebulaHandshakeDetails: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Cert", wireType) } var byteLen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ byteLen |= int(b&0x7F) << shift if b < 0x80 { break } } if byteLen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + byteLen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } m.Cert = append(m.Cert[:0], dAtA[iNdEx:postIndex]...) if m.Cert == nil { m.Cert = []byte{} } iNdEx = postIndex case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field InitiatorIndex", wireType) } m.InitiatorIndex = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.InitiatorIndex |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 3: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field ResponderIndex", wireType) } m.ResponderIndex = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.ResponderIndex |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 4: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Cookie", wireType) } m.Cookie = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Cookie |= uint64(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType) } m.Time = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Time |= uint64(b&0x7F) << shift if b < 0x80 { break } } case 8: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) } m.CertVersion = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.CertVersion |= uint32(b&0x7F) << shift if b < 0x80 { break } } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func (m *NebulaControl) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { preIndex := iNdEx var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= uint64(b&0x7F) << shift if b < 0x80 { break } } fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { return fmt.Errorf("proto: NebulaControl: wiretype end group for non-group") } if fieldNum <= 0 { return fmt.Errorf("proto: NebulaControl: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) } m.Type = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.Type |= NebulaControl_MessageType(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field InitiatorRelayIndex", wireType) } m.InitiatorRelayIndex = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.InitiatorRelayIndex |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 3: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field ResponderRelayIndex", wireType) } m.ResponderRelayIndex = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.ResponderRelayIndex |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 4: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field OldRelayToAddr", wireType) } m.OldRelayToAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.OldRelayToAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field OldRelayFromAddr", wireType) } m.OldRelayFromAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ m.OldRelayFromAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 6: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field RelayToAddr", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } if m.RelayToAddr == nil { m.RelayToAddr = &Addr{} } if err := m.RelayToAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 7: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field RelayFromAddr", wireType) } var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula } if iNdEx >= l { return io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ msglen |= int(b&0x7F) << shift if b < 0x80 { break } } if msglen < 0 { return ErrInvalidLengthNebula } postIndex := iNdEx + msglen if postIndex < 0 { return ErrInvalidLengthNebula } if postIndex > l { return io.ErrUnexpectedEOF } if m.RelayFromAddr == nil { m.RelayFromAddr = &Addr{} } if err := m.RelayFromAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) if err != nil { return err } if (skippy < 0) || (iNdEx+skippy) < 0 { return ErrInvalidLengthNebula } if (iNdEx + skippy) > l { return io.ErrUnexpectedEOF } iNdEx += skippy } } if iNdEx > l { return io.ErrUnexpectedEOF } return nil } func skipNebula(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 depth := 0 for iNdEx < l { var wire uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return 0, ErrIntOverflowNebula } if iNdEx >= l { return 0, io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ wire |= (uint64(b) & 0x7F) << shift if b < 0x80 { break } } wireType := int(wire & 0x7) switch wireType { case 0: for shift := uint(0); ; shift += 7 { if shift >= 64 { return 0, ErrIntOverflowNebula } if iNdEx >= l { return 0, io.ErrUnexpectedEOF } iNdEx++ if dAtA[iNdEx-1] < 0x80 { break } } case 1: iNdEx += 8 case 2: var length int for shift := uint(0); ; shift += 7 { if shift >= 64 { return 0, ErrIntOverflowNebula } if iNdEx >= l { return 0, io.ErrUnexpectedEOF } b := dAtA[iNdEx] iNdEx++ length |= (int(b) & 0x7F) << shift if b < 0x80 { break } } if length < 0 { return 0, ErrInvalidLengthNebula } iNdEx += length case 3: depth++ case 4: if depth == 0 { return 0, ErrUnexpectedEndOfGroupNebula } depth-- case 5: iNdEx += 4 default: return 0, fmt.Errorf("proto: illegal wireType %d", wireType) } if iNdEx < 0 { return 0, ErrInvalidLengthNebula } if depth == 0 { return iNdEx, nil } } return 0, io.ErrUnexpectedEOF } var ( ErrInvalidLengthNebula = fmt.Errorf("proto: negative length found during unmarshaling") ErrIntOverflowNebula = fmt.Errorf("proto: integer overflow") ErrUnexpectedEndOfGroupNebula = fmt.Errorf("proto: unexpected end of group") ) ================================================ FILE: nebula.proto ================================================ syntax = "proto3"; package nebula; option go_package = "github.com/slackhq/nebula"; message NebulaMeta { enum MessageType { None = 0; HostQuery = 1; HostQueryReply = 2; HostUpdateNotification = 3; HostMovedNotification = 4; HostPunchNotification = 5; HostWhoami = 6; HostWhoamiReply = 7; PathCheck = 8; PathCheckReply = 9; HostUpdateNotificationAck = 10; } MessageType Type = 1; NebulaMetaDetails Details = 2; } message NebulaMetaDetails { uint32 OldVpnAddr = 1 [deprecated = true]; Addr VpnAddr = 6; repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; repeated Addr RelayVpnAddrs = 7; repeated V4AddrPort V4AddrPorts = 2; repeated V6AddrPort V6AddrPorts = 4; uint32 counter = 3; } message Addr { uint64 Hi = 1; uint64 Lo = 2; } message V4AddrPort { uint32 Addr = 1; uint32 Port = 2; } message V6AddrPort { uint64 Hi = 1; uint64 Lo = 2; uint32 Port = 3; } message NebulaPing { enum MessageType { Ping = 0; Reply = 1; } MessageType Type = 1; uint64 Time = 2; } message NebulaHandshake { NebulaHandshakeDetails Details = 1; bytes Hmac = 2; } message NebulaHandshakeDetails { bytes Cert = 1; uint32 InitiatorIndex = 2; uint32 ResponderIndex = 3; uint64 Cookie = 4; uint64 Time = 5; uint32 CertVersion = 8; // reserved for WIP multiport reserved 6, 7; } message NebulaControl { enum MessageType { None = 0; CreateRelayRequest = 1; CreateRelayResponse = 2; } MessageType Type = 1; uint32 InitiatorRelayIndex = 2; uint32 ResponderRelayIndex = 3; uint32 OldRelayToAddr = 4 [deprecated = true]; uint32 OldRelayFromAddr = 5 [deprecated = true]; Addr RelayToAddr = 6; Addr RelayFromAddr = 7; } ================================================ FILE: noise.go ================================================ package nebula import ( "crypto/cipher" "encoding/binary" "errors" "github.com/flynn/noise" ) type endianness interface { PutUint64(b []byte, v uint64) } var noiseEndianness endianness = binary.BigEndian type NebulaCipherState struct { c noise.Cipher //k [32]byte //n uint64 } func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { return &NebulaCipherState{c: s.Cipher()} } // EncryptDanger encrypts and authenticates a given payload. // // out is a destination slice to hold the output of the EncryptDanger operation. // - ad is additional data, which will be authenticated and appended to out, but not encrypted. // - plaintext is encrypted, authenticated and appended to out. // - n is a nonce value which must never be re-used with this key. // - nb is a buffer used for temporary storage in the implementation of this call, which should // be re-used by callers to minimize garbage collection. func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { // TODO: Is this okay now that we have made messageCounter atomic? // Alternative may be to split the counter space into ranges //if n <= s.n { // return nil, errors.New("CRITICAL: a duplicate counter value was used") //} //s.n = n nb[0] = 0 nb[1] = 0 nb[2] = 0 nb[3] = 0 noiseEndianness.PutUint64(nb[4:], n) out = s.c.(cipher.AEAD).Seal(out, nb, plaintext, ad) //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) return out, nil } else { return nil, errors.New("no cipher state available to encrypt") } } func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { if s != nil { nb[0] = 0 nb[1] = 0 nb[2] = 0 nb[3] = 0 noiseEndianness.PutUint64(nb[4:], n) return s.c.(cipher.AEAD).Open(out, nb, ciphertext, ad) } else { return []byte{}, nil } } func (s *NebulaCipherState) Overhead() int { if s != nil { return s.c.(cipher.AEAD).Overhead() } return 0 } ================================================ FILE: noiseutil/boring.go ================================================ //go:build boringcrypto // +build boringcrypto package noiseutil import ( "crypto/aes" "crypto/cipher" "encoding/binary" // unsafe needed for go:linkname _ "unsafe" "github.com/flynn/noise" ) // EncryptLockNeeded indicates if calls to Encrypt need a lock // This is true for boringcrypto because the Seal function verifies that the // nonce is strictly increasing. const EncryptLockNeeded = true // NewGCMTLS is no longer exposed in go1.19+, so we need to link it in // See: https://github.com/golang/go/issues/56326 // // NewGCMTLS is the internal method used with boringcrypto that provides a // validated mode of AES-GCM which enforces the nonce is strictly // monotonically increasing. This is the TLS 1.2 specification for nonce // generation (which also matches the method used by the Noise Protocol) // // - https://github.com/golang/go/blob/go1.19/src/crypto/tls/cipher_suites.go#L520-L522 // - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L235-L237 // - https://github.com/golang/go/blob/go1.19/src/crypto/internal/boring/aes.go#L250 // - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/include/openssl/aead.h#L379-L381 // - https://github.com/google/boringssl/blob/ae223d6138807a13006342edfeef32e813246b39/crypto/fipsmodule/cipher/e_aes.c#L1082-L1093 // //go:linkname newGCMTLS crypto/internal/boring.NewGCMTLS func newGCMTLS(c cipher.Block) (cipher.AEAD, error) type cipherFn struct { fn func([32]byte) noise.Cipher name string } func (c cipherFn) Cipher(k [32]byte) noise.Cipher { return c.fn(k) } func (c cipherFn) CipherName() string { return c.name } // CipherAESGCM is the AES256-GCM AEAD cipher (using NewGCMTLS when GoBoring is present) var CipherAESGCM noise.CipherFunc = cipherFn{cipherAESGCMBoring, "AESGCM"} func cipherAESGCMBoring(k [32]byte) noise.Cipher { c, err := aes.NewCipher(k[:]) if err != nil { panic(err) } gcm, err := newGCMTLS(c) if err != nil { panic(err) } return aeadCipher{ gcm, func(n uint64) []byte { var nonce [12]byte binary.BigEndian.PutUint64(nonce[4:], n) return nonce[:] }, } } type aeadCipher struct { cipher.AEAD nonce func(uint64) []byte } func (c aeadCipher) Encrypt(out []byte, n uint64, ad, plaintext []byte) []byte { return c.Seal(out, c.nonce(n), plaintext, ad) } func (c aeadCipher) Decrypt(out []byte, n uint64, ad, ciphertext []byte) ([]byte, error) { return c.Open(out, c.nonce(n), ciphertext, ad) } ================================================ FILE: noiseutil/boring_test.go ================================================ //go:build boringcrypto // +build boringcrypto package noiseutil import ( "crypto/boring" "encoding/hex" "testing" "github.com/stretchr/testify/assert" ) func TestEncryptLockNeeded(t *testing.T) { assert.True(t, EncryptLockNeeded) } // Ensure NewGCMTLS validates the nonce is non-repeating func TestNewGCMTLS(t *testing.T) { assert.True(t, boring.Enabled()) // Test Case 16 from GCM Spec: // - (now dead link): http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-spec.pdf // - as listed in boringssl tests: https://github.com/google/boringssl/blob/fips-20220613/crypto/cipher_extra/test/cipher_tests.txt#L412-L418 key, _ := hex.DecodeString("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308") iv, _ := hex.DecodeString("cafebabefacedbaddecaf888") plaintext, _ := hex.DecodeString("d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a721c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39") aad, _ := hex.DecodeString("feedfacedeadbeeffeedfacedeadbeefabaddad2") expected, _ := hex.DecodeString("522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662") expectedTag, _ := hex.DecodeString("76fc6ece0f4e1768cddf8853bb2d551b") expected = append(expected, expectedTag...) var keyArray [32]byte copy(keyArray[:], key) c := CipherAESGCM.Cipher(keyArray) aead := c.(aeadCipher).AEAD dst := aead.Seal([]byte{}, iv, plaintext, aad) assert.Equal(t, expected, dst) // We expect this to fail since we are re-encrypting with a repeat IV assert.PanicsWithError(t, "boringcrypto: EVP_AEAD_CTX_seal failed", func() { dst = aead.Seal([]byte{}, iv, plaintext, aad) }) } ================================================ FILE: noiseutil/nist.go ================================================ package noiseutil import ( "crypto/ecdh" "crypto/rand" "fmt" "io" "github.com/flynn/noise" ) // DHP256 is the NIST P-256 ECDH function var DHP256 noise.DHFunc = newNISTCurve("P256", ecdh.P256(), 32) type nistCurve struct { name string curve ecdh.Curve dhLen int pubLen int } func newNISTCurve(name string, curve ecdh.Curve, byteLen int) nistCurve { return nistCurve{ name: name, curve: curve, dhLen: byteLen, // Standard uncompressed format, type (1 byte) plus both coordinates pubLen: 1 + 2*byteLen, } } func (c nistCurve) GenerateKeypair(rng io.Reader) (noise.DHKey, error) { if rng == nil { rng = rand.Reader } privkey, err := c.curve.GenerateKey(rng) if err != nil { return noise.DHKey{}, err } pubkey := privkey.PublicKey() return noise.DHKey{Private: privkey.Bytes(), Public: pubkey.Bytes()}, nil } func (c nistCurve) DH(privkey, pubkey []byte) ([]byte, error) { ecdhPubKey, err := c.curve.NewPublicKey(pubkey) if err != nil { return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) } ecdhPrivKey, err := c.curve.NewPrivateKey(privkey) if err != nil { return nil, fmt.Errorf("unable to unmarshal private key: %w", err) } return ecdhPrivKey.ECDH(ecdhPubKey) } func (c nistCurve) DHLen() int { // NOTE: Noise Protocol specifies "DHLen" to represent two things: // - The size of the public key // - The return size of the DH() function // But for standard NIST ECDH, the sizes of these are different. // Luckily, the flynn/noise library actually only uses this DHLen() // value to represent the public key size, so that is what we are // returning here. The length of the DH() return bytes are unaffected by // this value here. return c.pubLen } func (c nistCurve) DHName() string { return c.name } ================================================ FILE: noiseutil/notboring.go ================================================ //go:build !boringcrypto // +build !boringcrypto package noiseutil import ( "github.com/flynn/noise" ) // EncryptLockNeeded indicates if calls to Encrypt need a lock const EncryptLockNeeded = false // CipherAESGCM is the standard noise.CipherAESGCM when boringcrypto is not enabled var CipherAESGCM noise.CipherFunc = noise.CipherAESGCM ================================================ FILE: noiseutil/notboring_test.go ================================================ //go:build !boringcrypto // +build !boringcrypto package noiseutil import ( "testing" "github.com/stretchr/testify/assert" ) func TestEncryptLockNeeded(t *testing.T) { assert.False(t, EncryptLockNeeded) } ================================================ FILE: noiseutil/pkcs11.go ================================================ package noiseutil import ( "crypto/ecdh" "fmt" "strings" "github.com/slackhq/nebula/pkclient" "github.com/flynn/noise" ) // DHP256PKCS11 is the NIST P-256 ECDH function var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32) type nistP11Curve struct { nistCurve } func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve { return nistP11Curve{ newNISTCurve(name, curve, byteLen), } } func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) { //for this function "privkey" is actually a pkcs11 URI pkStr := string(privkey) //to set up a handshake, we need to also do non-pkcs11-DH. Handle that here. if !strings.HasPrefix(pkStr, "pkcs11:") { return DHP256.DH(privkey, pubkey) } ecdhPubKey, err := c.curve.NewPublicKey(pubkey) if err != nil { return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err) } //this is not the most performant way to do this (a long-lived client would be better) //but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users. client, err := pkclient.FromUrl(pkStr) if err != nil { return nil, err } defer func(client *pkclient.PKClient) { _ = client.Close() }(client) return client.DeriveNoise(ecdhPubKey.Bytes()) } ================================================ FILE: notboring.go ================================================ //go:build !boringcrypto package nebula var boringEnabled = func() bool { return false } ================================================ FILE: outside.go ================================================ package nebula import ( "encoding/binary" "errors" "net/netip" "time" "github.com/google/gopacket/layers" "golang.org/x/net/ipv6" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "golang.org/x/net/ipv4" ) const ( minFwPacketLen = 4 ) func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") } return } } var hostinfo *HostInfo // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } var ci *ConnectionState if hostinfo != nil { ci = hostinfo.ConnectionState } switch h.Type { case header.Message: if !f.handleEncrypted(ci, via, h) { return } switch h.Subtype { case header.MessageNone: if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { return } case header.MessageRelay: // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice // which will gracefully fail in the DecryptDanger call. signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) if err != nil { return } // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. f.handleHostRoaming(hostinfo, via) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo) f.connectionManager.RelayUsed(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } switch relay.Type { case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. via = ViaSender{ UdpAddr: via.UdpAddr, relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay, IsRelayed: true, } f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") return } // If that relay is Established, forward the payload through it if targetRelay.State == Established { switch targetRelay.Type { case ForwardingType: // Forward this packet through the relay tunnel // Find the target HostInfo f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) return case TerminalType: hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } } case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") return } //TODO: assert via is not relayed lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt test packet") return } if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding f.handleHostRoaming(hostinfo, via) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } // Fallthrough to the bottom to record incoming traffic // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they // are unauthenticated case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.handshakeManager.HandleIncoming(via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) f.handleRecvError(via.UdpAddr, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) if !f.handleEncrypted(ci, via, h) { return } hostinfo.logger(f.l).WithField("from", via). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt Control packet") return } f.relayManager.HandleControlMsg(hostinfo, d, f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) return } f.handleHostRoaming(hostinfo, via) f.connectionManager.In(hostinfo) } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote hostinfo.SetRemote(via.UdpAddr) } } // handleEncrypted returns true if a packet should be processed, false otherwise func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect if ci == nil { if !via.IsRelayed { f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) } return false } // If the window check fails, refuse to process the packet, but don't send a recv error if !ci.window.Check(f.l, h.MessageCounter) { return false } return true } var ( ErrPacketTooShort = errors.New("packet is too short") ErrUnknownIPVersion = errors.New("packet is an unknown ip version") ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length") ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short") ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short") ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet") ) // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { if len(data) < 1 { return ErrPacketTooShort } version := int((data[0] >> 4) & 0x0f) switch version { case ipv4.Version: return parseV4(data, incoming, fp) case ipv6.Version: return parseV6(data, incoming, fp) } return ErrUnknownIPVersion } func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { dataLen := len(data) if dataLen < ipv6.HeaderLen { return ErrIPv6PacketTooShort } if incoming { fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24]) fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40]) } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) } protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header offset := ipv6.HeaderLen // Start at the end of the ipv6 header next := 0 for { if protoAt >= dataLen { break } proto := layers.IPProtocol(data[protoAt]) switch proto { case layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) fp.RemotePort = 0 fp.LocalPort = 0 fp.Fragment = false return nil case layers.IPProtocolICMPv6: if dataLen < offset+6 { return ErrIPv6PacketTooShort } fp.Protocol = uint8(proto) fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6 icmptype := data[offset+1] switch icmptype { case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply: fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier default: fp.RemotePort = 0 } fp.Fragment = false return nil case layers.IPProtocolTCP, layers.IPProtocolUDP: if dataLen < offset+4 { return ErrIPv6PacketTooShort } fp.Protocol = uint8(proto) if incoming { fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) } else { fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) } fp.Fragment = false return nil case layers.IPProtocolIPv6Fragment: // Fragment header is 8 bytes, need at least offset+4 to read the offset field if dataLen < offset+8 { return ErrIPv6PacketTooShort } // Check if this is the first fragment fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits if fragmentOffset != 0 { // Non-first fragment, use what we have now and stop processing fp.Protocol = data[offset] fp.Fragment = true fp.RemotePort = 0 fp.LocalPort = 0 return nil } // The next loop should be the transport layer since we are the first fragment next = 8 // Fragment headers are always 8 bytes case layers.IPProtocolAH: // Auth headers, used by IPSec, have a different meaning for header length if dataLen <= offset+1 { break } next = int(data[offset+1]+2) << 2 default: // Normal ipv6 header length processing if dataLen <= offset+1 { break } next = int(data[offset+1]+1) << 3 } if next <= 0 { // Safety check, each ipv6 header has to be at least 8 bytes next = 8 } protoAt = offset offset = offset + next } return ErrIPv6CouldNotFindPayload } func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Do we at least have an ipv4 header worth of data? if len(data) < ipv4.HeaderLen { return ErrIPv4PacketTooShort } // Adjust our start position based on the advertised ip header length ihl := int(data[0]&0x0f) << 2 // Well-formed ip header length? if ihl < ipv4.HeaderLen { return ErrIPv4InvalidHeaderLength } // Check if this is the second or further fragment of a fragmented packet. flagsfrags := binary.BigEndian.Uint16(data[6:8]) fp.Fragment = (flagsfrags & 0x1FFF) != 0 // Firewall handles protocol checks fp.Protocol = data[9] // Accounting for a variable header length, do we have enough data for our src/dst tuples? minLen := ihl if !fp.Fragment { if fp.Protocol == firewall.ProtoICMP { minLen += minFwPacketLen + 2 } else { minLen += minFwPacketLen } } if len(data) < minLen { return ErrIPv4InvalidHeaderLength } if incoming { // Firewall packets are locally oriented fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) } if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier fp.LocalPort = 0 //code would be uint16(data[ihl+1]) } else if incoming { fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } else { fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } return nil } func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb) if err != nil { return nil, err } if !hostinfo.ConnectionState.window.Update(f.l, mc) { hostinfo.logger(f.l).WithField("header", h). Debugln("dropping out of window packet") return nil, errors.New("out of window packet") } return out, nil } func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { var err error out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") return false } err = newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).WithError(err).WithField("packet", out). Warnf("Error while validating inbound packet") return false } if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). Debugln("dropping out of window packet") return false } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).WithField("fwPacket", fwPacket). WithField("reason", dropReason). Debugln("dropping inbound packet") } return false } f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") } return true } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { if f.sendRecvErrorConfig.ShouldRecvError(endpoint) { f.sendRecvError(endpoint, index) } } func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) _ = f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). Debug("Recv error sent") } } func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). Debug("Recv error received, ignoring") return } if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). Debug("Recv error received") } hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) if hostinfo == nil { f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") return } if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } f.closeTunnel(hostinfo) // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } ================================================ FILE: outside_test.go ================================================ package nebula import ( "bytes" "encoding/binary" "net" "net/netip" "testing" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" ) func Test_newPacket(t *testing.T) { p := &firewall.Packet{} // length fails err := newPacket([]byte{}, true, p) require.ErrorIs(t, err, ErrPacketTooShort) err = newPacket([]byte{0x40}, true, p) require.ErrorIs(t, err, ErrIPv4PacketTooShort) err = newPacket([]byte{0x60}, true, p) require.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ Version: 1, Len: 100, Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Options: []byte{0, 1, 0, 2}, } b, _ := h.Marshal() err = newPacket(b, true, p) require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) require.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ Version: 1, Len: 100, Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Options: []byte{0, 1, 0, 2}, Protocol: firewall.ProtoTCP, } b, _ = h.Marshal() b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) assert.Equal(t, uint16(3), p.RemotePort) assert.Equal(t, uint16(4), p.LocalPort) assert.False(t, p.Fragment) // account for variable ip header length - outgoing h = ipv4.Header{ Version: 1, Protocol: 2, Len: 100, Src: net.IPv4(10, 0, 0, 1), Dst: net.IPv4(10, 0, 0, 2), Options: []byte{0, 1, 0, 2}, } b, _ = h.Marshal() b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) require.NoError(t, err) assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) assert.Equal(t, uint16(6), p.RemotePort) assert.Equal(t, uint16(5), p.LocalPort) assert.False(t, p.Fragment) } func Test_newPacket_v6(t *testing.T) { p := &firewall.Packet{} // invalid ipv6 ip := layers.IPv6{ Version: 6, HopLimit: 128, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } buffer := gopacket.NewSerializeBuffer() opt := gopacket.SerializeOptions{ ComputeChecksums: false, FixLengths: false, } err := gopacket.SerializeLayers(buffer, opt, &ip) require.NoError(t, err) err = newPacket(buffer.Bytes(), true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A v6 packet with a hop-by-hop extension // ICMPv6 Payload (Echo Request) icmpLayer := layers.ICMPv6{ TypeCode: layers.ICMPv6TypeEchoRequest, } // Hop-by-Hop Extension Header hopOption := layers.IPv6HopByHopOption{} hopOption.OptionData = []byte{0, 0, 0, 0} hopByHop := layers.IPv6HopByHop{} hopByHop.Options = append(hopByHop.Options, &hopOption) ip = layers.IPv6{ Version: 6, HopLimit: 128, NextHeader: layers.IPProtocolIPv6Destination, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } buffer.Clear() err = gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{ ComputeChecksums: false, FixLengths: true, }, &ip, &hopByHop, &icmpLayer) if err != nil { panic(err) } // Ensure buffer length checks during parsing with the next 2 tests. // A full IPv6 header and 1 byte in the first extension, but missing // the length byte. err = newPacket(buffer.Bytes()[:41], true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A full IPv6 header plus 1 full extension, but only 1 byte of the // next layer, missing length byte err = newPacket(buffer.Bytes()[:49], true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) err = nil // A good ICMP packet ip = layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolICMPv6, HopLimit: 128, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } icmp := layers.ICMPv6{ TypeCode: layers.ICMPv6TypeEchoRequest, Checksum: 0x1234, } buffer.Clear() require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp)) require.Error(t, newPacket(buffer.Bytes(), true, p)) buffer.Clear() echo := layers.ICMPv6Echo{ Identifier: 0xabcd, SeqNumber: 1234, } require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo)) require.NoError(t, newPacket(buffer.Bytes(), true, p)) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint16(0xabcd), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.False(t, p.Fragment) // A good ESP packet b := buffer.Bytes() b[6] = byte(layers.IPProtocolESP) err = newPacket(b, true, p) require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint16(0), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.False(t, p.Fragment) // A good None packet b = buffer.Bytes() b[6] = byte(layers.IPProtocolNoNextHeader) err = newPacket(b, true, p) require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint16(0), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.False(t, p.Fragment) // An unknown protocol packet b = buffer.Bytes() b[6] = 255 // 255 is a reserved protocol number err = newPacket(b, true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good UDP packet ip = layers.IPv6{ Version: 6, NextHeader: firewall.ProtoUDP, HopLimit: 128, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } udp := layers.UDP{ SrcPort: layers.UDPPort(36123), DstPort: layers.UDPPort(22), } err = udp.SetNetworkLayerForChecksum(&ip) require.NoError(t, err) buffer.Clear() err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) if err != nil { panic(err) } b = buffer.Bytes() // incoming err = newPacket(b, true, p) require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint16(36123), p.RemotePort) assert.Equal(t, uint16(22), p.LocalPort) assert.False(t, p.Fragment) // outgoing err = newPacket(b, false, p) require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint16(36123), p.LocalPort) assert.Equal(t, uint16(22), p.RemotePort) assert.False(t, p.Fragment) // Too short UDP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good TCP packet b[6] = byte(layers.IPProtocolTCP) // incoming err = newPacket(b, true, p) require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint16(36123), p.RemotePort) assert.Equal(t, uint16(22), p.LocalPort) assert.False(t, p.Fragment) // outgoing err = newPacket(b, false, p) require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint16(36123), p.LocalPort) assert.Equal(t, uint16(22), p.RemotePort) assert.False(t, p.Fragment) // Too short TCP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good UDP packet with an AH header ip = layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolAH, HopLimit: 128, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } ah := layers.IPSecAH{ AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef}, } ah.NextHeader = layers.IPProtocolUDP udpHeader := []byte{ 0x8d, 0x1b, // Source port 36123 0x00, 0x16, // Destination port 22 0x00, 0x00, // Length 0x00, 0x00, // Checksum } buffer.Clear() err = ip.SerializeTo(buffer, opt) if err != nil { panic(err) } b = buffer.Bytes() ahb := serializeAH(&ah) b = append(b, ahb...) b = append(b, udpHeader...) err = newPacket(b, true, p) require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint16(36123), p.RemotePort) assert.Equal(t, uint16(22), p.LocalPort) assert.False(t, p.Fragment) // Ensure buffer bounds checking during processing err = newPacket(b[:41], true, p) require.ErrorIs(t, err, ErrIPv6PacketTooShort) // Invalid AH header b = buffer.Bytes() err = newPacket(b, true, p) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) } func Test_newPacket_ipv6Fragment(t *testing.T) { p := &firewall.Packet{} ip := &layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolIPv6Fragment, HopLimit: 64, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } // First fragment fragHeader1 := []byte{ uint8(layers.IPProtocolUDP), // Next Header (UDP) 0x00, // Reserved 0x00, // Fragment Offset high byte (0) 0x01, // Fragment Offset low byte & flags (M=1) 0x00, 0x00, 0x00, 0x01, // Identification } udpHeader := []byte{ 0x8d, 0x1b, // Source port 36123 0x00, 0x16, // Destination port 22 0x00, 0x00, // Length 0x00, 0x00, // Checksum } buffer := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } err := ip.SerializeTo(buffer, opts) if err != nil { t.Fatal(err) } firstFrag := buffer.Bytes() firstFrag = append(firstFrag, fragHeader1...) firstFrag = append(firstFrag, udpHeader...) firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) // Test first fragment incoming err = newPacket(firstFrag, true, p) require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint16(36123), p.RemotePort) assert.Equal(t, uint16(22), p.LocalPort) assert.False(t, p.Fragment) // Test first fragment outgoing err = newPacket(firstFrag, false, p) require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint16(36123), p.LocalPort) assert.Equal(t, uint16(22), p.RemotePort) assert.False(t, p.Fragment) // Second fragment fragHeader2 := []byte{ uint8(layers.IPProtocolUDP), // Next Header (UDP) 0x00, // Reserved 0xb9, // Fragment Offset high byte (185) 0x01, // Fragment Offset low byte & flags (M=1) 0x00, 0x00, 0x00, 0x01, // Identification } buffer.Clear() err = ip.SerializeTo(buffer, opts) if err != nil { t.Fatal(err) } secondFrag := buffer.Bytes() secondFrag = append(secondFrag, fragHeader2...) secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) // Test second fragment incoming err = newPacket(secondFrag, true, p) require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint16(0), p.RemotePort) assert.Equal(t, uint16(0), p.LocalPort) assert.True(t, p.Fragment) // Test second fragment outgoing err = newPacket(secondFrag, false, p) require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint16(0), p.LocalPort) assert.Equal(t, uint16(0), p.RemotePort) assert.True(t, p.Fragment) // Too short of a fragment packet err = newPacket(secondFrag[:len(secondFrag)-10], false, p) require.ErrorIs(t, err, ErrIPv6PacketTooShort) } func BenchmarkParseV6(b *testing.B) { // Regular UDP packet ip := &layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolUDP, HopLimit: 64, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } udp := &layers.UDP{ SrcPort: layers.UDPPort(36123), DstPort: layers.UDPPort(22), } buffer := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ ComputeChecksums: false, FixLengths: true, } err := gopacket.SerializeLayers(buffer, opts, ip, udp) if err != nil { b.Fatal(err) } normalPacket := buffer.Bytes() // First Fragment packet ipFrag := &layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolIPv6Fragment, HopLimit: 64, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } fragHeader := []byte{ uint8(layers.IPProtocolUDP), // Next Header (UDP) 0x00, // Reserved 0x00, // Fragment Offset high byte (0) 0x01, // Fragment Offset low byte & flags (M=1) 0x00, 0x00, 0x00, 0x01, // Identification } udpHeader := []byte{ 0x8d, 0x7b, // Source port 36123 0x00, 0x16, // Destination port 22 0x00, 0x00, // Length 0x00, 0x00, // Checksum } buffer.Clear() err = ipFrag.SerializeTo(buffer, opts) if err != nil { b.Fatal(err) } firstFrag := buffer.Bytes() firstFrag = append(firstFrag, fragHeader...) firstFrag = append(firstFrag, udpHeader...) firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) // Second Fragment packet fragHeader[2] = 0xb9 // offset 185 buffer.Clear() err = ipFrag.SerializeTo(buffer, opts) if err != nil { b.Fatal(err) } secondFrag := buffer.Bytes() secondFrag = append(secondFrag, fragHeader...) secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) fp := &firewall.Packet{} b.Run("Normal", func(b *testing.B) { for i := 0; i < b.N; i++ { if err = parseV6(normalPacket, true, fp); err != nil { b.Fatal(err) } } }) b.Run("FirstFragment", func(b *testing.B) { for i := 0; i < b.N; i++ { if err = parseV6(firstFrag, true, fp); err != nil { b.Fatal(err) } } }) b.Run("SecondFragment", func(b *testing.B) { for i := 0; i < b.N; i++ { if err = parseV6(secondFrag, true, fp); err != nil { b.Fatal(err) } } }) // Evil packet evilPacket := &layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolIPv6HopByHop, HopLimit: 64, SrcIP: net.IPv6linklocalallrouters, DstIP: net.IPv6linklocalallnodes, } hopHeader := []byte{ uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop) 0x00, // Length 0x00, 0x00, // Options and padding 0x00, 0x00, 0x00, 0x00, // More options and padding } lastHopHeader := []byte{ uint8(layers.IPProtocolUDP), // Next Header (UDP) 0x00, // Length 0x00, 0x00, // Options and padding 0x00, 0x00, 0x00, 0x00, // More options and padding } buffer.Clear() err = evilPacket.SerializeTo(buffer, opts) if err != nil { b.Fatal(err) } evilBytes := buffer.Bytes() for range 200 { evilBytes = append(evilBytes, hopHeader...) } evilBytes = append(evilBytes, lastHopHeader...) evilBytes = append(evilBytes, udpHeader...) evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...) b.Run("200 HopByHop headers", func(b *testing.B) { for i := 0; i < b.N; i++ { if err = parseV6(evilBytes, false, fp); err != nil { b.Fatal(err) } } }) } // Ensure authentication data is a multiple of 8 bytes by padding if necessary func padAuthData(authData []byte) []byte { // Length of Authentication Data must be a multiple of 8 bytes paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary if paddingLength > 0 { authData = append(authData, make([]byte, paddingLength)...) } return authData } // Custom function to manually serialize IPSecAH for both IPv4 and IPv6 func serializeAH(ah *layers.IPSecAH) []byte { buf := new(bytes.Buffer) // Ensure Authentication Data is a multiple of 8 bytes ah.AuthenticationData = padAuthData(ah.AuthenticationData) // Calculate Payload Length (in 32-bit words, minus 2) payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2 // Serialize fields if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil { panic(err) } if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil { panic(err) } if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil { panic(err) } if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil { panic(err) } if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil { panic(err) } if len(ah.AuthenticationData) > 0 { if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil { panic(err) } } return buf.Bytes() } ================================================ FILE: overlay/device.go ================================================ package overlay import ( "io" "net/netip" "github.com/slackhq/nebula/routing" ) type Device interface { io.ReadWriteCloser Activate() error Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool NewMultiQueueReader() (io.ReadWriteCloser, error) } ================================================ FILE: overlay/route.go ================================================ package overlay import ( "fmt" "math" "net" "net/netip" "runtime" "strconv" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) type Route struct { MTU int Metric int Cidr netip.Prefix Via routing.Gateways Install bool } // Equal determines if a route that could be installed in the system route table is equal to another // Via is ignored since that is only consumed within nebula itself func (r Route) Equal(t Route) bool { if r.Cidr != t.Cidr { return false } if r.Metric != t.Metric { return false } if r.MTU != t.MTU { return false } if r.Install != t.Install { return false } return true } func (r Route) String() string { s := r.Cidr.String() if r.Metric != 0 { s += fmt.Sprintf(" metric: %v", r.Metric) } return s } func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) { routeTree := new(bart.Table[routing.Gateways]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } gateways := r.Via if len(gateways) > 0 { routing.CalculateBucketsForGateways(gateways) routeTree.Insert(r.Cidr, gateways) } } return routeTree, nil } func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") if r == nil { return []Route{}, nil } rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.routes is not an array") } if len(rawRoutes) < 1 { return []Route{}, nil } routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1) } rMtu, ok := m["mtu"] if !ok { return nil, fmt.Errorf("entry %v.mtu in tun.routes is not present", i+1) } mtu, ok := rMtu.(int) if !ok { mtu, err = strconv.Atoi(rMtu.(string)) if err != nil { return nil, fmt.Errorf("entry %v.mtu in tun.routes is not an integer: %v", i+1, err) } } if mtu < 500 { return nil, fmt.Errorf("entry %v.mtu in tun.routes is below 500: %v", i+1, mtu) } rRoute, ok := m["route"] if !ok { return nil, fmt.Errorf("entry %v.route in tun.routes is not present", i+1) } r := Route{ Install: true, MTU: mtu, } r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } found := false for _, network := range networks { if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() { found = true break } } if !found { return nil, fmt.Errorf( "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v", i+1, r.Cidr.String(), networks, ) } routes[i] = r } return routes, nil } func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") if r == nil { return []Route{}, nil } rawRoutes, ok := r.([]any) if !ok { return nil, fmt.Errorf("tun.unsafe_routes is not an array") } if len(rawRoutes) < 1 { return []Route{}, nil } routes := make([]Route, len(rawRoutes)) for i, r := range rawRoutes { m, ok := r.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1) } var mtu int if rMtu, ok := m["mtu"]; ok { mtu, ok = rMtu.(int) if !ok { mtu, err = strconv.Atoi(rMtu.(string)) if err != nil { return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is not an integer: %v", i+1, err) } } if mtu != 0 && mtu < 500 { return nil, fmt.Errorf("entry %v.mtu in tun.unsafe_routes is below 500: %v", i+1, mtu) } } rMetric, ok := m["metric"] if !ok { rMetric = 0 } metric, ok := rMetric.(int) if !ok { _, err = strconv.ParseInt(rMetric.(string), 10, 32) if err != nil { return nil, fmt.Errorf("entry %v.metric in tun.unsafe_routes is not an integer: %v", i+1, err) } } if metric < 0 || metric > math.MaxInt32 { return nil, fmt.Errorf("entry %v.metric in tun.unsafe_routes is not in range (0-%d) : %v", i+1, math.MaxInt32, metric) } rVia, ok := m["via"] if !ok { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1) } var gateways routing.Gateways switch via := rVia.(type) { case string: viaIp, err := netip.ParseAddr(via) if err != nil { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) } gateways = routing.Gateways{routing.NewGateway(viaIp, 1)} case []any: gateways = make(routing.Gateways, len(via)) for ig, v := range via { gatewayMap, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1) } rGateway, ok := gatewayMap["gateway"] if !ok { return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1) } parsedGateway, ok := rGateway.(string) if !ok { return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1) } gatewayIp, err := netip.ParseAddr(parsedGateway) if err != nil { return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err) } rGatewayWeight, ok := gatewayMap["weight"] if !ok { rGatewayWeight = 1 } gatewayWeight, ok := rGatewayWeight.(int) if !ok { _, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32) if err != nil { return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1) } } if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 { return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight) } gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight) } default: return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia) } rRoute, ok := m["route"] if !ok { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } install := true rInstall, ok := m["install"] if ok { install, err = strconv.ParseBool(fmt.Sprintf("%v", rInstall)) if err != nil { return nil, fmt.Errorf("entry %v.install in tun.unsafe_routes is not a boolean: %v", i+1, err) } } r := Route{ Via: gateways, MTU: mtu, Metric: metric, Install: install, } r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } for _, network := range networks { if network.Contains(r.Cidr.Addr()) { return nil, fmt.Errorf( "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v", i+1, r.Cidr.String(), network.String(), ) } } routes[i] = r } return routes, nil } func ipWithin(o *net.IPNet, i *net.IPNet) bool { // Make sure o contains the lowest form of i if !o.Contains(i.IP.Mask(i.Mask)) { return false } // Find the max ip in i ip4 := i.IP.To4() if ip4 == nil { return false } last := make(net.IP, len(ip4)) copy(last, ip4) for x := range ip4 { last[x] |= ^i.Mask[x] } // Make sure o contains the max if !o.Contains(last) { return false } return true } ================================================ FILE: overlay/route_test.go ================================================ package overlay import ( "fmt" "net/netip" "testing" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) // test no routes config routes, err := parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[string]any{"routes": "hi"} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[string]any{"routes": []any{}} routes, err = parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[string]any{"routes": []any{ map[string]any{"mtu": "9000", "route": "10.0.0.0/29"}, map[string]any{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) tested := 0 for _, r := range routes { assert.True(t, r.Install) if r.MTU == 8000 { assert.Equal(t, "10.0.0.1/32", r.Cidr.String()) tested++ } else { assert.Equal(t, 9000, r.MTU) assert.Equal(t, "10.0.0.0/29", r.Cidr.String()) tested++ } } if tested != 2 { t.Fatal("Did not see both routes") } } func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via for _, invalidValue := range []any{ 127, false, nil, 1.0, []string{"1", "2"}, } { c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue)) } // Unparsable list of via c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string") // unparsable via c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // unparsable gateway c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP") // missing gateway element c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present") // unparsable weight element c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer") // missing route c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) require.NoError(t, err) // above network range c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) require.NoError(t, err) // no mtu c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"}, map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0}, map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 4) tested := 0 for _, r := range routes { if r.MTU == 8000 { assert.Equal(t, "1.0.0.1/32", r.Cidr.String()) assert.False(t, r.Install) tested++ } else if r.MTU == 9000 { assert.Equal(t, 9000, r.MTU) assert.Equal(t, "1.0.0.0/29", r.Cidr.String()) assert.True(t, r.Install) tested++ } else { assert.Equal(t, 1500, r.MTU) assert.Equal(t, 1234, r.Metric) assert.Equal(t, "1.0.0.2/32", r.Cidr.String()) assert.True(t, r.Install) tested++ } } if tested != 4 { t.Fatal("Did not see all unsafe_routes") } } func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) c.Settings["tun"] = map[string]any{"unsafe_routes": []any{ map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") require.NoError(t, err) r, ok := routeTree.Lookup(ip) assert.True(t, ok) nip, err := netip.ParseAddr("192.168.0.1") require.NoError(t, err) assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.0.0.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.True(t, ok) nip, err = netip.ParseAddr("192.168.0.2") require.NoError(t, err) assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("1.1.0.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } func Test_makeMultipathUnsafeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") require.NoError(t, err) c.Settings["tun"] = map[string]any{ "unsafe_routes": []any{ map[string]any{ "route": "192.168.86.0/24", "via": "192.168.100.10", }, map[string]any{ "route": "192.168.87.0/24", "via": []any{ map[string]any{ "gateway": "10.0.0.1", }, map[string]any{ "gateway": "10.0.0.2", }, map[string]any{ "gateway": "10.0.0.3", }, }, }, map[string]any{ "route": "192.168.89.0/24", "via": []any{ map[string]any{ "gateway": "10.0.0.1", "weight": 10, }, map[string]any{ "gateway": "10.0.0.2", "weight": 5, }, }, }, }, } routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 3) routeTree, err := makeRouteTree(l, routes, true) require.NoError(t, err) ip, err := netip.ParseAddr("192.168.86.1") require.NoError(t, err) r, ok := routeTree.Lookup(ip) assert.True(t, ok) nip, err := netip.ParseAddr("192.168.100.10") require.NoError(t, err) assert.Equal(t, nip, r[0].Addr()) ip, err = netip.ParseAddr("192.168.87.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.True(t, ok) expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1), routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1), routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)} routing.CalculateBucketsForGateways(expectedGateways) assert.ElementsMatch(t, expectedGateways, r) ip, err = netip.ParseAddr("192.168.89.1") require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.True(t, ok) expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10), routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)} routing.CalculateBucketsForGateways(expectedGateways) assert.ElementsMatch(t, expectedGateways, r) } ================================================ FILE: overlay/tun.go ================================================ package overlay import ( "fmt" "net" "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) const DefaultMTU = 1300 type NameError struct { Name string Underlying error } func (e *NameError) Error() string { return fmt.Sprintf("could not set tun device name: %s because %s", e.Name, e.Underlying) } // TODO: We may be able to remove routines type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: return newTun(c, l, vpnNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, vpnNetworks) } } func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } routes, err := parseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } routes = append(routes, unsafeRoutes...) return true, routes, nil } // findRemovedRoutes will return all routes that are not present in the newRoutes list and would affect the system route table. // Via is not used to evaluate since it does not affect the system route table. func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route { var removed []Route has := func(entry Route) bool { for _, check := range newRoutes { if check.Equal(entry) { return true } } return false } for _, oldEntry := range oldRoutes { if !has(oldEntry) { removed = append(removed, oldEntry) } } return removed } func prefixToMask(prefix netip.Prefix) netip.Addr { pLen := 128 if prefix.Addr().Is4() { pLen = 32 } addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) return addr } func flipBytes(b []byte) []byte { for i := 0; i < len(b); i++ { b[i] ^= 0xFF } return b } func orBytes(a []byte, b []byte) []byte { ret := make([]byte, len(a)) for i := 0; i < len(a); i++ { ret[i] = a[i] | b[i] } return ret } func getBroadcast(cidr netip.Prefix) netip.Addr { broadcast, _ := netip.AddrFromSlice( orBytes( cidr.Addr().AsSlice(), flipBytes(prefixToMask(cidr).AsSlice()), ), ) return broadcast } func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) { for _, gateway := range gateways { if dest.Addr().Is4() && gateway.Addr().Is4() { return gateway, nil } if dest.Addr().Is6() && gateway.Addr().Is6() { return gateway, nil } } return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest) } ================================================ FILE: overlay/tun_android.go ================================================ //go:build !e2e_testing // +build !e2e_testing package overlay import ( "fmt" "io" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser fd int vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t := &tun{ ReadWriteCloser: file, fd: deviceFd, vpnNetworks: vpnNetworks, l: l, } err := t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } func (t tun) Activate() error { return nil } func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes t.Routes.Store(&routes) t.routeTree.Store(routeTree) return nil } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *tun) Name() string { return "android" } func (t *tun) SupportsMultiqueue() bool { return false } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } ================================================ FILE: overlay/tun_darwin.go ================================================ //go:build !ios && !e2e_testing // +build !ios,!e2e_testing package overlay import ( "errors" "fmt" "io" "net/netip" "os" "sync/atomic" "syscall" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser Device string vpnNetworks []netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } type ifReq struct { Name [unix.IFNAMSIZ]byte Flags uint16 pad [8]byte } const ( _SIOCAIFADDR_IN6 = 2155899162 _UTUN_OPT_IFNAME = 2 _IN6_IFF_NODAD = 0x0020 _IN6_IFF_SECURED = 0x0400 utunControlName = "com.apple.net.utun_control" ) type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } type addrLifetime struct { Expire float64 Preferred float64 Vltime uint32 Pltime uint32 } type ifreqAlias4 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet4 DstAddr unix.RawSockaddrInet4 MaskAddr unix.RawSockaddrInet4 } type ifreqAlias6 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet6 DstAddr unix.RawSockaddrInet6 PrefixMask unix.RawSockaddrInet6 Flags uint32 Lifetime addrLifetime } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { _, err := fmt.Sscanf(name, "utun%d", &ifIndex) if err != nil || ifIndex < 0 { // NOTE: we don't make this error so we don't break existing // configs that set a name before it was used. l.Warn("interface name must be utun[0-9]+ on Darwin, ignoring") ifIndex = -1 } } fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) if err != nil { return nil, fmt.Errorf("system socket: %v", err) } var ctlInfo = &unix.CtlInfo{} copy(ctlInfo.Name[:], utunControlName) err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { return nil, fmt.Errorf("CTLIOCGINFO: %v", err) } err = unix.Connect(fd, &unix.SockaddrCtl{ ID: ctlInfo.Id, Unit: uint32(ifIndex) + 1, }) if err != nil { return nil, fmt.Errorf("SYS_CONNECT: %v", err) } name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME) if err != nil { return nil, fmt.Errorf("failed to retrieve tun name: %w", err) } err = unix.SetNonblock(fd, true) if err != nil { return nil, fmt.Errorf("SetNonblock: %v", err) } t := &tun{ ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, vpnNetworks: vpnNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } err = t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } func (t *tun) Close() error { if t.ReadWriteCloser != nil { return t.ReadWriteCloser.Close() } return nil } func (t *tun) Activate() error { devName := t.deviceBytes() s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP, ) if err != nil { return err } defer unix.Close(s) fd := uintptr(s) // Set the MTU on the device ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { return fmt.Errorf("failed to set tun mtu: %v", err) } // Get the device flags ifrf := ifReq{Name: devName} if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to get tun flags: %s", err) } linkAddr, err := getLinkAddr(t.Device) if err != nil { return err } if linkAddr == nil { return fmt.Errorf("unable to discover link_addr for tun interface") } t.linkAddr = linkAddr for _, network := range t.vpnNetworks { if network.Addr().Is4() { err = t.activate4(network) if err != nil { return err } } else { err = t.activate6(network) if err != nil { return err } } } // Run the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to run tun device: %s", err) } // Unsafe path routes return t.addRoutes(false) } func (t *tun) activate4(network netip.Prefix) error { s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP, ) if err != nil { return err } defer unix.Close(s) ifr := ifreqAlias4{ Name: t.deviceBytes(), Addr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: network.Addr().As4(), }, DstAddr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: network.Addr().As4(), }, MaskAddr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: prefixToMask(network).As4(), }, } if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { return fmt.Errorf("failed to set tun v4 address: %s", err) } err = addRoute(network, t.linkAddr) if err != nil { return err } return nil } func (t *tun) activate6(network netip.Prefix) error { s, err := unix.Socket( unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP, ) if err != nil { return err } defer unix.Close(s) ifr := ifreqAlias6{ Name: t.deviceBytes(), Addr: unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: network.Addr().As16(), }, PrefixMask: unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: prefixToMask(network).As16(), }, Lifetime: addrLifetime{ // never expires Vltime: 0xffffffff, Pltime: 0xffffffff, }, Flags: _IN6_IFF_NODAD, } if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { return fmt.Errorf("failed to set tun address: %s", err) } return nil } func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) t.routeTree.Store(routeTree) if !initial { // Remove first, if the system removes a wanted route hopefully it will be re-added next err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } // Ensure any routes we actually want are installed err = t.addRoutes(true) if err != nil { // Catch any stray logs util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } return routing.Gateways{} } // Get the LinkAddr for the interface of the given name // Is there an easier way to fetch this when we create the interface? // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers. func getLinkAddr(name string) (*netroute.LinkAddr, error) { rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) if err != nil { return nil, err } msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib) if err != nil { return nil, err } for _, m := range msgs { switch m := m.(type) { case *netroute.InterfaceMessage: if m.Name == name { sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr) if ok { return sa, nil } } } } return nil, nil } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { t.l.WithField("route", r.Cidr). Warnf("unable to add unsafe_route, identical route already exists") } else { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } } } else { t.l.WithField("route", r).Info("Added route") } } return nil } func (t *tun) removeRoutes(routes []Route) error { for _, r := range routes { if !r.Install { continue } err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") } } return nil } func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, } if prefix.Addr().Is4() { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: gateway, } } else { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: gateway, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, } if prefix.Addr().Is4() { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: gateway, } } else { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: gateway, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } func (t *tun) Read(to []byte) (int, error) { buf := make([]byte, len(to)+4) n, err := t.ReadWriteCloser.Read(buf) copy(to, buf[4:]) return n - 4, err } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { buf := t.out if cap(buf) < len(from)+4 { buf = make([]byte, len(from)+4) t.out = buf } buf = buf[:len(from)+4] if len(from) == 0 { return 0, syscall.EIO } // Determine the IP Family for the NULL L2 Header ipVer := from[0] >> 4 if ipVer == 4 { buf[3] = syscall.AF_INET } else if ipVer == 6 { buf[3] = syscall.AF_INET6 } else { return 0, fmt.Errorf("unable to determine IP version from packet") } copy(buf[4:], from) n, err := t.ReadWriteCloser.Write(buf) return n - 4, err } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *tun) Name() string { return t.Device } func (t *tun) SupportsMultiqueue() bool { return false } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } ================================================ FILE: overlay/tun_disabled.go ================================================ package overlay import ( "fmt" "io" "net/netip" "strings" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/routing" ) type disabledTun struct { read chan []byte vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter rx metrics.Counter l *logrus.Logger } func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, read: make(chan []byte, queueLen), l: l, } if metricsEnabled { tun.tx = metrics.GetOrRegisterCounter("messages.tx.message", nil) tun.rx = metrics.GetOrRegisterCounter("messages.rx.message", nil) } else { tun.tx = &metrics.NilCounter{} tun.rx = &metrics.NilCounter{} } return tun } func (*disabledTun) Activate() error { return nil } func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways { return routing.Gateways{} } func (t *disabledTun) Networks() []netip.Prefix { return t.vpnNetworks } func (*disabledTun) Name() string { return "disabled" } func (t *disabledTun) Read(b []byte) (int, error) { r, ok := <-t.read if !ok { return 0, io.EOF } if len(r) > len(b) { return 0, fmt.Errorf("packet larger than mtu: %d > %d bytes", len(r), len(b)) } t.tx.Inc(1) if t.l.Level >= logrus.DebugLevel { t.l.WithField("raw", prettyPacket(r)).Debugf("Write payload") } return copy(b, r), nil } func (t *disabledTun) handleICMPEchoRequest(b []byte) bool { out := make([]byte, len(b)) out = iputil.CreateICMPEchoResponse(b, out) if out == nil { return false } // attempt to write it, but don't block select { case t.read <- out: default: t.l.Debugf("tun_disabled: dropped ICMP Echo Reply response") } return true } func (t *disabledTun) Write(b []byte) (int, error) { t.rx.Inc(1) // Check for ICMP Echo Request before spending time doing the full parsing if t.handleICMPEchoRequest(b) { if t.l.Level >= logrus.DebugLevel { t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun responded to ICMP Echo Request") } } else if t.l.Level >= logrus.DebugLevel { t.l.WithField("raw", prettyPacket(b)).Debugf("Disabled tun received unexpected payload") } return len(b), nil } func (t *disabledTun) SupportsMultiqueue() bool { return true } func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return t, nil } func (t *disabledTun) Close() error { if t.read != nil { close(t.read) t.read = nil } return nil } type prettyPacket []byte func (p prettyPacket) String() string { var s strings.Builder for i, b := range p { if i > 0 && i%8 == 0 { s.WriteString(" ") } s.WriteString(fmt.Sprintf("%02x ", b)) } return s.String() } ================================================ FILE: overlay/tun_freebsd.go ================================================ //go:build !e2e_testing // +build !e2e_testing package overlay import ( "bytes" "errors" "fmt" "io" "io/fs" "net/netip" "sync/atomic" "syscall" "time" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) const ( // FIODGNAME is defined in sys/sys/filio.h on FreeBSD // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) FIODGNAME = 0x80106678 TUNSIFMODE = 0x8004745e TUNSIFHEAD = 0x80047460 OSIOCAIFADDR_IN6 = 0x8088691b IN6_IFF_NODAD = 0x0020 ) type fiodgnameArg struct { length int32 pad [4]byte buf unsafe.Pointer } type ifreqRename struct { Name [unix.IFNAMSIZ]byte Data uintptr } type ifreqDestroy struct { Name [unix.IFNAMSIZ]byte pad [16]byte } type ifReq struct { Name [unix.IFNAMSIZ]byte Flags uint16 } type ifreqMTU struct { Name [unix.IFNAMSIZ]byte MTU int32 } type addrLifetime struct { Expire uint64 Preferred uint64 Vltime uint32 Pltime uint32 } type ifreqAlias4 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet4 DstAddr unix.RawSockaddrInet4 MaskAddr unix.RawSockaddrInet4 VHid uint32 } type ifreqAlias6 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet6 DstAddr unix.RawSockaddrInet6 PrefixMask unix.RawSockaddrInet6 Flags uint32 Lifetime addrLifetime VHid uint32 } type tun struct { Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] linkAddr *netroute.LinkAddr l *logrus.Logger devFd int } func (t *tun) Read(to []byte) (int, error) { // use readv() to read from the tunnel device, to eliminate the need for copying the buffer if t.devFd < 0 { return -1, syscall.EINVAL } // first 4 bytes is protocol family, in network byte order head := make([]byte, 4) iovecs := []syscall.Iovec{ {&head[0], 4}, {&to[0], uint64(len(to))}, } n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) var err error if errno != 0 { err = syscall.Errno(errno) } else { err = nil } // fix bytes read number to exclude header bytesRead := int(n) if bytesRead < 0 { return bytesRead, err } else if bytesRead < 4 { return 0, err } else { return bytesRead - 4, err } } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { // use writev() to write to the tunnel device, to eliminate the need for copying the buffer if t.devFd < 0 { return -1, syscall.EINVAL } if len(from) <= 1 { return 0, syscall.EIO } ipVer := from[0] >> 4 var head []byte // first 4 bytes is protocol family, in network byte order if ipVer == 4 { head = []byte{0, 0, 0, syscall.AF_INET} } else if ipVer == 6 { head = []byte{0, 0, 0, syscall.AF_INET6} } else { return 0, fmt.Errorf("unable to determine IP version from packet") } iovecs := []syscall.Iovec{ {&head[0], 4}, {&from[0], uint64(len(from))}, } n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) var err error if errno != 0 { err = syscall.Errno(errno) } else { err = nil } return int(n) - 4, err } func (t *tun) Close() error { if t.devFd >= 0 { err := syscall.Close(t.devFd) if err != nil { t.l.WithError(err).Error("Error closing device") } t.devFd = -1 c := make(chan struct{}) go func() { // destroying the interface can block if a read() is still pending. Do this asynchronously. defer close(c) s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) if err == nil { defer syscall.Close(s) ifreq := ifreqDestroy{Name: t.deviceBytes()} err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) } if err != nil { t.l.WithError(err).Error("Error destroying tunnel") } }() // wait up to 1 second so we start blocking at the ioctl select { case <-c: case <-time.After(1 * time.Second): } } return nil } func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error deviceName := c.GetString("tun.dev", "") if deviceName != "" { fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0) } if errors.Is(err, fs.ErrNotExist) || deviceName == "" { // If the device doesn't already exist, request a new one and rename it fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0) } if err != nil { return nil, err } // Read the name of the interface var name [16]byte arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg))) if ctrlErr == nil { // set broadcast mode and multicast ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST) ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode))) } if ctrlErr == nil { // turn on link-layer mode, to support ipv6 ifhead := uint32(1) ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead))) } if ctrlErr != nil { return nil, err } ifName := string(bytes.TrimRight(name[:], "\x00")) if deviceName == "" { deviceName = ifName } // If the name doesn't match the desired interface name, rename it now if ifName != deviceName { s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return nil, err } defer syscall.Close(s) fd := uintptr(s) var fromName [16]byte var toName [16]byte copy(fromName[:], ifName) copy(toName[:], deviceName) ifrr := ifreqRename{ Name: fromName, Data: uintptr(unsafe.Pointer(&toName)), } // Set the device name _ = ioctl(fd, syscall.SIOCSIFNAME, uintptr(unsafe.Pointer(&ifrr))) } t := &tun{ Device: deviceName, vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, devFd: fd, } err = t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *tun) addIp(cidr netip.Prefix) error { if cidr.Addr().Is4() { ifr := ifreqAlias4{ Name: t.deviceBytes(), Addr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: cidr.Addr().As4(), }, DstAddr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: getBroadcast(cidr).As4(), }, MaskAddr: unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: prefixToMask(cidr).As4(), }, VHid: 0, } s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) // Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) } return nil } if cidr.Addr().Is6() { ifr := ifreqAlias6{ Name: t.deviceBytes(), Addr: unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: cidr.Addr().As16(), }, PrefixMask: unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: prefixToMask(cidr).As16(), }, Lifetime: addrLifetime{ Expire: 0, Preferred: 0, Vltime: 0xffffffff, Pltime: 0xffffffff, }, Flags: IN6_IFF_NODAD, } s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) } return nil } return fmt.Errorf("unknown address type %v", cidr) } func (t *tun) Activate() error { // Setup our default MTU err := t.setMTU() if err != nil { return err } linkAddr, err := getLinkAddr(t.Device) if err != nil { return err } if linkAddr == nil { return fmt.Errorf("unable to discover link_addr for tun interface") } t.linkAddr = linkAddr for i := range t.vpnNetworks { err := t.addIp(t.vpnNetworks[i]) if err != nil { return err } } return t.addRoutes(false) } func (t *tun) setMTU() error { // Set the MTU on the device s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)} err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))) return err } func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) t.routeTree.Store(routeTree) if !initial { // Remove first, if the system removes a wanted route hopefully it will be re-added next err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } // Ensure any routes we actually want are installed err = t.addRoutes(true) if err != nil { // Catch any stray logs util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *tun) Name() string { return t.Device } func (t *tun) SupportsMultiqueue() bool { return false } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } err := addRoute(r.Cidr, t.linkAddr) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } } else { t.l.WithField("route", r).Info("Added route") } } return nil } func (t *tun) removeRoutes(routes []Route) error { for _, r := range routes { if !r.Install { continue } err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") } } return nil } func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, } if prefix.Addr().Is4() { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: gateway, } } else { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: gateway, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { if errors.Is(err, unix.EEXIST) { // Try to do a change route.Type = unix.RTM_CHANGE data, err = route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) } _, err = unix.Write(sock, data[:]) fmt.Println("DOING CHANGE") return err } return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, } if prefix.Addr().Is4() { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: gateway, } } else { route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: gateway, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } // getLinkAddr Gets the link address for the interface of the given name func getLinkAddr(name string) (*netroute.LinkAddr, error) { rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) if err != nil { return nil, err } msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib) if err != nil { return nil, err } for _, m := range msgs { switch m := m.(type) { case *netroute.InterfaceMessage: if m.Name == name { sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr) if ok { return sa, nil } } } } return nil, nil } ================================================ FILE: overlay/tun_ios.go ================================================ //go:build ios && !e2e_testing // +build ios,!e2e_testing package overlay import ( "errors" "fmt" "io" "net/netip" "os" "sync" "sync/atomic" "syscall" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser vpnNetworks []netip.Prefix Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger } func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } err := t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *tun) Activate() error { return nil } func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes t.Routes.Store(&routes) t.routeTree.Store(routeTree) return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } // The following is hoisted up from water, we do this so we can inject our own fd on iOS type tunReadCloser struct { f io.ReadWriteCloser rMu sync.Mutex rBuf []byte wMu sync.Mutex wBuf []byte } func (tr *tunReadCloser) Read(to []byte) (int, error) { tr.rMu.Lock() defer tr.rMu.Unlock() if cap(tr.rBuf) < len(to)+4 { tr.rBuf = make([]byte, len(to)+4) } tr.rBuf = tr.rBuf[:len(to)+4] n, err := tr.f.Read(tr.rBuf) copy(to, tr.rBuf[4:]) return n - 4, err } func (tr *tunReadCloser) Write(from []byte) (int, error) { if len(from) == 0 { return 0, syscall.EIO } tr.wMu.Lock() defer tr.wMu.Unlock() if cap(tr.wBuf) < len(from)+4 { tr.wBuf = make([]byte, len(from)+4) } tr.wBuf = tr.wBuf[:len(from)+4] // Determine the IP Family for the NULL L2 Header ipVer := from[0] >> 4 if ipVer == 4 { tr.wBuf[3] = syscall.AF_INET } else if ipVer == 6 { tr.wBuf[3] = syscall.AF_INET6 } else { return 0, errors.New("unable to determine IP version from packet") } copy(tr.wBuf[4:], from) n, err := tr.f.Write(tr.wBuf) return n - 4, err } func (tr *tunReadCloser) Close() error { return tr.f.Close() } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *tun) Name() string { return "iOS" } func (t *tun) SupportsMultiqueue() bool { return false } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } ================================================ FILE: overlay/tun_linux.go ================================================ //go:build !android && !e2e_testing // +build !android,!e2e_testing package overlay import ( "fmt" "io" "net" "net/netip" "os" "strings" "sync" "sync/atomic" "time" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) type tun struct { io.ReadWriteCloser fd int Device string vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int deviceIndex int ioctlFd uintptr Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeChan chan struct{} useSystemRoutes bool useSystemRoutesBufferSize int // These are routes learned from `tun.use_system_route_table` // stored here to make it easier to restore them after a reload routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex l *logrus.Logger } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } type ifreqQLEN struct { Name [16]byte Value int32 pad [8]byte } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } t.Device = "tun0" return t, nil } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) if os.IsNotExist(err) { err = os.MkdirAll("/dev/net", 0755) if err != nil { return nil, fmt.Errorf("/dev/net/tun doesn't exist, failed to mkdir -p /dev/net: %w", err) } err = unix.Mknod("/dev/net/tun", unix.S_IFCHR|0600, int(unix.Mkdev(10, 200))) if err != nil { return nil, fmt.Errorf("failed to create /dev/net/tun: %w", err) } fd, err = unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, fmt.Errorf("created /dev/net/tun, but still failed: %w", err) } } else { return nil, err } } var req ifReq req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } nameStr := c.GetString("tun.dev", "") copy(req.Name[:], nameStr) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, &NameError{ Name: nameStr, Underlying: err, } } name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } t.Device = name return t, nil } func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), routesFromSystem: map[netip.Prefix]routing.Gateways{}, l: l, } err := t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *tun) reload(c *config.C, initial bool) error { routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !routeChange && !c.HasChanged("tun.mtu") { return nil } routeTree, err := makeRouteTree(t.l, routes, true) if err != nil { return err } // Bring along any routes learned from the system route table on reload t.routesFromSystemLock.Lock() for dst, gw := range t.routesFromSystem { routeTree.Insert(dst, gw) } t.routesFromSystemLock.Unlock() oldDefaultMTU := t.DefaultMTU oldMaxMTU := t.MaxMTU newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU) newMaxMTU := newDefaultMTU for i, r := range routes { if r.MTU == 0 { routes[i].MTU = newDefaultMTU } if r.MTU > t.MaxMTU { newMaxMTU = r.MTU } } t.MaxMTU = newMaxMTU t.DefaultMTU = newDefaultMTU // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) t.routeTree.Store(routeTree) if !initial { if oldMaxMTU != newMaxMTU { t.setMTU() t.l.Infof("Set max MTU to %v was %v", t.MaxMTU, oldMaxMTU) } if oldDefaultMTU != newDefaultMTU { for i := range t.vpnNetworks { err := t.setDefaultRoute(t.vpnNetworks[i]) if err != nil { t.l.Warn(err) } else { t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) } } } // Remove first, if the system removes a wanted route hopefully it will be re-added next t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) // Ensure any routes we actually want are installed err = t.addRoutes(true) if err != nil { // This should never be called since addRoutes should log its own errors in a reload condition util.LogWithContextIfNeeded("Failed to refresh routes", err, t.l) } } return nil } func (t *tun) SupportsMultiqueue() bool { return true } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { return nil, err } var req ifReq req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, err } file := os.NewFile(uintptr(fd), "/dev/net/tun") return file, nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) Write(b []byte) (int, error) { var nn int maximum := len(b) for { n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } if nn == len(b) { return nn, err } if err != nil { return nn, err } if n == 0 { return nn, io.ErrUnexpectedEOF } } } func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { for i := range al { if al[i].Equal(x) { return true } } return false } // addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there func (t *tun) addIPs(link netlink.Link) error { newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) for i := range t.vpnNetworks { newAddrs[i] = &netlink.Addr{ IPNet: &net.IPNet{ IP: t.vpnNetworks[i].Addr().AsSlice(), Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), }, Label: t.vpnNetworks[i].Addr().Zone(), } } //add all new addresses for i := range newAddrs { //AddrReplace still adds new IPs, but if their properties change it will change them as well if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { return err } } //iterate over remainder, remove whoever shouldn't be there al, err := netlink.AddrList(link, netlink.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to get tun address list: %s", err) } for i := range al { if hasNetlinkAddr(newAddrs, al[i]) { continue } err = netlink.AddrDel(link, &al[i]) if err != nil { t.l.WithError(err).Error("failed to remove address from tun address list") } else { t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") } } return nil } func (t *tun) Activate() error { devName := t.deviceBytes() if t.useSystemRoutes { t.watchRoutes() } s, err := unix.Socket( unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) if err != nil { return err } t.ioctlFd = uintptr(s) // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } link, err := netlink.LinkByName(t.Device) if err != nil { return fmt.Errorf("failed to get tun device link: %s", err) } t.deviceIndex = link.Attrs().Index // Setup our default MTU t.setMTU() // Set the transmit queue length ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} if err = ioctl(t.ioctlFd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { // If we can't set the queue length nebula will still work but it may lead to packet loss t.l.WithError(err).Error("Failed to set tun tx queue length") } const modeNone = 1 if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { t.l.WithError(err).Warn("Failed to disable link local address generation") } if err = t.addIPs(link); err != nil { return err } // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } //set route MTU for i := range t.vpnNetworks { if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { return fmt.Errorf("failed to set default route MTU: %w", err) } } // Set the routes if err = t.addRoutes(false); err != nil { return err } // Run the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to run tun device: %s", err) } return nil } func (t *tun) setMTU() { // Set the MTU on the device ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MaxMTU)} if err := ioctl(t.ioctlFd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { // This is currently a non fatal condition because the route table must have the MTU set appropriately as well t.l.WithError(err).Error("Failed to set tun mtu") } } func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ IP: cidr.Masked().Addr().AsSlice(), Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), } nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } err := netlink.RouteReplace(&nr) if err != nil { t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` for i := 0; i < 2; i++ { time.Sleep(100 * time.Millisecond) err = netlink.RouteReplace(&nr) if err == nil { break } else { t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") } } if err != nil { return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) } } return nil } func (t *tun) addRoutes(logErrors bool) error { // Path routes routes := *t.Routes.Load() for _, r := range routes { if !r.Install { continue } dr := &net.IPNet{ IP: r.Cidr.Masked().Addr().AsSlice(), Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), } nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, } if r.Metric > 0 { nr.Priority = r.Metric } err := netlink.RouteReplace(&nr) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } } else { t.l.WithField("route", r).Info("Added route") } } return nil } func (t *tun) removeRoutes(routes []Route) { for _, r := range routes { if !r.Install { continue } dr := &net.IPNet{ IP: r.Cidr.Masked().Addr().AsSlice(), Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), } nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, } if r.Metric > 0 { nr.Priority = r.Metric } err := netlink.RouteDel(&nr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") } } } func (t *tun) Name() string { return t.Device } func (t *tun) advMSS(r Route) int { mtu := r.MTU if r.MTU == 0 { mtu = t.DefaultMTU } // We only need to set advmss if the route MTU does not match the device MTU if mtu != t.MaxMTU { return mtu - 40 } return 0 } func (t *tun) watchRoutes() { rch := make(chan netlink.RouteUpdate) doneChan := make(chan struct{}) netlinkOptions := netlink.RouteSubscribeOptions{ ReceiveBufferSize: t.useSystemRoutesBufferSize, ReceiveBufferForceSize: t.useSystemRoutesBufferSize != 0, ErrorCallback: func(e error) { t.l.WithError(e).Errorf("netlink error") }, } if err := netlink.RouteSubscribeWithOptions(rch, doneChan, netlinkOptions); err != nil { t.l.WithError(err).Errorf("failed to subscribe to system route changes") return } t.routeChan = doneChan go func() { for { select { case r, ok := <-rch: if ok { t.updateRoutes(r) } else { // may be should do something here as // netlink stops sending updates return } case <-doneChan: // netlink.RouteSubscriber will close the rch for us return } } }() } func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { withinNetworks := false for i := range t.vpnNetworks { if t.vpnNetworks[i].Contains(gwAddr) { withinNetworks = true break } } return withinNetworks } func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { var gateways routing.Gateways link, err := netlink.LinkByName(t.Device) if err != nil { t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") return gateways } // If this route is relevant to our interface and there is a gateway then add it if r.LinkIndex == link.Attrs().Index { gwAddr, ok := getGatewayAddr(r.Gw, r.Via) if ok { if t.isGatewayInVpnNetworks(gwAddr) { gateways = append(gateways, routing.NewGateway(gwAddr, 1)) } else { // Gateway isn't in our overlay network, ignore t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") } } else { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") } } for _, p := range r.MultiPath { // If this route is relevant to our interface and there is a gateway then add it if p.LinkIndex == link.Attrs().Index { gwAddr, ok := getGatewayAddr(p.Gw, p.Via) if ok { if t.isGatewayInVpnNetworks(gwAddr) { gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) } else { // Gateway isn't in our overlay network, ignore t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") } } else { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") } } } routing.CalculateBucketsForGateways(gateways) return gateways } func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) { // Try to use the old RTA_GATEWAY first gwAddr, ok := netip.AddrFromSlice(gw) if !ok { // Fallback to the new RTA_VIA rVia, ok := via.(*netlink.Via) if ok { gwAddr, ok = netip.AddrFromSlice(rVia.Addr) } } if gwAddr.IsValid() { gwAddr = gwAddr.Unmap() return gwAddr, true } return netip.Addr{}, false } func (t *tun) updateRoutes(r netlink.RouteUpdate) { gateways := t.getGatewaysFromRoute(&r.Route) if len(gateways) == 0 { // No gateways relevant to our network, no routing changes required. t.l.WithField("route", r).Debug("Ignoring route update, no gateways") return } if r.Dst == nil { t.l.WithField("route", r).Debug("Ignoring route update, no destination address") return } dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) if !ok { t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") return } ones, _ := r.Dst.Mask.Size() dst := netip.PrefixFrom(dstAddr, ones) newTree := t.routeTree.Load().Clone() t.routesFromSystemLock.Lock() if r.Type == unix.RTM_NEWROUTE { t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route") t.routesFromSystem[dst] = gateways newTree.Insert(dst, gateways) } else { t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route") delete(t.routesFromSystem, dst) newTree.Delete(dst) } t.routesFromSystemLock.Unlock() t.routeTree.Store(newTree) } func (t *tun) Close() error { if t.routeChan != nil { close(t.routeChan) } if t.ReadWriteCloser != nil { _ = t.ReadWriteCloser.Close() } if t.ioctlFd > 0 { _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() t.ioctlFd = 0 } return nil } ================================================ FILE: overlay/tun_linux_test.go ================================================ //go:build !e2e_testing // +build !e2e_testing package overlay import "testing" var runAdvMSSTests = []struct { name string tun *tun r Route expected int }{ // Standard case, default MTU is the device max MTU {"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0}, {"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0}, {"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160}, // Case where we have a route MTU set higher than the default {"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400}, {"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400}, {"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0}, } func TestTunAdvMSS(t *testing.T) { for _, tt := range runAdvMSSTests { t.Run(tt.name, func(t *testing.T) { o := tt.tun.advMSS(tt.r) if o != tt.expected { t.Errorf("got %d, want %d", o, tt.expected) } }) } } ================================================ FILE: overlay/tun_netbsd.go ================================================ //go:build !e2e_testing // +build !e2e_testing package overlay import ( "errors" "fmt" "io" "net/netip" "os" "regexp" "sync/atomic" "syscall" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) const ( SIOCAIFADDR_IN6 = 0x8080696b TUNSIFHEAD = 0x80047442 TUNSIFMODE = 0x80047458 ) type ifreqAlias4 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet4 DstAddr unix.RawSockaddrInet4 MaskAddr unix.RawSockaddrInet4 } type ifreqAlias6 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet6 DstAddr unix.RawSockaddrInet6 PrefixMask unix.RawSockaddrInet6 Flags uint32 Lifetime addrLifetime } type ifreq struct { Name [unix.IFNAMSIZ]byte data int } type addrLifetime struct { Expire uint64 Preferred uint64 Vltime uint32 Pltime uint32 } type tun struct { Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger f *os.File fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } if !deviceNameRE.MatchString(deviceName) { return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0) if err != nil { return nil, err } err = unix.SetNonblock(fd, true) if err != nil { l.WithError(err).Warn("Failed to set the tun device as nonblocking") } t := &tun{ f: os.NewFile(uintptr(fd), ""), fd: fd, Device: deviceName, vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } err = t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *tun) Close() error { if t.f != nil { if err := t.f.Close(); err != nil { return fmt.Errorf("error closing tun file: %w", err) } // t.f.Close should have handled it for us but let's be extra sure _ = unix.Close(t.fd) s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) ifr := ifreq{Name: t.deviceBytes()} err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr))) return err } return nil } func (t *tun) Read(to []byte) (int, error) { rc, err := t.f.SyscallConn() if err != nil { return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) } var errno syscall.Errno var n uintptr err = rc.Read(func(fd uintptr) bool { // first 4 bytes is protocol family, in network byte order head := [4]byte{} iovecs := []syscall.Iovec{ {&head[0], 4}, {&to[0], uint64(len(to))}, } n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) if errno.Temporary() { // We got an EAGAIN, EINTR, or EWOULDBLOCK, go again return false } return true }) if err != nil { if err == syscall.EBADF || err.Error() == "use of closed file" { // Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are // https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121 return 0, os.ErrClosed } return 0, fmt.Errorf("failed to make read call for tun: %w", err) } if errno != 0 { return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno) } // fix bytes read number to exclude header bytesRead := int(n) if bytesRead < 0 { return bytesRead, nil } else if bytesRead < 4 { return 0, nil } else { return bytesRead - 4, nil } } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { if len(from) <= 1 { return 0, syscall.EIO } ipVer := from[0] >> 4 var head [4]byte // first 4 bytes is protocol family, in network byte order if ipVer == 4 { head[3] = syscall.AF_INET } else if ipVer == 6 { head[3] = syscall.AF_INET6 } else { return 0, fmt.Errorf("unable to determine IP version from packet") } rc, err := t.f.SyscallConn() if err != nil { return 0, err } var errno syscall.Errno var n uintptr err = rc.Write(func(fd uintptr) bool { iovecs := []syscall.Iovec{ {&head[0], 4}, {&from[0], uint64(len(from))}, } n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) // According to NetBSD documentation for TUN, writes will only return errors in which // this packet will never be delivered so just go on living life. return true }) if err != nil { return 0, err } if errno != 0 { return 0, errno } return int(n) - 4, err } func (t *tun) addIp(cidr netip.Prefix) error { if cidr.Addr().Is4() { var req ifreqAlias4 req.Name = t.deviceBytes() req.Addr = unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: cidr.Addr().As4(), } req.DstAddr = unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: cidr.Addr().As4(), } req.MaskAddr = unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: prefixToMask(cidr).As4(), } s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err) } return nil } if cidr.Addr().Is6() { var req ifreqAlias6 req.Name = t.deviceBytes() req.Addr = unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: cidr.Addr().As16(), } req.PrefixMask = unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: prefixToMask(cidr).As16(), } req.Lifetime = addrLifetime{ Vltime: 0xffffffff, Pltime: 0xffffffff, } s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) } return nil } return fmt.Errorf("unknown address type %v", cidr) } func (t *tun) Activate() error { mode := int32(unix.IFF_BROADCAST) err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode))) if err != nil { return fmt.Errorf("failed to set tun device mode: %w", err) } v := 1 err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v))) if err != nil { return fmt.Errorf("failed to set tun device head: %w", err) } err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU)) if err != nil { return fmt.Errorf("failed to set tun mtu: %w", err) } for i := range t.vpnNetworks { err = t.addIp(t.vpnNetworks[i]) if err != nil { return err } } return t.addRoutes(false) } func (t *tun) doIoctlByName(ctl uintptr, value uint32) error { s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) ir := ifreq{Name: t.deviceBytes(), data: int(value)} err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir))) return err } func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) t.routeTree.Store(routeTree) if !initial { // Remove first, if the system removes a wanted route hopefully it will be re-added next err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } // Ensure any routes we actually want are installed err = t.addRoutes(true) if err != nil { // Catch any stray logs util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *tun) Name() string { return t.Device } func (t *tun) SupportsMultiqueue() bool { return false } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } err := addRoute(r.Cidr, t.vpnNetworks) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } } else { t.l.WithField("route", r).Info("Added route") } } return nil } func (t *tun) removeRoutes(routes []Route) error { for _, r := range routes { if !r.Install { continue } err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") } } return nil } func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP | unix.RTF_GATEWAY, Seq: 1, } if prefix.Addr().Is4() { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, } } else { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { if errors.Is(err, unix.EEXIST) { // Try to do a change route.Type = unix.RTM_CHANGE data, err = route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) } _, err = unix.Write(sock, data[:]) return err } return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, } if prefix.Addr().Is4() { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, } } else { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } ================================================ FILE: overlay/tun_notwin.go ================================================ //go:build !windows // +build !windows package overlay import "syscall" func ioctl(a1, a2, a3 uintptr) error { _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, a1, a2, a3) if errno != 0 { return errno } return nil } ================================================ FILE: overlay/tun_openbsd.go ================================================ //go:build !e2e_testing // +build !e2e_testing package overlay import ( "errors" "fmt" "io" "net/netip" "os" "regexp" "sync/atomic" "syscall" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" ) const ( SIOCAIFADDR_IN6 = 0x8080691a ) type ifreqAlias4 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet4 DstAddr unix.RawSockaddrInet4 MaskAddr unix.RawSockaddrInet4 } type ifreqAlias6 struct { Name [unix.IFNAMSIZ]byte Addr unix.RawSockaddrInet6 DstAddr unix.RawSockaddrInet6 PrefixMask unix.RawSockaddrInet6 Flags uint32 Lifetime [2]uint32 } type ifreq struct { Name [unix.IFNAMSIZ]byte data int } type tun struct { Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger f *os.File fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } if !deviceNameRE.MatchString(deviceName) { return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0) if err != nil { return nil, err } err = unix.SetNonblock(fd, true) if err != nil { l.WithError(err).Warn("Failed to set the tun device as nonblocking") } t := &tun{ f: os.NewFile(uintptr(fd), ""), fd: fd, Device: deviceName, vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } err = t.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *tun) Close() error { if t.f != nil { if err := t.f.Close(); err != nil { return fmt.Errorf("error closing tun file: %w", err) } // t.f.Close should have handled it for us but let's be extra sure _ = unix.Close(t.fd) } return nil } func (t *tun) Read(to []byte) (int, error) { buf := make([]byte, len(to)+4) n, err := t.f.Read(buf) copy(to, buf[4:]) return n - 4, err } // Write is only valid for single threaded use func (t *tun) Write(from []byte) (int, error) { buf := t.out if cap(buf) < len(from)+4 { buf = make([]byte, len(from)+4) t.out = buf } buf = buf[:len(from)+4] if len(from) == 0 { return 0, syscall.EIO } // Determine the IP Family for the NULL L2 Header ipVer := from[0] >> 4 if ipVer == 4 { buf[3] = syscall.AF_INET } else if ipVer == 6 { buf[3] = syscall.AF_INET6 } else { return 0, fmt.Errorf("unable to determine IP version from packet") } copy(buf[4:], from) n, err := t.f.Write(buf) return n - 4, err } func (t *tun) addIp(cidr netip.Prefix) error { if cidr.Addr().Is4() { var req ifreqAlias4 req.Name = t.deviceBytes() req.Addr = unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: cidr.Addr().As4(), } req.DstAddr = unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: cidr.Addr().As4(), } req.MaskAddr = unix.RawSockaddrInet4{ Len: unix.SizeofSockaddrInet4, Family: unix.AF_INET, Addr: prefixToMask(cidr).As4(), } s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err) } err = addRoute(cidr, t.vpnNetworks) if err != nil { return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err) } return nil } if cidr.Addr().Is6() { var req ifreqAlias6 req.Name = t.deviceBytes() req.Addr = unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: cidr.Addr().As16(), } req.PrefixMask = unix.RawSockaddrInet6{ Len: unix.SizeofSockaddrInet6, Family: unix.AF_INET6, Addr: prefixToMask(cidr).As16(), } req.Lifetime[0] = 0xffffffff req.Lifetime[1] = 0xffffffff s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil { return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) } return nil } return fmt.Errorf("unknown address type %v", cidr) } func (t *tun) Activate() error { err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU)) if err != nil { return fmt.Errorf("failed to set tun mtu: %w", err) } for i := range t.vpnNetworks { err = t.addIp(t.vpnNetworks[i]) if err != nil { return err } } return t.addRoutes(false) } func (t *tun) doIoctlByName(ctl uintptr, value uint32) error { s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return err } defer syscall.Close(s) ir := ifreq{Name: t.deviceBytes(), data: int(value)} err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir))) return err } func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) t.routeTree.Store(routeTree) if !initial { // Remove first, if the system removes a wanted route hopefully it will be re-added next err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } // Ensure any routes we actually want are installed err = t.addRoutes(true) if err != nil { // Catch any stray logs util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *tun) Name() string { return t.Device } func (t *tun) SupportsMultiqueue() bool { return false } func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } err := addRoute(r.Cidr, t.vpnNetworks) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } } else { t.l.WithField("route", r).Info("Added route") } } return nil } func (t *tun) removeRoutes(routes []Route) error { for _, r := range routes { if !r.Install { continue } err := delRoute(r.Cidr, t.vpnNetworks) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") } } return nil } func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) } return } func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP | unix.RTF_GATEWAY, Seq: 1, } if prefix.Addr().Is4() { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, } } else { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { if errors.Is(err, unix.EEXIST) { // Try to do a change route.Type = unix.RTM_CHANGE data, err = route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) } _, err = unix.Write(sock, data[:]) return err } return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) } defer unix.Close(sock) route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, } if prefix.Addr().Is4() { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, } } else { gw, err := selectGateway(prefix, gateways) if err != nil { return err } route.Addrs = []netroute.Addr{ unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, } } data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) } return nil } ================================================ FILE: overlay/tun_tester.go ================================================ //go:build e2e_testing // +build e2e_testing package overlay import ( "fmt" "io" "net/netip" "os" "sync/atomic" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) type TestTun struct { Device string vpnNetworks []netip.Prefix Routes []Route routeTree *bart.Table[routing.Gateways] l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err } routeTree, err := makeRouteTree(l, routes, false) if err != nil { return nil, err } return &TestTun{ Device: c.GetString("tun.dev", ""), vpnNetworks: vpnNetworks, Routes: routes, routeTree: routeTree, l: l, rxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10), }, nil } func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } // Send will place a byte array onto the receive queue for nebula to consume // These are unencrypted ip layer frames destined for another nebula node. // packets should exit the udp side, capture them with udpConn.Get func (t *TestTun) Send(packet []byte) { if t.closed.Load() { return } if t.l.Level >= logrus.DebugLevel { t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet") } t.rxPackets <- packet } // Get will pull an unencrypted ip layer frame from the transmit queue // nebula meant to send this message to some application on the local system // packets were ingested from the udp side, you can send them with udpConn.Send func (t *TestTun) Get(block bool) []byte { if block { return <-t.TxPackets } select { case p := <-t.TxPackets: return p default: return nil } } //********************************************************************************************************************// // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Lookup(ip) return r } func (t *TestTun) Activate() error { return nil } func (t *TestTun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *TestTun) Name() string { return t.Device } func (t *TestTun) Write(b []byte) (n int, err error) { if t.closed.Load() { return 0, io.ErrClosedPipe } packet := make([]byte, len(b), len(b)) copy(packet, b) t.TxPackets <- packet return len(b), nil } func (t *TestTun) Close() error { if t.closed.CompareAndSwap(false, true) { close(t.rxPackets) close(t.TxPackets) } return nil } func (t *TestTun) Read(b []byte) (int, error) { p, ok := <-t.rxPackets if !ok { return 0, os.ErrClosed } copy(b, p) return len(p), nil } func (t *TestTun) SupportsMultiqueue() bool { return false } func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } ================================================ FILE: overlay/tun_windows.go ================================================ //go:build !e2e_testing // +build !e2e_testing package overlay import ( "crypto" "fmt" "io" "net/netip" "os" "path/filepath" "runtime" "sync/atomic" "syscall" "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger tun *wintun.NativeTun } func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) } deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) } t := &winTun{ Device: deviceName, vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } err = t.reload(c, true) if err != nil { return nil, err } var tunDevice wintun.Device tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. // Trying a second time resolves the issue. l.WithError(err).Debug("Failed to create wintun device, retrying") tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { return nil, &NameError{ Name: deviceName, Underlying: fmt.Errorf("create TUN device failed: %w", err), } } } t.tun = tunDevice.(*wintun.NativeTun) c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) if err != nil { util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) } }) return t, nil } func (t *winTun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } if !initial && !change { return nil } routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err } // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) t.routeTree.Store(routeTree) if !initial { // Remove first, if the system removes a wanted route hopefully it will be re-added next err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) if err != nil { util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) } // Ensure any routes we actually want are installed err = t.addRoutes(true) if err != nil { // Catch any stray logs util.LogWithContextIfNeeded("Failed to add routes", err, t.l) } } return nil } func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) err := luid.SetIPAddresses(t.vpnNetworks) if err != nil { return fmt.Errorf("failed to set address: %w", err) } err = t.addRoutes(false) if err != nil { return err } return nil } func (t *winTun) addRoutes(logErrors bool) error { luid := winipcfg.LUID(t.tun.LUID()) routes := *t.Routes.Load() foundDefault4 := false for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } // Add our unsafe route // Windows does not support multipath routes natively, so we install only a single route. // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) continue } else { return retErr } } else { t.l.WithField("route", r).Info("Added route") } if !foundDefault4 { if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { foundDefault4 = true } } } ipif, err := luid.IPInterface(windows.AF_INET) if err != nil { return fmt.Errorf("failed to get ip interface: %w", err) } ipif.NLMTU = uint32(t.MTU) if foundDefault4 { ipif.UseAutomaticMetric = false ipif.Metric = 0 } if err := ipif.Set(); err != nil { return fmt.Errorf("failed to set ip interface: %w", err) } return nil } func (t *winTun) removeRoutes(routes []Route) error { luid := winipcfg.LUID(t.tun.LUID()) for _, r := range routes { if !r.Install { continue } // See comment on luid.AddRoute err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") } } return nil } func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *winTun) Networks() []netip.Prefix { return t.vpnNetworks } func (t *winTun) Name() string { return t.Device } func (t *winTun) Read(b []byte) (int, error) { return t.tun.Read(b, 0) } func (t *winTun) Write(b []byte) (int, error) { return t.tun.Write(b, 0) } func (t *winTun) SupportsMultiqueue() bool { return false } func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") } func (t *winTun) Close() error { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, // so to be certain, just remove everything before destroying. luid := winipcfg.LUID(t.tun.LUID()) _ = luid.FlushRoutes(windows.AF_INET) _ = luid.FlushIPAddresses(windows.AF_INET) _ = luid.FlushRoutes(windows.AF_INET6) _ = luid.FlushIPAddresses(windows.AF_INET6) _ = luid.FlushDNS(windows.AF_INET) _ = luid.FlushDNS(windows.AF_INET6) return t.tun.Close() } func generateGUIDByDeviceName(name string) (*windows.GUID, error) { // GUID is 128 bit hash := crypto.MD5.New() _, err := hash.Write([]byte(tunGUIDLabel)) if err != nil { return nil, err } _, err = hash.Write([]byte(name)) if err != nil { return nil, err } sum := hash.Sum(nil) return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } func checkWinTunExists() error { myPath, err := os.Executable() if err != nil { return err } arch := runtime.GOARCH switch arch { case "386": //NOTE: wintun bundles 386 as x86 arch = "x86" } _, err = syscall.LoadDLL(filepath.Join(filepath.Dir(myPath), "dist", "windows", "wintun", "bin", arch, "wintun.dll")) return err } ================================================ FILE: overlay/user.go ================================================ package overlay import ( "io" "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() return &UserDevice{ vpnNetworks: vpnNetworks, outboundReader: or, outboundWriter: ow, inboundReader: ir, inboundWriter: iw, }, nil } type UserDevice struct { vpnNetworks []netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter inboundReader *io.PipeReader inboundWriter *io.PipeWriter } func (d *UserDevice) Activate() error { return nil } func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { return routing.Gateways{routing.NewGateway(ip, 1)} } func (d *UserDevice) SupportsMultiqueue() bool { return true } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { return d.inboundReader, d.outboundWriter } func (d *UserDevice) Read(p []byte) (n int, err error) { return d.outboundReader.Read(p) } func (d *UserDevice) Write(p []byte) (n int, err error) { return d.inboundWriter.Write(p) } func (d *UserDevice) Close() error { d.inboundWriter.Close() d.outboundWriter.Close() return nil } ================================================ FILE: pkclient/pkclient.go ================================================ package pkclient import ( "crypto/ecdsa" "crypto/x509" "fmt" "io" "strconv" "github.com/stefanberger/go-pkcs11uri" ) type Client interface { io.Closer GetPubKey() ([]byte, error) DeriveNoise(peerPubKey []byte) ([]byte, error) Test() error } const NoiseKeySize = 32 func FromUrl(pkurl string) (*PKClient, error) { uri := pkcs11uri.New() uri.SetAllowAnyModule(true) //todo err := uri.Parse(pkurl) if err != nil { return nil, err } module, err := uri.GetModule() if err != nil { return nil, err } slotid := 0 slot, ok := uri.GetPathAttribute("slot-id", false) if !ok { slotid = 0 } else { slotid, err = strconv.Atoi(slot) if err != nil { return nil, err } } pin, _ := uri.GetPIN() id, _ := uri.GetPathAttribute("id", false) label, _ := uri.GetPathAttribute("object", false) return New(module, uint(slotid), pin, id, label) } func ecKeyToArray(key *ecdsa.PublicKey) []byte { x := make([]byte, 32) y := make([]byte, 32) key.X.FillBytes(x) key.Y.FillBytes(y) return append([]byte{0x04}, append(x, y...)...) } func formatPubkeyFromPublicKeyInfoAttr(d []byte) ([]byte, error) { e, err := x509.ParsePKIXPublicKey(d) if err != nil { return nil, err } switch t := e.(type) { case *ecdsa.PublicKey: return ecKeyToArray(e.(*ecdsa.PublicKey)), nil default: return nil, fmt.Errorf("unknown public key type: %T", t) } } func (c *PKClient) Test() error { pub, err := c.GetPubKey() if err != nil { return fmt.Errorf("failed to get public key: %w", err) } out, err := c.DeriveNoise(pub) //do an ECDH with ourselves as a quick test if err != nil { return err } if len(out) != NoiseKeySize { return fmt.Errorf("got a key of %d bytes, expected %d", len(out), NoiseKeySize) } return nil } ================================================ FILE: pkclient/pkclient_cgo.go ================================================ //go:build cgo && pkcs11 package pkclient import ( "encoding/asn1" "errors" "fmt" "log" "math/big" "github.com/miekg/pkcs11" "github.com/miekg/pkcs11/p11" ) type PKClient struct { module p11.Module session p11.Session id []byte label []byte privKeyObj p11.Object pubKeyObj p11.Object } type ecdsaSignature struct { R, S *big.Int } // New tries to open a session with the HSM, select the slot and login to it func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { module, err := p11.OpenModule(hsmPath) if err != nil { return nil, fmt.Errorf("failed to load module library: %s", hsmPath) } slots, err := module.Slots() if err != nil { module.Destroy() return nil, err } // Try to open a session on the slot slotIdx := 0 for i, slot := range slots { if slot.ID() == slotId { slotIdx = i break } } client := &PKClient{ module: module, id: []byte(id), label: []byte(label), } client.session, err = slots[slotIdx].OpenWriteSession() if err != nil { module.Destroy() return nil, fmt.Errorf("failed to open session on slot %d", slotId) } if len(pin) != 0 { err = client.session.Login(pin) if err != nil { // ignore "already logged in" if !errors.Is(err, pkcs11.Error(256)) { _ = client.session.Close() return nil, fmt.Errorf("unable to login. error: %w", err) } } } // Make sure the hsm has a private key for deriving client.privKeyObj, err = client.findDeriveKey(client.id, client.label, true) if err != nil { _ = client.Close() //log out, close session, destroy module return nil, fmt.Errorf("failed to find private key for deriving: %w", err) } return client, nil } // Close cleans up properly and logs out func (c *PKClient) Close() error { var err error = nil if c.session != nil { _ = c.session.Logout() //if logout fails, we still want to close err = c.session.Close() } c.module.Destroy() return err } // Try to find a suitable key on the hsm for key derivation // parameter GET_PUB_KEY sets the search pattern for a public or private key func (c *PKClient) findDeriveKey(id []byte, label []byte, private bool) (key p11.Object, err error) { keyClass := pkcs11.CKO_PRIVATE_KEY if !private { keyClass = pkcs11.CKO_PUBLIC_KEY } keyAttrs := []*pkcs11.Attribute{ //todo, not all HSMs seem to report this, even if its true: pkcs11.NewAttribute(pkcs11.CKA_DERIVE, true), pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), } if id != nil && len(id) != 0 { keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) } if label != nil && len(label) != 0 { keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) } return c.session.FindObject(keyAttrs) } func (c *PKClient) listDeriveKeys(id []byte, label []byte, private bool) { keyClass := pkcs11.CKO_PRIVATE_KEY if !private { keyClass = pkcs11.CKO_PUBLIC_KEY } keyAttrs := []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_CLASS, keyClass), } if id != nil && len(id) != 0 { keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) } if label != nil && len(label) != 0 { keyAttrs = append(keyAttrs, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) } objects, err := c.session.FindObjects(keyAttrs) if err != nil { return } for _, obj := range objects { l, err := obj.Label() log.Printf("%s, %v", l, err) a, err := obj.Attribute(pkcs11.CKA_DERIVE) log.Printf("DERIVE: %s %v, %v", l, a, err) } } // SignASN1 signs some data. Returns the ASN.1 encoded signature. func (c *PKClient) SignASN1(data []byte) ([]byte, error) { mech := pkcs11.NewMechanism(pkcs11.CKM_ECDSA_SHA256, nil) sk := p11.PrivateKey(c.privKeyObj) rawSig, err := sk.Sign(*mech, data) if err != nil { return nil, err } // PKCS #11 Mechanisms v2.30: // "The signature octets correspond to the concatenation of the ECDSA values r and s, // both represented as an octet string of equal length of at most nLen with the most // significant byte first. If r and s have different octet length, the shorter of both // must be padded with leading zero octets such that both have the same octet length. // Loosely spoken, the first half of the signature is r and the second half is s." r := new(big.Int).SetBytes(rawSig[:len(rawSig)/2]) s := new(big.Int).SetBytes(rawSig[len(rawSig)/2:]) return asn1.Marshal(ecdsaSignature{r, s}) } // DeriveNoise derives a shared secret using the input public key against the private key that was found during setup. // Returns a fixed 32 byte array. func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) { // Before we call derive, we need to have an array of attributes which specify the type of // key to be returned, in our case, it's the shared secret key, produced via deriving // This template pulled from OpenSC pkclient-tool.c line 4038 attrTemplate := []*pkcs11.Attribute{ pkcs11.NewAttribute(pkcs11.CKA_TOKEN, false), pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, pkcs11.CKK_GENERIC_SECRET), pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, false), pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, true), pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, true), pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize), } // Set up the parameters which include the peer's public key ecdhParams := pkcs11.NewECDH1DeriveParams(pkcs11.CKD_NULL, nil, peerPubKey) mech := pkcs11.NewMechanism(pkcs11.CKM_ECDH1_DERIVE, ecdhParams) sk := p11.PrivateKey(c.privKeyObj) tmpKey, err := sk.Derive(*mech, attrTemplate) if err != nil { return nil, err } if tmpKey == nil || len(tmpKey) == 0 { return nil, fmt.Errorf("got an empty secret key") } secret := make([]byte, NoiseKeySize) copy(secret[:], tmpKey[:NoiseKeySize]) return secret, nil } func (c *PKClient) GetPubKey() ([]byte, error) { d, err := c.privKeyObj.Attribute(pkcs11.CKA_PUBLIC_KEY_INFO) if err != nil { return nil, err } if d != nil && len(d) > 0 { return formatPubkeyFromPublicKeyInfoAttr(d) } c.pubKeyObj, err = c.findDeriveKey(c.id, c.label, false) if err != nil { return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and looking up the public key also failed: %w", err) } d, err = c.pubKeyObj.Attribute(pkcs11.CKA_EC_POINT) if err != nil { return nil, fmt.Errorf("pkcs11 module gave us a nil CKA_PUBLIC_KEY_INFO, and reading CKA_EC_POINT also failed: %w", err) } if d == nil || len(d) < 1 { return nil, fmt.Errorf("pkcs11 module gave us a nil or empty CKA_EC_POINT") } switch len(d) { case 65: //length of 0x04 + len(X) + len(Y) return d, nil case 67: //as above, DER-encoded IIRC? return d[2:], nil default: return nil, fmt.Errorf("unknown public key length: %d", len(d)) } } ================================================ FILE: pkclient/pkclient_stub.go ================================================ //go:build !cgo || !pkcs11 package pkclient import "errors" type PKClient struct { } var notImplemented = errors.New("not implemented") func New(hsmPath string, slotId uint, pin string, id string, label string) (*PKClient, error) { return nil, notImplemented } func (c *PKClient) Close() error { return nil } func (c *PKClient) SignASN1(data []byte) ([]byte, error) { return nil, notImplemented } func (c *PKClient) DeriveNoise(_ []byte) ([]byte, error) { return nil, notImplemented } func (c *PKClient) GetPubKey() ([]byte, error) { return nil, notImplemented } ================================================ FILE: pki.go ================================================ package nebula import ( "encoding/binary" "encoding/json" "errors" "fmt" "net" "net/netip" "os" "slices" "strings" "sync/atomic" "time" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/util" ) type PKI struct { cs atomic.Pointer[CertState] caPool atomic.Pointer[cert.CAPool] l *logrus.Logger } type CertState struct { v1Cert cert.Certificate v1HandshakeBytes []byte v2Cert cert.Certificate v2HandshakeBytes []byte initiatingVersion cert.Version privateKey []byte pkcs11Backed bool cipher string myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Lite myVpnAddrs []netip.Addr myVpnAddrsTable *bart.Lite myVpnBroadcastAddrsTable *bart.Lite } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { pki := &PKI{l: l} err := pki.reload(c, true) if err != nil { return nil, err } c.RegisterReloadCallback(func(c *config.C) { rErr := pki.reload(c, false) if rErr != nil { util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l) } }) return pki, nil } func (p *PKI) GetCAPool() *cert.CAPool { return p.caPool.Load() } func (p *PKI) getCertState() *CertState { return p.cs.Load() } func (p *PKI) reload(c *config.C, initial bool) error { err := p.reloadCerts(c, initial) if err != nil { if initial { return err } err.Log(p.l) } err = p.reloadCAPool(c) if err != nil { if initial { return err } err.Log(p.l) } return nil } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { newState, err := newCertStateFromConfig(c) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } if !initial { currentState := p.cs.Load() if newState.v1Cert != nil { if currentState.v1Cert == nil { //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). } else { // did IP in cert change? if so, don't set if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { return util.NewContextualError( "Networks in new cert was different from old", m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1}, nil, ) } if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { return util.NewContextualError( "Curve in new v1 cert was different from old", m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1}, nil, ) } } } if newState.v2Cert != nil { if currentState.v2Cert == nil { //adding certs is fine, actually } else { // did IP in cert change? if so, don't set if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { return util.NewContextualError( "Networks in new cert was different from old", m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2}, nil, ) } if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { return util.NewContextualError( "Curve in new cert was different from old", m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2}, nil, ) } } } else if currentState.v2Cert != nil { //newState.v1Cert is non-nil bc empty certstates aren't permitted if newState.v1Cert == nil { return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err) } //if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) { return util.NewContextualError( "Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert", m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()}, nil, ) } } // Cipher cant be hot swapped so just leave it at what it was before newState.cipher = currentState.cipher } else { newState.cipher = c.GetString("cipher", "aes") //TODO: this sucks and we should make it not a global switch newState.cipher { case "aes": noiseEndianness = binary.BigEndian case "chachapoly": noiseEndianness = binary.LittleEndian default: return util.NewContextualError( "unknown cipher", m{"cipher": newState.cipher}, nil, ) } } p.cs.Store(newState) if initial { p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") } return nil } func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { caPool, err := loadCAPoolFromConfig(p.l, c) if err != nil { return util.NewContextualError("Failed to load ca from config", nil, err) } p.caPool.Store(caPool) p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints") return nil } func (cs *CertState) GetDefaultCertificate() cert.Certificate { c := cs.getCertificate(cs.initiatingVersion) if c == nil { panic("No default certificate found") } return c } func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { switch v { case cert.Version1: return cs.v1Cert case cert.Version2: return cs.v2Cert } return nil } // getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. // Callers must check if the return []byte is nil. func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { switch v { case cert.Version1: return cs.v1HandshakeBytes case cert.Version2: return cs.v2HandshakeBytes default: return nil } } func (cs *CertState) String() string { b, err := cs.MarshalJSON() if err != nil { return fmt.Sprintf("error marshaling certificate state: %v", err) } return string(b) } func (cs *CertState) MarshalJSON() ([]byte, error) { msg := []json.RawMessage{} if cs.v1Cert != nil { b, err := cs.v1Cert.MarshalJSON() if err != nil { return nil, err } msg = append(msg, b) } if cs.v2Cert != nil { b, err := cs.v2Cert.MarshalJSON() if err != nil { return nil, err } msg = append(msg, b) } return json.Marshal(msg) } func newCertStateFromConfig(c *config.C) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") if privPathOrPEM == "" { return nil, errors.New("no pki.key path or PEM data provided") } rawKey, curve, isPkcs11, err := loadPrivateKey(privPathOrPEM) if err != nil { return nil, err } var rawCert []byte pubPathOrPEM := c.GetString("pki.cert", "") if pubPathOrPEM == "" { return nil, errors.New("no pki.cert path or PEM data provided") } if strings.Contains(pubPathOrPEM, "-----BEGIN") { rawCert = []byte(pubPathOrPEM) pubPathOrPEM = "" } else { rawCert, err = os.ReadFile(pubPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err) } } var crt, v1, v2 cert.Certificate for { // Load the certificate crt, rawCert, err = loadCertificate(rawCert) if err != nil { return nil, err } switch crt.Version() { case cert.Version1: if v1 != nil { return nil, fmt.Errorf("v1 certificate already found in pki.cert") } v1 = crt case cert.Version2: if v2 != nil { return nil, fmt.Errorf("v2 certificate already found in pki.cert") } v2 = crt default: return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) } if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { break } } if v1 == nil && v2 == nil { return nil, errors.New("no certificates found in pki.cert") } useInitiatingVersion := uint32(1) if v1 == nil { // The only condition that requires v2 as the default is if only a v2 certificate is present // We do this to avoid having to configure it specifically in the config file useInitiatingVersion = 2 } rawInitiatingVersion := c.GetUint32("pki.initiating_version", useInitiatingVersion) var initiatingVersion cert.Version switch rawInitiatingVersion { case 1: if v1 == nil { return nil, fmt.Errorf("can not use pki.initiating_version 1 without a v1 certificate in pki.cert") } initiatingVersion = cert.Version1 case 2: initiatingVersion = cert.Version2 default: return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) } func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, myVpnNetworksTable: new(bart.Lite), myVpnAddrsTable: new(bart.Lite), myVpnBroadcastAddrsTable: new(bart.Lite), } if v1 != nil && v2 != nil { if !slices.Equal(v1.PublicKey(), v2.PublicKey()) { return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil) } if v1.Curve() != v2.Curve() { return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) } if v1.Networks()[0] != v2.Networks()[0] { return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil) } cs.initiatingVersion = dv } if v1 != nil { if pkcs11backed { //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm } else { if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") } } v1hs, err := v1.MarshalForHandshakes() if err != nil { return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) } cs.v1Cert = v1 cs.v1HandshakeBytes = v1hs if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version1 } } if v2 != nil { if pkcs11backed { //NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm } else { if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") } } v2hs, err := v2.MarshalForHandshakes() if err != nil { return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) } cs.v2Cert = v2 cs.v2HandshakeBytes = v2hs if cs.initiatingVersion == 0 { cs.initiatingVersion = cert.Version2 } } var crt cert.Certificate crt = cs.getCertificate(cert.Version2) if crt == nil { // v2 certificates are a superset, only look at v1 if its all we have crt = cs.getCertificate(cert.Version1) } for _, network := range crt.Networks() { cs.myVpnNetworks = append(cs.myVpnNetworks, network) cs.myVpnNetworksTable.Insert(network) cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen())) if network.Addr().Is4() { addr := network.Masked().Addr().As4() mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen())) } } return &cs, nil } func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { var pemPrivateKey []byte if strings.Contains(privPathOrPEM, "-----BEGIN") { pemPrivateKey = []byte(privPathOrPEM) privPathOrPEM = "" rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) if err != nil { return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) } } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { rawKey = []byte(privPathOrPEM) return rawKey, cert.Curve_P256, true, nil } else { pemPrivateKey, err = os.ReadFile(privPathOrPEM) if err != nil { return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) } rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) if err != nil { return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) } } return } func loadCertificate(b []byte) (cert.Certificate, []byte, error) { c, b, err := cert.UnmarshalCertificateFromPEM(b) if err != nil { return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err) } if c.Expired(time.Now()) { return nil, b, fmt.Errorf("nebula certificate for this host is expired") } if len(c.Networks()) == 0 { return nil, b, fmt.Errorf("no networks encoded in certificate") } if c.IsCA() { return nil, b, fmt.Errorf("host certificate is a CA certificate") } return c, b, nil } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { var rawCA []byte var err error caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") } if strings.Contains(caPathOrPEM, "-----BEGIN") { rawCA = []byte(caPathOrPEM) } else { rawCA, err = os.ReadFile(caPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) } } caPool, err := cert.NewCAPoolFromPEM(rawCA) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { if crt.Certificate.Expired(time.Now()) { expired++ l.WithField("cert", crt).Warn("expired certificate present in CA pool") } } if expired >= len(caPool.CAs) { return nil, errors.New("no valid CA certificates present") } } else if err != nil { return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) } bl := c.GetStringSlice("pki.blocklist", []string{}) if len(bl) > 0 { for _, fp := range bl { caPool.BlocklistFingerprint(fp) } l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") } return caPool, nil } ================================================ FILE: punchy.go ================================================ package nebula import ( "sync/atomic" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type Punchy struct { punch atomic.Bool respond atomic.Bool delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool l *logrus.Logger } func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy { p := &Punchy{l: l} p.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { p.reload(c, false) }) return p } func (p *Punchy) reload(c *config.C, initial bool) { if initial { var yes bool if c.IsSet("punchy.punch") { yes = c.GetBool("punchy.punch", false) } else { // Deprecated fallback yes = c.GetBool("punchy", false) } p.punch.Store(yes) if yes { p.l.Info("punchy enabled") } else { p.l.Info("punchy disabled") } } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") } if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") { var yes bool if c.IsSet("punchy.respond") { yes = c.GetBool("punchy.respond", false) } else { // Deprecated fallback yes = c.GetBool("punch_back", false) } p.respond.Store(yes) if !initial { p.l.Infof("punchy.respond changed to %v", p.GetRespond()) } } //NOTE: this will not apply to any in progress operations, only the next one if initial || c.HasChanged("punchy.delay") { p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) if !initial { p.l.Infof("punchy.delay changed to %s", p.GetDelay()) } } if initial || c.HasChanged("punchy.target_all_remotes") { p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) if !initial { p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed") } } if initial || c.HasChanged("punchy.respond_delay") { p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) if !initial { p.l.Infof("punchy.respond_delay changed to %s", p.GetRespondDelay()) } } } func (p *Punchy) GetPunch() bool { return p.punch.Load() } func (p *Punchy) GetRespond() bool { return p.respond.Load() } func (p *Punchy) GetDelay() time.Duration { return (time.Duration)(p.delay.Load()) } func (p *Punchy) GetRespondDelay() time.Duration { return (time.Duration)(p.respondDelay.Load()) } func (p *Punchy) GetTargetEverything() bool { return p.punchEverything.Load() } ================================================ FILE: punchy_test.go ================================================ package nebula import ( "testing" "time" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewPunchyFromConfig(t *testing.T) { l := test.NewLogger() c := config.NewC(l) // Test defaults p := NewPunchyFromConfig(l, c) assert.False(t, p.GetPunch()) assert.False(t, p.GetRespond()) assert.Equal(t, time.Second, p.GetDelay()) assert.Equal(t, 5*time.Second, p.GetRespondDelay()) // punchy deprecation c.Settings["punchy"] = true p = NewPunchyFromConfig(l, c) assert.True(t, p.GetPunch()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} p = NewPunchyFromConfig(l, c) assert.True(t, p.GetPunch()) // punch_back deprecation c.Settings["punch_back"] = true p = NewPunchyFromConfig(l, c) assert.True(t, p.GetRespond()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false p = NewPunchyFromConfig(l, c) assert.True(t, p.GetRespond()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetDelay()) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} p = NewPunchyFromConfig(l, c) assert.Equal(t, time.Minute, p.GetRespondDelay()) } func TestPunchy_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) delay, _ := time.ParseDuration("1m") require.NoError(t, c.LoadString(` punchy: delay: 1m respond: false `)) p := NewPunchyFromConfig(l, c) assert.Equal(t, delay, p.GetDelay()) assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") require.NoError(t, c.ReloadConfigString(` punchy: delay: 10m respond: true `)) p.reload(c, false) assert.Equal(t, newDelay, p.GetDelay()) assert.True(t, p.GetRespond()) } ================================================ FILE: relay_manager.go ================================================ package nebula import ( "context" "encoding/binary" "errors" "fmt" "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type relayManager struct { l *logrus.Logger hostmap *HostMap amRelay atomic.Bool } func NewRelayManager(ctx context.Context, l *logrus.Logger, hostmap *HostMap, c *config.C) *relayManager { rm := &relayManager{ l: l, hostmap: hostmap, } rm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { err := rm.reload(c, false) if err != nil { l.WithError(err).Error("Failed to reload relay_manager") } }) return rm } func (rm *relayManager) reload(c *config.C, initial bool) error { if initial || c.HasChanged("relay.am_relay") { rm.setAmRelay(c.GetBool("relay.am_relay", false)) } return nil } func (rm *relayManager) GetAmRelay() bool { return rm.amRelay.Load() } func (rm *relayManager) setAmRelay(v bool) { rm.amRelay.Store(v) } // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for range 32 { index, err := generateIndex(l) if err != nil { return 0, err } _, inRelays := hm.Relays[index] if !inRelays { // Avoid standing up a relay that can't be used since only the primary hostinfo // will be pointed to by the relay logic //TODO: if there was an existing primary and it had relay state, should we merge? hm.unlockedMakePrimary(relayHostInfo) hm.Relays[index] = relayHostInfo newRelay := Relay{ Type: relayType, State: state, LocalIndex: index, PeerAddr: vpnIp, } if remoteIdx != nil { newRelay.RemoteIndex = *remoteIdx } relayHostInfo.relayState.InsertRelay(vpnIp, index, &newRelay) return index, nil } } return 0, errors.New("failed to generate unique localIndexId") } // EstablishRelay updates a Requested Relay to become an Established Relay, which can pass traffic. func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { fields := logrus.Fields{ "relay": relayHostInfo.vpnAddrs[0], "initiatorRelayIndex": m.InitiatorRelayIndex, } if m.RelayFromAddr == nil { fields["relayFrom"] = m.OldRelayFromAddr } else { fields["relayFrom"] = m.RelayFromAddr } if m.RelayToAddr == nil { fields["relayTo"] = m.OldRelayToAddr } else { fields["relayTo"] = m.RelayToAddr } rm.l.WithFields(fields).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } return relay, nil } func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { msg := &NebulaControl{} err := msg.Unmarshal(d) if err != nil { h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") return } var v cert.Version if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 { v = cert.Version1 b := [4]byte{} binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr) msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr) msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) } else { v = cert.Version2 } switch msg.Type { case NebulaControl_CreateRelayRequest: rm.handleCreateRelayRequest(v, h, f, msg) case NebulaControl_CreateRelayResponse: rm.handleCreateRelayResponse(v, h, f, msg) } } func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), "relayTo": protoAddrToNetAddr(m.RelayToAddr), "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, "vpnAddrs": h.vpnAddrs}). Info("handleCreateRelayResponse") target := m.RelayToAddr targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { rm.l.WithError(err).Error("Failed to update relay for relayTo") return } // Do I need to complete the relays now? if relay.Type == TerminalType { return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") return } switch peerRelay.State { case Requested: // I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer // to respond to complete the connection. case PeerRequested, Disestablished, Established: peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established) resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, } if v == cert.Version1 { peer := peerHostInfo.vpnAddrs[0] if !peer.Is4() { rm.l.WithField("relayFrom", peer). WithField("relayTo", target). WithField("initiatorRelayIndex", resp.InitiatorRelayIndex). WithField("responderRelayIndex", resp.ResponderRelayIndex). WithField("vpnAddrs", peerHostInfo.vpnAddrs). Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address") return } b := peer.As4() resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) b = targetAddr.As4() resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) } else { resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) resp.RelayToAddr = target } msg, err := resp.Marshal() if err != nil { rm.l.WithError(err). Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ "relayFrom": resp.RelayFromAddr, "relayTo": resp.RelayToAddr, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnAddrs": peerHostInfo.vpnAddrs}). Info("send CreateRelayResponse") } } } func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { from := protoAddrToNetAddr(m.RelayFromAddr) target := protoAddrToNetAddr(m.RelayToAddr) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": m.InitiatorRelayIndex, "vpnAddrs": h.vpnAddrs}) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. if f.myVpnAddrsTable.Contains(from) { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? if f.myVpnAddrsTable.Contains(target) { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { case Requested: ok = h.relayState.CompleteRelayByIP(from, m.InitiatorRelayIndex) if !ok { logMsg.Error("Relay State not found") return } case Established: if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. logMsg.WithFields(logrus.Fields{ "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } case Disestablished: if existingRelay.RemoteIndex != m.InitiatorRelayIndex { // We got a brand new Relay request, because its index is different than what we saw before. // This should never happen. The peer should never change an index, once created. logMsg.WithFields(logrus.Fields{ "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } // Mark the relay as 'Established' because it's safe to use again h.relayState.UpdateRelayForByIpState(from, Established) case PeerRequested: // I should never be in this state, because I am terminal, not forwarding. logMsg.WithFields(logrus.Fields{ "existingRemoteIndex": existingRelay.RemoteIndex, "state": existingRelay.State}).Error("Unexpected Relay State found") } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) if err != nil { logMsg.WithError(err).Error("Failed to add relay") return } } relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { logMsg.WithField("from", from).Error("Relay State not found") return } resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, } if v == cert.Version1 { b := from.As4() resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) b = target.As4() resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) } else { resp.RelayFromAddr = netAddrToProtoAddr(from) resp.RelayToAddr = netAddrToProtoAddr(target) } msg, err := resp.Marshal() if err != nil { logMsg. WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } return } else { // the target is not me. Create a relay to the target, from me. if !rm.GetAmRelay() { return } peer := rm.hostmap.QueryVpnAddr(target) if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! f.Handshake(target) return } if !peer.remote.IsValid() { // Only create relays to peers for whom I have a direct connection return } var index uint32 var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex } else { // Allocate an index in the hostMap for this relay peer index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) if err != nil { return } } peer.relayState.UpdateRelayForByIpState(from, Requested) // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, } if v == cert.Version1 { if !h.vpnAddrs[0].Is4() { rm.l.WithField("relayFrom", h.vpnAddrs[0]). WithField("relayTo", target). WithField("initiatorRelayIndex", req.InitiatorRelayIndex). WithField("responderRelayIndex", req.ResponderRelayIndex). WithField("vpnAddr", target). Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address") return } b := h.vpnAddrs[0].As4() req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) b = target.As4() req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) } else { req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) req.RelayToAddr = netAddrToProtoAddr(target) } msg, err := req.Marshal() if err != nil { logMsg. WithError(err).Error("relayManager Failed to marshal Control message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ "relayFrom": h.vpnAddrs[0], "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnAddr": target}). Info("send CreateRelayRequest") } // Also track the half-created Relay state just received _, ok = h.relayState.QueryRelayForByIp(target) if !ok { _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { logMsg. WithError(err).Error("relayManager Failed to allocate a local index for relay") return } } } } ================================================ FILE: remote_list.go ================================================ package nebula import ( "context" "net" "net/netip" "slices" "sort" "strconv" "sync" "sync/atomic" "time" "github.com/sirupsen/logrus" ) // forEachFunc is used to benefit folks that want to do work inside the lock type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { Learned []netip.AddrPort `json:"learned,omitempty"` Reported []netip.AddrPort `json:"reported,omitempty"` Relay []netip.Addr `json:"relay"` } // cache is an internal struct that splits v4 and v6 addresses inside the cache map type cache struct { v4 *cacheV4 v6 *cacheV6 relay *cacheRelay } type cacheRelay struct { relay []netip.Addr } // cacheV4 stores learned and reported ipv4 records under cache type cacheV4 struct { learned *V4AddrPort reported []*V4AddrPort } // cacheV4 stores learned and reported ipv6 records under cache type cacheV6 struct { learned *V6AddrPort reported []*V6AddrPort } type hostnamePort struct { name string port uint16 } type hostnamesResults struct { hostnames []hostnamePort network string lookupTimeout time.Duration cancelFn func() l *logrus.Logger ips atomic.Pointer[map[netip.AddrPort]struct{}] } func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, network string, timeout time.Duration, hostPorts []string, onUpdate func()) (*hostnamesResults, error) { r := &hostnamesResults{ hostnames: make([]hostnamePort, len(hostPorts)), network: network, lookupTimeout: timeout, l: l, } // Fastrack IP addresses to ensure they're immediately available for use. // DNS lookups for hostnames that aren't hardcoded IP's will happen in a background goroutine. performBackgroundLookup := false ips := map[netip.AddrPort]struct{}{} for idx, hostPort := range hostPorts { rIp, sPort, err := net.SplitHostPort(hostPort) if err != nil { return nil, err } iPort, err := strconv.Atoi(sPort) if err != nil { return nil, err } r.hostnames[idx] = hostnamePort{name: rIp, port: uint16(iPort)} addr, err := netip.ParseAddr(rIp) if err != nil { // This address is a hostname, not an IP address performBackgroundLookup = true continue } // Save the IP address immediately ips[netip.AddrPortFrom(addr, uint16(iPort))] = struct{}{} } r.ips.Store(&ips) // Time for the DNS lookup goroutine if performBackgroundLookup { newCtx, cancel := context.WithCancel(ctx) r.cancelFn = cancel ticker := time.NewTicker(d) go func() { defer ticker.Stop() for { netipAddrs := map[netip.AddrPort]struct{}{} for _, hostPort := range r.hostnames { timeoutCtx, timeoutCancel := context.WithTimeout(ctx, r.lookupTimeout) addrs, err := net.DefaultResolver.LookupNetIP(timeoutCtx, r.network, hostPort.name) timeoutCancel() if err != nil { l.WithFields(logrus.Fields{"hostname": hostPort.name, "network": r.network}).WithError(err).Error("DNS resolution failed for static_map host") continue } for _, a := range addrs { netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{} } } origSet := r.ips.Load() different := false for a := range *origSet { if _, ok := netipAddrs[a]; !ok { different = true break } } if !different { for a := range netipAddrs { if _, ok := (*origSet)[a]; !ok { different = true break } } } if different { l.WithFields(logrus.Fields{"origSet": origSet, "newSet": netipAddrs}).Info("DNS results changed for host list") r.ips.Store(&netipAddrs) onUpdate() } select { case <-newCtx.Done(): return case <-ticker.C: continue } } }() } return r, nil } func (hr *hostnamesResults) Cancel() { if hr != nil && hr.cancelFn != nil { hr.cancelFn() } } func (hr *hostnamesResults) GetAddrs() []netip.AddrPort { var retSlice []netip.AddrPort if hr != nil { p := hr.ips.Load() if p != nil { for k := range *p { retSlice = append(retSlice, k) } } } return retSlice } // RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos. // It serves as a local cache of query replies, host update notifications, and locally learned addresses type RemoteList struct { // Every interaction with internals requires a lock! sync.RWMutex // The full list of vpn addresses assigned to this host vpnAddrs []netip.Addr // A deduplicated set of underlay addresses. Any accessor should lock beforehand. addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. relays []netip.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet cache map[netip.Addr]*cache hr *hostnamesResults // shouldAdd is a nillable function that decides if x should be added to addrs. shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake badRemotes []netip.AddrPort // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool } // NewRemoteList creates a new empty RemoteList func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList { r := &RemoteList{ vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), relays: make([]netip.Addr, 0), cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } copy(r.vpnAddrs, vpnAddrs) return r } func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { // Cancel any existing hostnamesResults DNS goroutine to release resources r.hr.Cancel() r.hr = hr } // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() return len(r.addrs) } // ForEach locks and will call the forEachFunc for every deduplicated address in the list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) { r.Rebuild(preferredRanges) r.RLock() for _, v := range r.addrs { forEach(v, isPreferred(v.Addr(), preferredRanges)) } r.RUnlock() } // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort { if r == nil { return nil } r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() c := make([]netip.AddrPort, len(r.addrs)) for i, v := range r.addrs { c[i] = v } return c } // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() if remote.Addr().Is4() { r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port())) } else { r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port())) } } // CopyCache locks and creates a more human friendly form of the internal address cache. // This may contain duplicates and blocked addresses func (r *RemoteList) CopyCache() *CacheMap { r.RLock() defer r.RUnlock() cm := make(CacheMap) getOrMake := func(vpnIp string) *Cache { c := cm[vpnIp] if c == nil { c = &Cache{ Learned: make([]netip.AddrPort, 0), Reported: make([]netip.AddrPort, 0), Relay: make([]netip.Addr, 0), } cm[vpnIp] = c } return c } for owner, mc := range r.cache { c := getOrMake(owner.String()) if mc.v4 != nil { if mc.v4.learned != nil { c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned)) } for _, a := range mc.v4.reported { c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned)) } for _, a := range mc.v6.reported { c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a)) } } if mc.relay != nil { for _, a := range mc.relay.relay { c.Relay = append(c.Relay, a) } } } return &cm } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list func (r *RemoteList) BlockRemote(bad ViaSender) { if bad.IsRelayed { return } r.Lock() defer r.Unlock() // Check if we already blocked this addr if r.unlockedIsBad(bad.UdpAddr) { return } // We copy here because we are taking something else's memory and we can't trust everything r.badRemotes = append(r.badRemotes, bad.UdpAddr) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { r.RLock() defer r.RUnlock() c := make([]netip.AddrPort, len(r.badRemotes)) for i, v := range r.badRemotes { c[i] = v } return c } // RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) { r.Lock() r.badRemotes = nil r.vpnAddrs = make([]netip.Addr, len(vpnAddrs)) copy(r.vpnAddrs, vpnAddrs) r.Unlock() } // ResetBlockedRemotes locks and clears the blocked remotes list func (r *RemoteList) ResetBlockedRemotes() { r.Lock() r.badRemotes = nil r.Unlock() } // Rebuild locks and generates the deduplicated address list only if there is work to be done // There is generally no reason to call this directly but it is safe to do so func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { r.Lock() defer r.Unlock() // Only rebuild if the cache changed if r.shouldRebuild { r.unlockedCollect() r.shouldRebuild = false } // Always re-sort, preferredRanges can change via HUP r.unlockedSort(preferredRanges) } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { return slices.Contains(r.badRemotes, remote) } // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // Reset the slice c.reported = c.reported[:0] // We can't take their array but we can take their pointers for _, v := range to[:minInt(len(to), MaxRemotes)] { if check(vpnIp, v) { c.reported = append(c.reported, v) } } } func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) // Reset the slice c.relay = c.relay[:0] // We can't take their array but we can take their pointers c.relay = append(c.relay, to[:minInt(len(to), MaxRemotes)]...) } // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // We are doing the easy append because this is rarely called c.reported = append([]*V4AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } } // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // Reset the slice c.reported = c.reported[:0] // We can't take their array but we can take their pointers for _, v := range to[:minInt(len(to), MaxRemotes)] { if check(vpnIp, v) { c.reported = append(c.reported, v) } } } // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // We are doing the easy append because this is rarely called c.reported = append([]*V6AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } } func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} r.cache[ownerVpnIp] = am } // Avoid occupying memory for relay if we never have any if am.relay == nil { am.relay = &cacheRelay{} } return am.relay } // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} r.cache[ownerVpnIp] = am } // Avoid occupying memory for v6 addresses if we never have any if am.v4 == nil { am.v4 = &cacheV4{} } return am.v4 } // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} r.cache[ownerVpnIp] = am } // Avoid occupying memory for v4 addresses if we never have any if am.v6 == nil { am.v6 = &cacheV6{} } return am.v6 } // unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list. // The result of this function can contain duplicates. unlockedSort handles cleaning it. func (r *RemoteList) unlockedCollect() { addrs := r.addrs[:0] relays := r.relays[:0] for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { u := protoV4AddrPortToNetAddrPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { u := protoV4AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } } if c.v6 != nil { if c.v6.learned != nil { u := protoV6AddrPortToNetAddrPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { u := protoV6AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } } if c.relay != nil { for _, v := range c.relay.relay { relays = append(relays, v) } } } dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) { if !r.unlockedIsBad(addr) { addrs = append(addrs, addr) } } } r.addrs = addrs r.relays = relays } // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { // Use a map to deduplicate any relay addresses dedupedRelays := map[netip.Addr]struct{}{} for _, relay := range r.relays { dedupedRelays[relay] = struct{}{} } r.relays = r.relays[:0] for relay := range dedupedRelays { r.relays = append(r.relays, relay) } // Put them in a somewhat consistent order after de-duplication slices.SortFunc(r.relays, func(a, b netip.Addr) int { return a.Compare(b) }) // Now the addrs n := len(r.addrs) if n < 2 { return } lessFunc := func(i, j int) bool { a := r.addrs[i] b := r.addrs[j] // Preferred addresses first aPref := isPreferred(a.Addr(), preferredRanges) bPref := isPreferred(b.Addr(), preferredRanges) switch { case aPref && !bPref: // If i is preferred and j is not, i is less than j return true case !aPref && bPref: // If j is preferred then i is not due to the else, i is not less than j return false default: // Both i an j are either preferred or not, sort within that } // ipv6 addresses 2nd a4 := a.Addr().Is4() b4 := b.Addr().Is4() switch { case a4 == false && b4 == true: // If i is v6 and j is v4, i is less than j return true case a4 == true && b4 == false: // If j is v6 and i is v4, i is not less than j return false case a4 == true && b4 == true: // i and j are both ipv4 aPrivate := a.Addr().IsPrivate() bPrivate := b.Addr().IsPrivate() switch { case !aPrivate && bPrivate: // If i is a public ip (not private) and j is a private ip, i is less then j return true case aPrivate && !bPrivate: // If j is public (not private) then i is private due to the else, i is not less than j return false default: // Both i an j are either public or private, sort within that } default: // Both i an j are either ipv4 or ipv6, sort within that } // lexical order of ips 3rd c := a.Addr().Compare(b.Addr()) if c == 0 { // Ips are the same, Lexical order of ports 4th return a.Port() < b.Port() } // Ip wasn't the same return c < 0 } // Sort it sort.Slice(r.addrs, lessFunc) // Deduplicate a, b := 0, 1 for b < n { if r.addrs[a] != r.addrs[b] { a++ if a != b { r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] } } b++ } r.addrs = r.addrs[:a+1] return } // minInt returns the minimum integer of a or b func minInt(a, b int) int { if a < b { return a } return b } // isPreferred returns true of the ip is contained in the preferredRanges list func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { for _, p := range preferredRanges { if p.Contains(ip) { return true } } return false } ================================================ FILE: remote_list_test.go ================================================ package nebula import ( "encoding/binary" "net/netip" "testing" "github.com/stretchr/testify/assert" ) func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is duped newIp4AndPortFromString("172.18.0.1:10101"), // this is duped newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe newIp4AndPortFromString("172.19.0.1:10101"), newIp4AndPortFromString("172.31.0.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"), []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, func(netip.Addr, *V6AddrPort) bool { return true }, ) rl.unlockedSetRelay( netip.MustParseAddr("0.0.0.1"), []netip.Addr{ netip.MustParseAddr("1::1"), netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1::1"), }, ) rl.Rebuild([]netip.Prefix{}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv6 first, sorted lexically within assert.Equal(t, "[1::1]:1", rl.addrs[0].String()) assert.Equal(t, "[1::1]:2", rl.addrs[1].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String()) // ipv4 last, sorted by public first, then private, lexically within them assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String()) assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String()) assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String()) assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String()) assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String()) assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String()) assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) // Now ensure we can hoist ipv4 up rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv4 first, public then private, lexically within them assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String()) assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String()) assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String()) assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String()) assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String()) assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String()) assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String()) // ipv6 last, sorted by public first, then private, lexically within them assert.Equal(t, "[1::1]:1", rl.addrs[7].String()) assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) // assert relay deduplicated assert.Len(t, rl.relays, 2) assert.Equal(t, "1.2.3.4", rl.relays[0].String()) assert.Equal(t, "1::1", rl.relays[1].String()) // Ensure we can hoist a specific ipv4 range over anything else rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // Preferred ipv4 first assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String()) assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String()) // ipv6 next assert.Equal(t, "[1::1]:1", rl.addrs[2].String()) assert.Equal(t, "[1::1]:2", rl.addrs[3].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String()) // the remaining ipv4 last assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String()) assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String()) assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String()) assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String()) assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) } func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), newIp4AndPortFromString("172.18.0.1:10101"), newIp4AndPortFromString("172.19.0.1:10101"), newIp4AndPortFromString("172.31.0.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true rl.Rebuild([]netip.Prefix{}) } }) ipNet1 := netip.MustParsePrefix("172.17.0.0/16") b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true rl.Rebuild([]netip.Prefix{ipNet1}) } }) ipNet2 := netip.MustParsePrefix("70.0.0.0/8") b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true rl.Rebuild([]netip.Prefix{ipNet2}) } }) ipNet3 := netip.MustParsePrefix("0.0.0.0/0") b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), newIp4AndPortFromString("172.18.0.1:10101"), newIp4AndPortFromString("172.19.0.1:10101"), newIp4AndPortFromString("172.31.0.1:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true rl.Rebuild([]netip.Prefix{}) } }) ipNet1 := netip.MustParsePrefix("172.17.0.0/16") rl.Rebuild([]netip.Prefix{ipNet1}) b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.Rebuild([]netip.Prefix{ipNet1}) } }) ipNet2 := netip.MustParsePrefix("70.0.0.0/8") rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) } }) ipNet3 := netip.MustParsePrefix("0.0.0.0/0") rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } func newIp4AndPortFromString(s string) *V4AddrPort { a := netip.MustParseAddrPort(s) v4Addr := a.Addr().As4() return &V4AddrPort{ Addr: binary.BigEndian.Uint32(v4Addr[:]), Port: uint32(a.Port()), } } func newIp6AndPortFromString(s string) *V6AddrPort { a := netip.MustParseAddrPort(s) v6Addr := a.Addr().As16() return &V6AddrPort{ Hi: binary.BigEndian.Uint64(v6Addr[:8]), Lo: binary.BigEndian.Uint64(v6Addr[8:]), Port: uint32(a.Port()), } } ================================================ FILE: routing/balance.go ================================================ package routing import ( "net/netip" "github.com/slackhq/nebula/firewall" ) // Hashes the packet source and destination port and always returns a positive integer // Based on 'Prospecting for Hash Functions' // - https://nullprogram.com/blog/2018/07/31/ // - https://github.com/skeeto/hash-prospector // [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 func hashPacket(p *firewall.Packet) int { x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) x ^= x >> 16 x *= 0x21f0aaad x ^= x >> 15 x *= 0xd35a2d97 x ^= x >> 15 return int(x) & 0x7FFFFFFF } // For this function to work correctly it requires that the buckets for the gateways have been calculated // If the contract is violated balancing will not work properly and the second return value will return false func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) { hash := hashPacket(fwPacket) for i := range gateways { if hash <= gateways[i].BucketUpperBound() { return gateways[i].Addr(), true } } // If you land here then the buckets for the gateways are not properly calculated // Fallback to random routing and let the caller know return gateways[hash%len(gateways)].Addr(), false } ================================================ FILE: routing/balance_test.go ================================================ package routing import ( "net/netip" "testing" "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" ) func TestPacketsAreBalancedEqually(t *testing.T) { gateways := []Gateway{} gw1Addr := netip.MustParseAddr("1.0.0.1") gw2Addr := netip.MustParseAddr("1.0.0.2") gw3Addr := netip.MustParseAddr("1.0.0.3") gateways = append(gateways, NewGateway(gw1Addr, 1)) gateways = append(gateways, NewGateway(gw2Addr, 1)) gateways = append(gateways, NewGateway(gw3Addr, 1)) CalculateBucketsForGateways(gateways) gw1count := 0 gw2count := 0 gw3count := 0 iterationCount := uint16(65535) for i := uint16(0); i < iterationCount; i++ { packet := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), RemoteAddr: netip.MustParseAddr("10.0.0.1"), LocalPort: i, RemotePort: 65535 - i, Protocol: 6, // TCP Fragment: false, } selectedGw, ok := BalancePacket(&packet, gateways) assert.True(t, ok) switch selectedGw { case gw1Addr: gw1count += 1 case gw2Addr: gw2count += 1 case gw3Addr: gw3count += 1 } } // Assert packets are balanced, allow variation of up to 100 packets per gateway assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count) } func TestPacketsAreBalancedByPriority(t *testing.T) { gateways := []Gateway{} gw1Addr := netip.MustParseAddr("1.0.0.1") gw2Addr := netip.MustParseAddr("1.0.0.2") gateways = append(gateways, NewGateway(gw1Addr, 10)) gateways = append(gateways, NewGateway(gw2Addr, 5)) CalculateBucketsForGateways(gateways) gw1count := 0 gw2count := 0 iterationCount := uint16(65535) for i := uint16(0); i < iterationCount; i++ { packet := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), RemoteAddr: netip.MustParseAddr("10.0.0.1"), LocalPort: i, RemotePort: 65535 - i, Protocol: 6, // TCP Fragment: false, } selectedGw, ok := BalancePacket(&packet, gateways) assert.True(t, ok) switch selectedGw { case gw1Addr: gw1count += 1 case gw2Addr: gw2count += 1 } } iterationCountAsFloat := float32(iterationCount) assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count) assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count) } func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) { gateways := []Gateway{} gw1Addr := netip.MustParseAddr("1.0.0.1") gw2Addr := netip.MustParseAddr("1.0.0.2") gateways = append(gateways, NewGateway(gw1Addr, 10)) gateways = append(gateways, NewGateway(gw2Addr, 5)) iterationCount := uint16(65535) gw1count := 0 gw2count := 0 for i := uint16(0); i < iterationCount; i++ { packet := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), RemoteAddr: netip.MustParseAddr("10.0.0.1"), LocalPort: i, RemotePort: 65535 - i, Protocol: 6, // TCP Fragment: false, } selectedGw, ok := BalancePacket(&packet, gateways) assert.False(t, ok) switch selectedGw { case gw1Addr: gw1count += 1 case gw2Addr: gw2count += 1 } } assert.Equal(t, int(iterationCount), (gw1count + gw2count)) assert.NotEqual(t, 0, gw1count) assert.NotEqual(t, 0, gw2count) } ================================================ FILE: routing/gateway.go ================================================ package routing import ( "fmt" "net/netip" ) const ( // Sentinel value BucketNotCalculated = -1 ) type Gateways []Gateway func (g Gateways) String() string { str := "" for i, gw := range g { str += gw.String() if i < len(g)-1 { str += ", " } } return str } type Gateway struct { addr netip.Addr weight int bucketUpperBound int } func NewGateway(addr netip.Addr, weight int) Gateway { return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated} } func (g *Gateway) BucketUpperBound() int { return g.bucketUpperBound } func (g *Gateway) Addr() netip.Addr { return g.addr } func (g *Gateway) String() string { return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight) } // Divide and round to nearest integer func divideAndRound(v uint64, d uint64) uint64 { var tmp uint64 = v + d/2 return tmp / d } // Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel. // After this function returns each gateway will have a // positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX) func CalculateBucketsForGateways(gateways []Gateway) { var totalWeight int = 0 for i := range gateways { totalWeight += gateways[i].weight } var loopWeight int = 0 for i := range gateways { loopWeight += gateways[i].weight gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1 } } ================================================ FILE: routing/gateway_test.go ================================================ package routing import ( "net/netip" "testing" "github.com/stretchr/testify/assert" ) func TestRebalance3_2Split(t *testing.T) { gateways := []Gateway{} gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10}) gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5}) CalculateBucketsForGateways(gateways) assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2 assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX } func TestRebalanceEqualSplit(t *testing.T) { gateways := []Gateway{} gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1}) CalculateBucketsForGateways(gateways) assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3 assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2 assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX } ================================================ FILE: service/listener.go ================================================ package service import ( "io" "net" ) type tcpListener struct { port uint16 s *Service addr *net.TCPAddr accept chan net.Conn } func (l *tcpListener) Accept() (net.Conn, error) { conn, ok := <-l.accept if !ok { return nil, io.EOF } return conn, nil } func (l *tcpListener) Close() error { l.s.mu.Lock() defer l.s.mu.Unlock() delete(l.s.mu.listeners, uint16(l.addr.Port)) close(l.accept) return nil } // Addr returns the listener's network address. func (l *tcpListener) Addr() net.Addr { return l.addr } ================================================ FILE: service/service.go ================================================ package service import ( "bytes" "context" "errors" "fmt" "log" "math" "net" "net/netip" "strings" "sync" "github.com/slackhq/nebula" "github.com/slackhq/nebula/overlay" "golang.org/x/sync/errgroup" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) const nicID = 1 type Service struct { eg *errgroup.Group control *nebula.Control ipstack *stack.Stack mu struct { sync.Mutex listeners map[uint16]*tcpListener } } func New(control *nebula.Control) (*Service, error) { control.Start() ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) s := Service{ eg: eg, control: control, } s.mu.listeners = map[uint16]*tcpListener{} device, ok := control.Device().(*overlay.UserDevice) if !ok { return nil, errors.New("must be using user device") } s.ipstack = stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, }) sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "") if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) } ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4))) s.ipstack.SetRouteTable([]tcpip.Route{ { Destination: ipv4Subnet, NIC: nicID, }, }) ipNet := device.Networks() pa := tcpip.ProtocolAddress{ AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ PEB: stack.CanBePrimaryEndpoint, // zero value default ConfigType: stack.AddressConfigStatic, // zero value default }); err != nil { return nil, fmt.Errorf("error creating IP: %s", err) } const tcpReceiveBufferSize = 0 const maxInFlightConnectionAttempts = 1024 tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) reader, writer := device.Pipe() go func() { <-ctx.Done() reader.Close() writer.Close() }() // create Goroutines to forward packets between Nebula and Gvisor eg.Go(func() error { buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) for { // this will read exactly one packet n, err := reader.Read(buf) if err != nil { return err } packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(bytes.Clone(buf[:n])), }) linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) if err := ctx.Err(); err != nil { return err } } }) eg.Go(func() error { for { packet := linkEP.ReadContext(ctx) if packet == nil { if err := ctx.Err(); err != nil { return err } continue } bufView := packet.ToView() if _, err := bufView.WriteTo(writer); err != nil { return err } bufView.Release() } }) return &s, nil } func getProtocolNumber(addr netip.Addr) tcpip.NetworkProtocolNumber { if addr.Is6() { return ipv6.ProtocolNumber } return ipv4.ProtocolNumber } // DialContext dials the provided address. func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { switch network { case "udp", "udp4", "udp6": addr, err := net.ResolveUDPAddr(network, address) if err != nil { return nil, err } fullAddr := tcpip.FullAddress{ NIC: nicID, Addr: tcpip.AddrFromSlice(addr.IP), Port: uint16(addr.Port), } num := getProtocolNumber(addr.AddrPort().Addr()) return gonet.DialUDP(s.ipstack, nil, &fullAddr, num) case "tcp", "tcp4", "tcp6": addr, err := net.ResolveTCPAddr(network, address) if err != nil { return nil, err } fullAddr := tcpip.FullAddress{ NIC: nicID, Addr: tcpip.AddrFromSlice(addr.IP), Port: uint16(addr.Port), } num := getProtocolNumber(addr.AddrPort().Addr()) return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, num) default: return nil, fmt.Errorf("unknown network type: %s", network) } } // Dial dials the provided address func (s *Service) Dial(network, address string) (net.Conn, error) { return s.DialContext(context.Background(), network, address) } // Listen listens on the provided address. Currently only TCP with wildcard // addresses are supported. func (s *Service) Listen(network, address string) (net.Listener, error) { if network != "tcp" && network != "tcp4" { return nil, errors.New("only tcp is supported") } addr, err := net.ResolveTCPAddr(network, address) if err != nil { return nil, err } if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) { return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP) } if addr.Port == 0 { return nil, errors.New("specific port required, got 0") } if addr.Port < 0 || addr.Port >= math.MaxUint16 { return nil, fmt.Errorf("invalid port %d", addr.Port) } port := uint16(addr.Port) l := &tcpListener{ port: port, s: s, addr: addr, accept: make(chan net.Conn), } s.mu.Lock() defer s.mu.Unlock() if _, ok := s.mu.listeners[port]; ok { return nil, fmt.Errorf("already listening on port %d", port) } s.mu.listeners[port] = l return l, nil } func (s *Service) Wait() error { return s.eg.Wait() } func (s *Service) Close() error { s.control.Stop() return nil } func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { endpointID := r.ID() s.mu.Lock() defer s.mu.Unlock() l, ok := s.mu.listeners[endpointID.LocalPort] if !ok { r.Complete(true) return } var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) if err != nil { log.Printf("got error creating endpoint %q", err) r.Complete(true) return } r.Complete(false) ep.SocketOptions().SetKeepAlive(true) conn := gonet.NewTCPConn(&wq, ep) l.accept <- conn } ================================================ FILE: service/service_test.go ================================================ package service import ( "bytes" "context" "errors" "net/netip" "os" "testing" "time" "dario.cat/mergo" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" "go.yaml.in/yaml/v3" "golang.org/x/sync/errgroup" ) type m = map[string]any func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } mc := m{ "pki": m{ "ca": string(caB), "cert": string(myPEM), "key": string(myPrivKey), }, //"tun": m{"disabled": true}, "firewall": m{ "outbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, "inbound": []m{{ "proto": "any", "port": "any", "host": "any", }}, }, "timers": m{ "pending_deletion_interval": 2, "connection_alive_interval": 2, }, "handshakes": m{ "try_interval": "200ms", }, } if overrides != nil { err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) if err != nil { panic(err) } mc = overrides } cb, err := yaml.Marshal(mc) if err != nil { panic(err) } var c config.C if err := c.LoadString(string(cb)); err != nil { panic(err) } logger := logrus.New() logger.Out = os.Stdout control, err := nebula.Main(&c, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { panic(err) } s, err := New(control) if err != nil { panic(err) } return s } func TestService(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, }, "listen": m{ "host": "0.0.0.0", "port": 4243, }, }) b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, }, "lighthouse": m{ "hosts": []string{"10.0.0.1"}, "interval": 1, }, }) ln, err := a.Listen("tcp", ":1234") if err != nil { t.Fatal(err) } var eg errgroup.Group eg.Go(func() error { conn, err := ln.Accept() if err != nil { return err } defer conn.Close() t.Log("accepted connection") if _, err := conn.Write([]byte("server msg")); err != nil { return err } t.Log("server: wrote message") data := make([]byte, 100) n, err := conn.Read(data) if err != nil { return err } data = data[:n] if !bytes.Equal(data, []byte("client msg")) { return errors.New("got invalid message from client") } t.Log("server: read message") return conn.Close() }) c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234") if err != nil { t.Fatal(err) } if _, err := c.Write([]byte("client msg")); err != nil { t.Fatal(err) } data := make([]byte, 100) n, err := c.Read(data) if err != nil { t.Fatal(err) } data = data[:n] if !bytes.Equal(data, []byte("server msg")) { t.Fatal("got invalid message from client") } if err := c.Close(); err != nil { t.Fatal(err) } if err := eg.Wait(); err != nil { t.Fatal(err) } } ================================================ FILE: ssh.go ================================================ package nebula import ( "bytes" "encoding/json" "errors" "flag" "fmt" "maps" "net" "net/netip" "os" "reflect" "runtime" "runtime/pprof" "sort" "strconv" "strings" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/sshd" ) type sshListHostMapFlags struct { Json bool Pretty bool ByIndex bool } type sshPrintCertFlags struct { Json bool Pretty bool Raw bool } type sshPrintTunnelFlags struct { Pretty bool } type sshChangeRemoteFlags struct { Address string } type sshCloseTunnelFlags struct { LocalOnly bool } type sshCreateTunnelFlags struct { Address string } type sshDeviceInfoFlags struct { Json bool Pretty bool } func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) { c.RegisterReloadCallback(func(c *config.C) { if c.GetBool("sshd.enabled", false) { sshRun, err := configSSH(l, ssh, c) if err != nil { l.WithError(err).Error("Failed to reconfigure the sshd") ssh.Stop() } if sshRun != nil { go sshRun() } } else { ssh.Stop() } }) } // configSSH reads the ssh info out of the passed-in Config and // updates the passed-in SSHServer. On success, it returns a function // that callers may invoke to run the configured ssh server. On // failure, it returns nil, error. func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) { listen := c.GetString("sshd.listen", "") if listen == "" { return nil, fmt.Errorf("sshd.listen must be provided") } _, port, err := net.SplitHostPort(listen) if err != nil { return nil, fmt.Errorf("invalid sshd.listen address: %s", err) } if port == "22" { return nil, fmt.Errorf("sshd.listen can not use port 22") } hostKeyPathOrKey := c.GetString("sshd.host_key", "") if hostKeyPathOrKey == "" { return nil, fmt.Errorf("sshd.host_key must be provided") } var hostKeyBytes []byte if strings.Contains(hostKeyPathOrKey, "-----BEGIN") { hostKeyBytes = []byte(hostKeyPathOrKey) } else { hostKeyBytes, err = os.ReadFile(hostKeyPathOrKey) if err != nil { return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err) } } err = ssh.SetHostKey(hostKeyBytes) if err != nil { return nil, fmt.Errorf("error while adding sshd.host_key: %s", err) } // Clear existing trusted CAs and authorized keys ssh.ClearTrustedCAs() ssh.ClearAuthorizedKeys() rawCAs := c.GetStringSlice("sshd.trusted_cas", []string{}) for _, caAuthorizedKey := range rawCAs { err := ssh.AddTrustedCA(caAuthorizedKey) if err != nil { l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring") continue } } rawKeys := c.Get("sshd.authorized_users") keys, ok := rawKeys.([]any) if ok { for _, rk := range keys { kDef, ok := rk.(map[string]any) if !ok { l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring") continue } user, ok := kDef["user"].(string) if !ok { l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field") continue } k := kDef["keys"] switch v := k.(type) { case string: err := ssh.AddAuthorizedKey(user, v) if err != nil { l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key") continue } case []any: for _, subK := range v { sk, ok := subK.(string) if !ok { l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key") continue } err := ssh.AddAuthorizedKey(user, sk) if err != nil { l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key") continue } } default: l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood") } } } else { l.Info("no ssh users to authorize") } var runner func() if c.GetBool("sshd.enabled", false) { ssh.Stop() runner = func() { if err := ssh.Run(listen); err != nil { l.WithField("err", err).Warn("Failed to run the SSH server") } } } else { ssh.Stop() } return runner, nil } func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.hostMap, fs, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "list-pending-hostmap", ShortDescription: "List all handshaking hosts", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListHostMap(f.handshakeManager, fs, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "list-lighthouse-addrmap", ShortDescription: "List all lighthouse map entries", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshListHostMapFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshListLighthouseMap(f.lightHouse, fs, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "reload", ShortDescription: "Reloads configuration from disk, same as sending HUP to the process", Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshReload(c, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "start-cpu-profile", ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`", Callback: sshStartCpuProfile, }) ssh.RegisterCommand(&sshd.Command{ Name: "stop-cpu-profile", ShortDescription: "Stops a cpu profile and writes output to the previously provided file", Callback: func(fs any, a []string, w sshd.StringWriter) error { pprof.StopCPUProfile() return w.WriteLine("If a CPU profile was running it is now stopped") }, }) ssh.RegisterCommand(&sshd.Command{ Name: "save-heap-profile", ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`", Callback: sshGetHeapProfile, }) ssh.RegisterCommand(&sshd.Command{ Name: "mutex-profile-fraction", ShortDescription: "Gets or sets runtime.SetMutexProfileFraction", Callback: sshMutexProfileFraction, }) ssh.RegisterCommand(&sshd.Command{ Name: "save-mutex-profile", ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`", Callback: sshGetMutexProfile, }) ssh.RegisterCommand(&sshd.Command{ Name: "log-level", ShortDescription: "Gets or sets the current log level", Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogLevel(l, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "log-format", ShortDescription: "Gets or sets the current log format", Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshLogFormat(l, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "version", ShortDescription: "Prints the currently running version of nebula", Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshVersion(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "device-info", ShortDescription: "Prints information about the network device.", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshDeviceInfoFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json with more information") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshDeviceInfo(f, fs, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} fl.BoolVar(&s.Json, "json", false, "outputs as json") fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json") fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintCert(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", ShortDescription: "Prints json details about a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintTunnel(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "print-relays", ShortDescription: "Prints json details about all relay info", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshPrintRelays(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshChangeRemote(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", ShortDescription: "Closes a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCloseTunnel(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "create-tunnel", ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Flags: func() (*flag.FlagSet, any) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCreateTunnelFlags{} fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ") return fl, &s }, Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshCreateTunnel(f, fs, a, w) }, }) ssh.RegisterCommand(&sshd.Command{ Name: "query-lighthouse", ShortDescription: "Query the lighthouses for the provided vpn address", Help: "This command is asynchronous. Only currently known udp addresses will be printed.", Callback: func(fs any, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, }) } func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { return nil } var hm []ControlHostInfo if fs.ByIndex { hm = listHostMapIndexes(hl) } else { hm = listHostMapHosts(hl) } sort.Slice(hm, func(i, j int) bool { return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0 }) if fs.Json || fs.Pretty { js := json.NewEncoder(w.GetWriter()) if fs.Pretty { js.SetIndent("", " ") } err := js.Encode(hm) if err != nil { return nil } } else { for _, v := range hm { err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs)) if err != nil { return err } } } return nil } func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { return nil } type lighthouseInfo struct { VpnAddr string `json:"vpnAddr"` Addrs *CacheMap `json:"addrs"` } lightHouse.RLock() addrMap := make([]lighthouseInfo, len(lightHouse.addrMap)) x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ VpnAddr: k.String(), Addrs: v.CopyCache(), } x++ } lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0 }) if fs.Json || fs.Pretty { js := json.NewEncoder(w.GetWriter()) if fs.Pretty { js.SetIndent("", " ") } err := js.Encode(addrMap) if err != nil { return nil } } else { for _, v := range addrMap { b, err := json.Marshal(v.Addrs) if err != nil { return err } err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b))) if err != nil { return err } } } return nil } func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { err := w.WriteLine("No path to write profile provided") return err } file, err := os.Create(a[0]) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err } err = pprof.StartCPUProfile(file) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to start cpu profile: %s", err)) return err } err = w.WriteLine(fmt.Sprintf("Started cpu profile, issue stop-cpu-profile to write the output to %s", a)) return err } func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { return w.WriteLine(fmt.Sprintf("%s", ifce.version)) } func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No vpn address was provided") } vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } if !vpnAddr.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } var cm *CacheMap rl := ifce.lightHouse.Query(vpnAddr) if rl != nil { cm = rl.CopyCache() } return json.NewEncoder(w.GetWriter()).Encode(cm) } func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCloseTunnelFlags) if !ok { return nil } if len(a) == 0 { return w.WriteLine("No vpn address was provided") } vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } if !vpnAddr.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } if !flags.LocalOnly { ifce.send( header.CloseTunnel, 0, hostInfo.ConnectionState, hostInfo, []byte{}, make([]byte, 12, 12), make([]byte, mtu), ) } ifce.closeTunnel(hostInfo) return w.WriteLine("Closed") } func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshCreateTunnelFlags) if !ok { return nil } if len(a) == 0 { return w.WriteLine("No vpn address was provided") } vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } if !vpnAddr.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } var addr netip.AddrPort if flags.Address != "" { addr, err = netip.ParseAddrPort(flags.Address) if err != nil { return w.WriteLine("Address could not be parsed") } } hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) if addr.IsValid() { hostInfo.SetRemote(addr) } return w.WriteLine("Created") } func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { flags, ok := fs.(*sshChangeRemoteFlags) if !ok { return nil } if len(a) == 0 { return w.WriteLine("No vpn address was provided") } if flags.Address == "" { return w.WriteLine("No address was provided") } addr, err := netip.ParseAddrPort(flags.Address) if err != nil { return w.WriteLine("Address could not be parsed") } vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } if !vpnAddr.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } hostInfo.SetRemote(addr) return w.WriteLine("Changed") } func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } file, err := os.Create(a[0]) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) return err } err = pprof.WriteHeapProfile(file) if err != nil { err = w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err)) return err } err = w.WriteLine(fmt.Sprintf("Mem profile created at %s", a)) return err } func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { rate := runtime.SetMutexProfileFraction(-1) return w.WriteLine(fmt.Sprintf("Current value: %d", rate)) } newRate, err := strconv.Atoi(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("Invalid argument: %s", a[0])) } oldRate := runtime.SetMutexProfileFraction(newRate) return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate)) } func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine("No path to write profile provided") } file, err := os.Create(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err)) } defer file.Close() mutexProfile := pprof.Lookup("mutex") if mutexProfile == nil { return w.WriteLine("Unable to get pprof.Lookup(\"mutex\")") } err = mutexProfile.WriteTo(file, 0) if err != nil { return w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err)) } return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a)) } func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } level, err := logrus.ParseLevel(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels)) } l.SetLevel(level) return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level)) } func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error { if len(a) == 0 { return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } logFormat := strings.ToLower(a[0]) switch logFormat { case "text": l.Formatter = &logrus.TextFormatter{} case "json": l.Formatter = &logrus.JSONFormatter{} default: return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"}) } return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter))) } func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintCertFlags) if !ok { return nil } cert := ifce.pki.getCertState().GetDefaultCertificate() if len(a) > 0 { vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } if !vpnAddr.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } cert = hostInfo.GetCert().Certificate } if args.Json || args.Pretty { b, err := cert.MarshalJSON() if err != nil { return nil } if args.Pretty { buf := new(bytes.Buffer) err := json.Indent(buf, b, "", " ") b = buf.Bytes() if err != nil { return nil } } return w.WriteBytes(b) } if args.Raw { b, err := cert.MarshalPEM() if err != nil { return nil } return w.WriteBytes(b) } return w.WriteLine(cert.String()) } func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type")) return nil } relays := map[uint32]*HostInfo{} ifce.hostMap.Lock() maps.Copy(relays, ifce.hostMap.Relays) ifce.hostMap.Unlock() type RelayFor struct { Error error Type string State string PeerAddr netip.Addr LocalIndex uint32 RemoteIndex uint32 RelayedThrough []netip.Addr } type RelayOutput struct { NebulaAddr netip.Addr RelayForAddrs []RelayFor } type CmdOutput struct { Relays []*RelayOutput } co := CmdOutput{} enc := json.NewEncoder(w.GetWriter()) if args.Pretty { enc.SetIndent("", " ") } for k, v := range relays { ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")}) continue } for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr) if ok { t := "" switch r.Type { case ForwardingType: t = "forwarding" case TerminalType: t = "terminal" default: t = "unknown" } s := "" switch r.State { case Requested: s = "requested" case Established: s = "established" default: s = "unknown" } rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex rf.PeerAddr = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } ro.RelayForAddrs = append(ro.RelayForAddrs, rf) } } err := enc.Encode(co) if err != nil { return err } return nil } func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error { args, ok := fs.(*sshPrintTunnelFlags) if !ok { return nil } if len(a) == 0 { return w.WriteLine("No vpn address was provided") } vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } if !vpnAddr.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } enc := json.NewEncoder(w.GetWriter()) if args.Pretty { enc.SetIndent("", " ") } return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges())) } func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error { data := struct { Name string `json:"name"` Cidr []netip.Prefix `json:"cidr"` }{ Name: ifce.inside.Name(), Cidr: make([]netip.Prefix, len(ifce.inside.Networks())), } copy(data.Cidr, ifce.inside.Networks()) flags, ok := fs.(*sshDeviceInfoFlags) if !ok { return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) } if flags.Json || flags.Pretty { js := json.NewEncoder(w.GetWriter()) if flags.Pretty { js.SetIndent("", " ") } return js.Encode(data) } else { return w.WriteLine(fmt.Sprintf("name=%v cidr=%v", data.Name, data.Cidr)) } } func sshReload(c *config.C, w sshd.StringWriter) error { err := w.WriteLine("Reloading config") c.ReloadConfig() return err } ================================================ FILE: sshd/command.go ================================================ package sshd import ( "errors" "flag" "fmt" "sort" "strings" "github.com/armon/go-radix" ) // CommandFlags is a function called before help or command execution to parse command line flags // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags type CommandFlags func() (*flag.FlagSet, any) // CommandCallback is the function called when your command should execute. // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved // and handled automatically for you. // a will be any unconsumed arguments, if no Command.Flags was available this will be all the flags passed in. // w is the writer to use when sending messages back to the client. // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user // where appropriate type CommandCallback func(fs any, a []string, w StringWriter) error type Command struct { Name string ShortDescription string Help string Flags CommandFlags Callback CommandCallback } func execCommand(c *Command, args []string, w StringWriter) error { var ( fl *flag.FlagSet fs any ) if c.Flags != nil { fl, fs = c.Flags() if fl != nil { // SetOutput() here in case fl.Parse dumps usage. fl.SetOutput(w.GetWriter()) err := fl.Parse(args) if err != nil { // fl.Parse has dumped error information to the user via the w writer. return err } args = fl.Args() } } return c.Callback(fs, args, w) } func dumpCommands(c *radix.Tree, w StringWriter) { err := w.WriteLine("Available commands:") if err != nil { return } cmds := make([]string, 0) for _, l := range allCommands(c) { cmds = append(cmds, fmt.Sprintf("%s - %s", l.Name, l.ShortDescription)) } sort.Strings(cmds) _ = w.Write(strings.Join(cmds, "\n") + "\n\n") } func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) { cmd, ok := c.Get(sCmd) if !ok { return nil, nil } command, ok := cmd.(*Command) if !ok { return nil, errors.New("failed to cast command") } return command, nil } func matchCommand(c *radix.Tree, cmd string) []string { cmds := make([]string, 0) c.WalkPrefix(cmd, func(found string, v any) bool { cmds = append(cmds, found) return false }) sort.Strings(cmds) return cmds } func allCommands(c *radix.Tree) []*Command { cmds := make([]*Command, 0) c.WalkPrefix("", func(found string, v any) bool { cmd, ok := v.(*Command) if ok { cmds = append(cmds, cmd) } return false }) return cmds } func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) { // Just typed help if len(a) == 0 { dumpCommands(commands, w) return nil } // We are printing a specific commands help text cmd, err := lookupCommand(commands, a[0]) if err != nil { return } if cmd != nil { err = w.WriteLine(fmt.Sprintf("%s - %s", cmd.Name, cmd.ShortDescription)) if err != nil { return err } if cmd.Help != "" { err = w.WriteLine(fmt.Sprintf(" %s", cmd.Help)) if err != nil { return err } } if cmd.Flags != nil { fs, _ := cmd.Flags() if fs != nil { fs.SetOutput(w.GetWriter()) fs.PrintDefaults() } } return nil } err = w.WriteLine("Command not available " + a[0]) if err != nil { return err } return nil } func checkHelpArgs(args []string) bool { for _, a := range args { if a == "-h" || a == "-help" { return true } } return false } ================================================ FILE: sshd/server.go ================================================ package sshd import ( "bytes" "errors" "fmt" "net" "sync" "github.com/armon/go-radix" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) type SSHServer struct { config *ssh.ServerConfig l *logrus.Entry certChecker *ssh.CertChecker // Map of user -> authorized keys trustedKeys map[string]map[string]bool trustedCAs []ssh.PublicKey // List of available commands helpCommand *Command commands *radix.Tree listener net.Listener // Locks the conns/counter to avoid concurrent map access connsLock sync.Mutex conns map[int]*session counter int } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), conns: make(map[int]*session), } cc := ssh.CertChecker{ IsUserAuthority: func(auth ssh.PublicKey) bool { for _, ca := range s.trustedCAs { if bytes.Equal(ca.Marshal(), auth.Marshal()) { return true } } return false }, UserKeyFallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { pk := string(pubKey.Marshal()) fp := ssh.FingerprintSHA256(pubKey) tk, ok := s.trustedKeys[c.User()] if !ok { return nil, fmt.Errorf("unknown user %s", c.User()) } _, ok = tk[pk] if !ok { return nil, fmt.Errorf("unknown public key for %s (%s)", c.User(), fp) } return &ssh.Permissions{ // Record the public key used for authentication. Extensions: map[string]string{ "fp": fp, "user": c.User(), }, }, nil }, } s.config = &ssh.ServerConfig{ PublicKeyCallback: cc.Authenticate, ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"), } s.RegisterCommand(&Command{ Name: "help", ShortDescription: "prints available commands or help for specific usage info", Callback: func(a any, args []string, w StringWriter) error { return helpCallback(s.commands, args, w) }, }) return s, nil } func (s *SSHServer) SetHostKey(hostPrivateKey []byte) error { private, err := ssh.ParsePrivateKey(hostPrivateKey) if err != nil { return fmt.Errorf("failed to parse private key: %s", err) } s.config.AddHostKey(private) return nil } func (s *SSHServer) ClearTrustedCAs() { s.trustedCAs = []ssh.PublicKey{} } func (s *SSHServer) ClearAuthorizedKeys() { s.trustedKeys = make(map[string]map[string]bool) } // AddTrustedCA adds a trusted CA for user certificates func (s *SSHServer) AddTrustedCA(pubKey string) error { pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) if err != nil { return err } s.trustedCAs = append(s.trustedCAs, pk) s.l.WithField("sshKey", pubKey).Info("Trusted CA key") return nil } // AddAuthorizedKey adds an ssh public key for a user func (s *SSHServer) AddAuthorizedKey(user, pubKey string) error { pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubKey)) if err != nil { return err } tk, ok := s.trustedKeys[user] if !ok { tk = make(map[string]bool) s.trustedKeys[user] = tk } tk[string(pk.Marshal())] = true s.l.WithField("sshKey", pubKey).WithField("sshUser", user).Info("Authorized ssh key") return nil } // RegisterCommand adds a command that can be run by a user, by default only `help` is available func (s *SSHServer) RegisterCommand(c *Command) { s.commands.Insert(c.Name, c) } // Run begins listening and accepting connections func (s *SSHServer) Run(addr string) error { var err error s.listener, err = net.Listen("tcp", addr) if err != nil { return err } s.l.WithField("sshListener", addr).Info("SSH server is listening") // Run loops until there is an error s.run() s.closeSessions() s.l.Info("SSH server stopped listening") // We don't return an error because run logs for us return nil } func (s *SSHServer) run() { for { c, err := s.listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { s.l.WithError(err).Warn("Error in listener, shutting down") } return } conn, chans, reqs, err := ssh.NewServerConn(c, s.config) fp := "" if conn != nil { fp = conn.Permissions.Extensions["fp"] } if err != nil { l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) if conn != nil { l = l.WithField("sshUser", conn.User()) conn.Close() } if fp != "" { l = l.WithField("sshFingerprint", fp) } l.Warn("failed to handshake") continue } l := s.l.WithField("sshUser", conn.User()) l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) s.connsLock.Lock() s.counter++ counter := s.counter s.conns[counter] = session s.connsLock.Unlock() go ssh.DiscardRequests(reqs) go func() { <-session.exitChan s.l.WithField("id", counter).Debug("closing conn") s.connsLock.Lock() delete(s.conns, counter) s.connsLock.Unlock() }() } } func (s *SSHServer) Stop() { // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { s.l.WithError(err).Warn("Failed to close the sshd listener") } } } func (s *SSHServer) closeSessions() { s.connsLock.Lock() for _, c := range s.conns { c.Close() } s.connsLock.Unlock() } ================================================ FILE: sshd/session.go ================================================ package sshd import ( "fmt" "sort" "strings" "github.com/anmitsu/go-shlex" "github.com/armon/go-radix" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" "golang.org/x/term" ) type session struct { l *logrus.Entry c *ssh.ServerConn term *term.Terminal commands *radix.Tree exitChan chan bool } func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, c: conn, exitChan: make(chan bool), } s.commands.Insert("logout", &Command{ Name: "logout", ShortDescription: "Ends the current session", Callback: func(a any, args []string, w StringWriter) error { s.Close() return nil }, }) go s.handleChannels(chans) return s } func (s *session) handleChannels(chans <-chan ssh.NewChannel) { for newChannel := range chans { if newChannel.ChannelType() != "session" { s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChannel.Accept() if err != nil { s.l.WithError(err).Warn("could not accept channel") continue } go s.handleRequests(requests, channel) } } func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { for req := range in { var err error switch req.Type { case "shell": if s.term == nil { s.term = s.createTerm(channel) err = req.Reply(true, nil) } else { err = req.Reply(false, nil) } case "pty-req": err = req.Reply(true, nil) case "window-change": err = req.Reply(true, nil) case "exec": var payload = struct{ Value string }{} cErr := ssh.Unmarshal(req.Payload, &payload) if cErr != nil { req.Reply(false, nil) return } req.Reply(true, nil) s.dispatchCommand(payload.Value, &stringWriter{channel}) status := struct{ Status uint32 }{uint32(0)} channel.SendRequest("exit-status", false, ssh.Marshal(status)) channel.Close() return default: s.l.WithField("sshRequest", req.Type).Debug("Rejected unknown request") err = req.Reply(false, nil) } if err != nil { s.l.WithError(err).Info("Error handling ssh session requests") s.Close() return } } } func (s *session) createTerm(channel ssh.Channel) *term.Terminal { term := term.NewTerminal(channel, s.c.User()+"@nebula > ") term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { // key 9 is tab if key == 9 { cmds := matchCommand(s.commands, line) if len(cmds) == 1 { return cmds[0] + " ", len(cmds[0]) + 1, true } sort.Strings(cmds) term.Write([]byte(strings.Join(cmds, "\n") + "\n\n")) } return "", 0, false } go s.handleInput(channel) return term } func (s *session) handleInput(channel ssh.Channel) { defer s.Close() w := &stringWriter{w: s.term} for { line, err := s.term.ReadLine() if err != nil { break } s.dispatchCommand(line, w) } } func (s *session) dispatchCommand(line string, w StringWriter) { args, err := shlex.Split(line, true) if err != nil { return } if len(args) == 0 { dumpCommands(s.commands, w) return } c, err := lookupCommand(s.commands, args[0]) if err != nil { return } if c == nil { err := w.WriteLine(fmt.Sprintf("did not understand: %s", line)) _ = err dumpCommands(s.commands, w) return } if checkHelpArgs(args) { s.dispatchCommand(fmt.Sprintf("%s %s", "help", c.Name), w) return } _ = execCommand(c, args[1:], w) return } func (s *session) Close() { s.c.Close() s.exitChan <- true } ================================================ FILE: sshd/writer.go ================================================ package sshd import "io" type StringWriter interface { WriteLine(string) error Write(string) error WriteBytes([]byte) error GetWriter() io.Writer } type stringWriter struct { w io.Writer } func (w *stringWriter) WriteLine(s string) error { return w.Write(s + "\n") } func (w *stringWriter) Write(s string) error { _, err := w.w.Write([]byte(s)) return err } func (w *stringWriter) WriteBytes(b []byte) error { _, err := w.w.Write(b) return err } func (w *stringWriter) GetWriter() io.Writer { return w.w } ================================================ FILE: stats.go ================================================ package nebula import ( "errors" "fmt" "log" "net" "net/http" "runtime" "strconv" "time" graphite "github.com/cyberdelia/go-metrics-graphite" mp "github.com/nbrownus/go-metrics-prometheus" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) // startStats initializes stats from config. On success, if any further work // is needed to serve stats, it returns a func to handle that work. If no // work is needed, it'll return nil. On failure, it returns nil, error. func startStats(l *logrus.Logger, c *config.C, buildVersion string, configTest bool) (func(), error) { mType := c.GetString("stats.type", "") if mType == "" || mType == "none" { return nil, nil } interval := c.GetDuration("stats.interval", 0) if interval == 0 { return nil, fmt.Errorf("stats.interval was an invalid duration: %s", c.GetString("stats.interval", "")) } var startFn func() switch mType { case "graphite": err := startGraphiteStats(l, interval, c, configTest) if err != nil { return nil, err } case "prometheus": var err error startFn, err = startPrometheusStats(l, interval, c, buildVersion, configTest) if err != nil { return nil, err } default: return nil, fmt.Errorf("stats.type was not understood: %s", mType) } metrics.RegisterDebugGCStats(metrics.DefaultRegistry) metrics.RegisterRuntimeMemStats(metrics.DefaultRegistry) go metrics.CaptureDebugGCStats(metrics.DefaultRegistry, interval) go metrics.CaptureRuntimeMemStats(metrics.DefaultRegistry, interval) return startFn, nil } func startGraphiteStats(l *logrus.Logger, i time.Duration, c *config.C, configTest bool) error { proto := c.GetString("stats.protocol", "tcp") host := c.GetString("stats.host", "") if host == "" { return errors.New("stats.host can not be empty") } prefix := c.GetString("stats.prefix", "nebula") addr, err := net.ResolveTCPAddr(proto, host) if err != nil { return fmt.Errorf("error while setting up graphite sink: %s", err) } if !configTest { l.Infof("Starting graphite. Interval: %s, prefix: %s, addr: %s", i, prefix, addr) go graphite.Graphite(metrics.DefaultRegistry, i, prefix, addr) } return nil } func startPrometheusStats(l *logrus.Logger, i time.Duration, c *config.C, buildVersion string, configTest bool) (func(), error) { namespace := c.GetString("stats.namespace", "") subsystem := c.GetString("stats.subsystem", "") listen := c.GetString("stats.listen", "") if listen == "" { return nil, fmt.Errorf("stats.listen should not be empty") } path := c.GetString("stats.path", "") if path == "" { return nil, fmt.Errorf("stats.path should not be empty") } pr := prometheus.NewRegistry() pClient := mp.NewPrometheusProvider(metrics.DefaultRegistry, namespace, subsystem, pr, i) if !configTest { go pClient.UpdatePrometheusMetrics() } // Export our version information as labels on a static gauge g := prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: namespace, Subsystem: subsystem, Name: "info", Help: "Version information for the Nebula binary", ConstLabels: prometheus.Labels{ "version": buildVersion, "goversion": runtime.Version(), "boringcrypto": strconv.FormatBool(boringEnabled()), }, }) pr.MustRegister(g) g.Set(1) var startFn func() if !configTest { startFn = func() { l.Infof("Prometheus stats listening on %s at %s", listen, path) http.Handle(path, promhttp.HandlerFor(pr, promhttp.HandlerOpts{ErrorLog: l})) log.Fatal(http.ListenAndServe(listen, nil)) } } return startFn, nil } ================================================ FILE: test/assert.go ================================================ package test import ( "fmt" "net/netip" "reflect" "testing" "time" "unsafe" "github.com/stretchr/testify/assert" ) // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory // There is currently a special case for `time.loc` (as this code traverses into unexported fields) func AssertDeepCopyEqual(t *testing.T, a any, b any) { v1 := reflect.ValueOf(a) v2 := reflect.ValueOf(b) if !assert.Equal(t, v1.Type(), v2.Type()) { return } traverseDeepCopy(t, v1, v2, v1.Type().String()) } func traverseDeepCopy(t *testing.T, v1 reflect.Value, v2 reflect.Value, name string) bool { if v1.Type() == v2.Type() && v1.Type() == reflect.TypeOf(netip.Addr{}) { // Ignore netip.Addr types since they reuse an interned global value return false } switch v1.Kind() { case reflect.Array: for i := 0; i < v1.Len(); i++ { if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) { return false } } return true case reflect.Slice: if v1.IsNil() || v2.IsNil() { return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil %+v, %+v", name, v1, v2) } if !assert.Equal(t, v1.Len(), v2.Len(), "%s did not have the same length", name) { return false } // A slice with cap 0 if v1.Cap() != 0 && !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same slice %v == %v", name, v1.Pointer(), v2.Pointer()) { return false } v1c := v1.Cap() v2c := v2.Cap() if v1c > 0 && v2c > 0 && v1.Slice(0, v1c).Slice(v1c-1, v1c-1).Pointer() == v2.Slice(0, v2c).Slice(v2c-1, v2c-1).Pointer() { return assert.Fail(t, "", "%s share some underlying memory", name) } for i := 0; i < v1.Len(); i++ { if !traverseDeepCopy(t, v1.Index(i), v2.Index(i), fmt.Sprintf("%s[%v]", name, i)) { return false } } return true case reflect.Interface: if v1.IsNil() || v2.IsNil() { return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name) } return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name) case reflect.Ptr: local := reflect.ValueOf(time.Local).Pointer() if local == v1.Pointer() && local == v2.Pointer() { return true } if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s points to the same memory", name) { return false } return traverseDeepCopy(t, v1.Elem(), v2.Elem(), name) case reflect.Struct: for i, n := 0, v1.NumField(); i < n; i++ { if !traverseDeepCopy(t, v1.Field(i), v2.Field(i), name+"."+v1.Type().Field(i).Name) { return false } } return true case reflect.Map: if v1.IsNil() || v2.IsNil() { return assert.Equal(t, v1.IsNil(), v2.IsNil(), "%s are not both nil", name) } if !assert.Equal(t, v1.Len(), v2.Len(), "%s are not the same length", name) { return false } if !assert.NotEqual(t, v1.Pointer(), v2.Pointer(), "%s point to the same memory", name) { return false } for _, k := range v1.MapKeys() { val1 := v1.MapIndex(k) val2 := v2.MapIndex(k) if !assert.True(t, val1.IsValid(), "%s is an invalid key in %s", k, name) { return false } if !assert.True(t, val2.IsValid(), "%s is an invalid key in %s", k, name) { return false } if !traverseDeepCopy(t, val1, val2, name+fmt.Sprintf("%s[%s]", name, k)) { return false } } return true default: if v1.CanInterface() && v2.CanInterface() { return assert.Equal(t, v1.Interface(), v2.Interface(), "%s was not equal", name) } e1 := reflect.NewAt(v1.Type(), unsafe.Pointer(v1.UnsafeAddr())).Elem().Interface() e2 := reflect.NewAt(v2.Type(), unsafe.Pointer(v2.UnsafeAddr())).Elem().Interface() return assert.Equal(t, e1, e2, "%s (unexported) was not equal", name) } } ================================================ FILE: test/logger.go ================================================ package test import ( "io" "os" "github.com/sirupsen/logrus" ) func NewLogger() *logrus.Logger { l := logrus.New() v := os.Getenv("TEST_LOGS") if v == "" { l.SetOutput(io.Discard) return l } switch v { case "2": l.SetLevel(logrus.DebugLevel) case "3": l.SetLevel(logrus.TraceLevel) default: l.SetLevel(logrus.InfoLevel) } return l } ================================================ FILE: test/tun.go ================================================ package test import ( "errors" "io" "net/netip" "github.com/slackhq/nebula/routing" ) type NoopTun struct{} func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { return routing.Gateways{} } func (NoopTun) Activate() error { return nil } func (NoopTun) Networks() []netip.Prefix { return []netip.Prefix{} } func (NoopTun) Name() string { return "noop" } func (NoopTun) Read([]byte) (int, error) { return 0, nil } func (NoopTun) Write([]byte) (int, error) { return 0, nil } func (NoopTun) SupportsMultiqueue() bool { return false } func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, errors.New("unsupported") } func (NoopTun) Close() error { return nil } ================================================ FILE: timeout.go ================================================ package nebula import ( "sync" "time" ) // How many timer objects should be cached const timerCacheMax = 50000 type TimerWheel[T any] struct { // Current tick current int // Cheat on finding the length of the wheel wheelLen int // Last time we ticked, since we are lazy ticking lastTick *time.Time // Durations of a tick and the entire wheel tickDuration time.Duration wheelDuration time.Duration // The actual wheel which is just a set of singly linked lists, head/tail pointers wheel []*TimeoutList[T] // Singly linked list of items that have timed out of the wheel expired *TimeoutList[T] // Item cache to avoid garbage collect itemCache *TimeoutItem[T] itemsCached int } type LockingTimerWheel[T any] struct { m sync.Mutex t *TimerWheel[T] } // TimeoutList Represents a tick in the wheel type TimeoutList[T any] struct { Head *TimeoutItem[T] Tail *TimeoutItem[T] } // TimeoutItem Represents an item within a tick type TimeoutItem[T any] struct { Item T Next *TimeoutItem[T] } // NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values // Purge must be called once per entry to actually remove anything // The TimerWheel does not handle concurrency on its own. // Locks around access to it must be used if multiple routines are manipulating it. func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] { //TODO provide an error //if min >= max { // return nil //} // Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full // max duration, even if our current tick is at the maximum position and the next item to be added is at maximum // timeout wLen := int((max / min) + 2) tw := TimerWheel[T]{ wheelLen: wLen, wheel: make([]*TimeoutList[T], wLen), tickDuration: min, wheelDuration: max, expired: &TimeoutList[T]{}, } for i := range tw.wheel { tw.wheel[i] = &TimeoutList[T]{} } return &tw } // NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] { return &LockingTimerWheel[T]{ t: NewTimerWheel[T](min, max), } } // Add will add an item to the wheel in its proper timeout. // Caller should Advance the wheel prior to ensure the proper slot is used. func (tw *TimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] { i := tw.findWheel(timeout) // Try to fetch off the cache ti := tw.itemCache if ti != nil { tw.itemCache = ti.Next tw.itemsCached-- ti.Next = nil } else { ti = &TimeoutItem[T]{} } // Relink and return ti.Item = v if tw.wheel[i].Tail == nil { tw.wheel[i].Head = ti tw.wheel[i].Tail = ti } else { tw.wheel[i].Tail.Next = ti tw.wheel[i].Tail = ti } return ti } // Purge removes and returns the first available expired item from the wheel and the 2nd argument is true. // If no item is available then an empty T is returned and the 2nd argument is false. func (tw *TimerWheel[T]) Purge() (T, bool) { if tw.expired.Head == nil { var na T return na, false } ti := tw.expired.Head tw.expired.Head = ti.Next if tw.expired.Head == nil { tw.expired.Tail = nil } // Clear out the items references ti.Next = nil // Maybe cache it for later if tw.itemsCached < timerCacheMax { ti.Next = tw.itemCache tw.itemCache = ti tw.itemsCached++ } return ti.Item, true } // findWheel find the next position in the wheel for the provided timeout given the current tick func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) { if timeout < tw.tickDuration { // Can't track anything below the set resolution timeout = tw.tickDuration } else if timeout > tw.wheelDuration { // We aren't handling timeouts greater than the wheels duration timeout = tw.wheelDuration } // Find the next highest, rounding up tick := int(((timeout - 1) / tw.tickDuration) + 1) // Add another tick since the current tick may almost be over then map it to the wheel from our // current position tick += tw.current + 1 if tick >= tw.wheelLen { tick -= tw.wheelLen } return tick } // Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items // passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely. func (tw *TimerWheel[T]) Advance(now time.Time) { if tw.lastTick == nil { tw.lastTick = &now } // We want to round down ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration) adv := ticks if ticks > tw.wheelLen { ticks = tw.wheelLen } for i := 0; i < ticks; i++ { tw.current++ if tw.current >= tw.wheelLen { tw.current = 0 } if tw.wheel[tw.current].Head != nil { // We need to append the expired items as to not starve evicting the oldest ones if tw.expired.Tail == nil { tw.expired.Head = tw.wheel[tw.current].Head tw.expired.Tail = tw.wheel[tw.current].Tail } else { tw.expired.Tail.Next = tw.wheel[tw.current].Head tw.expired.Tail = tw.wheel[tw.current].Tail } tw.wheel[tw.current].Head = nil tw.wheel[tw.current].Tail = nil } } // Advance the tick based on duration to avoid losing some accuracy newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv)) tw.lastTick = &newTick } func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] { lw.m.Lock() defer lw.m.Unlock() return lw.t.Add(v, timeout) } func (lw *LockingTimerWheel[T]) Purge() (T, bool) { lw.m.Lock() defer lw.m.Unlock() return lw.t.Purge() } func (lw *LockingTimerWheel[T]) Advance(now time.Time) { lw.m.Lock() defer lw.m.Unlock() lw.t.Advance(now) } ================================================ FILE: timeout_test.go ================================================ package nebula import ( "net/netip" "testing" "time" "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" ) func TestNewTimerWheel(t *testing.T) { // Make sure we get an object we expect tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Equal(t, 12, tw.wheelLen) assert.Equal(t, 0, tw.current) assert.Nil(t, tw.lastTick) assert.Equal(t, time.Second*1, tw.tickDuration) assert.Equal(t, time.Second*10, tw.wheelDuration) assert.Len(t, tw.wheel, 12) // Assert the math is correct tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10) assert.Equal(t, 5, tw.wheelLen) tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10) assert.Equal(t, 7, tw.wheelLen) // Test empty purge of non nil items i, ok := tw.Purge() assert.Equal(t, firewall.Packet{}, i) assert.False(t, ok) // Test empty purges of nil items tw2 := NewTimerWheel[*int](time.Second, time.Second*10) i2, ok := tw2.Purge() assert.Nil(t, i2) assert.False(t, ok) } func TestTimerWheel_findWheel(t *testing.T) { tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Len(t, tw.wheel, 12) // Current + tick + 1 since we don't know how far into current we are assert.Equal(t, 2, tw.findWheel(time.Second*1)) // Scale up to min duration assert.Equal(t, 2, tw.findWheel(time.Millisecond*1)) // Make sure we hit that last index assert.Equal(t, 11, tw.findWheel(time.Second*10)) // Scale down to max duration assert.Equal(t, 11, tw.findWheel(time.Second*11)) tw.current = 1 // Make sure we account for the current position properly assert.Equal(t, 3, tw.findWheel(time.Second*1)) assert.Equal(t, 0, tw.findWheel(time.Second*10)) } func TestTimerWheel_Add(t *testing.T) { tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) fp1 := firewall.Packet{} tw.Add(fp1, time.Second*1) // Make sure we set head and tail properly assert.NotNil(t, tw.wheel[2]) assert.Equal(t, fp1, tw.wheel[2].Head.Item) assert.Nil(t, tw.wheel[2].Head.Next) assert.Equal(t, fp1, tw.wheel[2].Tail.Item) assert.Nil(t, tw.wheel[2].Tail.Next) // Make sure we only modify head fp2 := firewall.Packet{} tw.Add(fp2, time.Second*1) assert.Equal(t, fp2, tw.wheel[2].Head.Item) assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item) assert.Equal(t, fp1, tw.wheel[2].Tail.Item) assert.Nil(t, tw.wheel[2].Tail.Next) // Make sure we use free'd items first tw.itemCache = &TimeoutItem[firewall.Packet]{} tw.itemsCached = 1 tw.Add(fp2, time.Second*1) assert.Nil(t, tw.itemCache) assert.Equal(t, 0, tw.itemsCached) // Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel for min := time.Duration(1); min < 100; min++ { for max := min; max < 100; max++ { tw = NewTimerWheel[firewall.Packet](min, max) for current := 0; current < tw.wheelLen; current++ { tw.current = current for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ { tick := tw.findWheel(timeout) if tick >= tw.wheelLen { t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick) } } } } } } func TestTimerWheel_Purge(t *testing.T) { // First advance should set the lastTick and do nothing else tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Nil(t, tw.lastTick) tw.Advance(time.Now()) assert.NotNil(t, tw.lastTick) assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ {LocalAddr: netip.MustParseAddr("0.0.0.1")}, {LocalAddr: netip.MustParseAddr("0.0.0.2")}, {LocalAddr: netip.MustParseAddr("0.0.0.3")}, {LocalAddr: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) tw.Add(fps[1], time.Second*1) tw.Add(fps[2], time.Second*2) tw.Add(fps[3], time.Second*2) ta := time.Now().Add(time.Second * 3) lastTick := *tw.lastTick tw.Advance(ta) assert.Equal(t, 3, tw.current) assert.True(t, tw.lastTick.After(lastTick)) // Make sure we get all 4 packets back for i := range 4 { p, has := tw.Purge() assert.True(t, has) assert.Equal(t, fps[i], p) } // Make sure there aren't any leftover _, ok := tw.Purge() assert.False(t, ok) assert.Nil(t, tw.expired.Head) assert.Nil(t, tw.expired.Tail) // Make sure we cached the free'd items assert.Equal(t, 4, tw.itemsCached) ci := tw.itemCache for range 4 { assert.NotNil(t, ci) ci = ci.Next } assert.Nil(t, ci) // Let's make sure we roll over properly ta = ta.Add(time.Second * 5) tw.Advance(ta) assert.Equal(t, 8, tw.current) ta = ta.Add(time.Second * 2) tw.Advance(ta) assert.Equal(t, 10, tw.current) ta = ta.Add(time.Second * 1) tw.Advance(ta) assert.Equal(t, 11, tw.current) ta = ta.Add(time.Second * 1) tw.Advance(ta) assert.Equal(t, 0, tw.current) } ================================================ FILE: udp/conn.go ================================================ package udp import ( "net/netip" "github.com/slackhq/nebula/config" ) const MTU = 9001 type EncReader func( addr netip.AddrPort, payload []byte, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error } type NoopConn struct{} func (NoopConn) Rebind() error { return nil } func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } func (NoopConn) ListenOut(_ EncReader) { return } func (NoopConn) SupportsMultipleReaders() bool { return false } func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) { return } func (NoopConn) Close() error { return nil } ================================================ FILE: udp/errors.go ================================================ package udp import "errors" var ErrInvalidIPv6RemoteForSocket = errors.New("listener is IPv4, but writing to IPv6 remote") ================================================ FILE: udp/udp_android.go ================================================ //go:build !e2e_testing // +build !e2e_testing package udp import ( "fmt" "net" "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { if multi { var controlErr error err := c.Control(func(fd uintptr) { if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) return } }) if err != nil { return err } if controlErr != nil { return controlErr } } return nil }, } } func (u *GenericConn) Rebind() error { return nil } ================================================ FILE: udp/udp_bsd.go ================================================ //go:build (openbsd || freebsd) && !e2e_testing // +build openbsd freebsd // +build !e2e_testing package udp // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig import ( "fmt" "net" "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { if multi { var controlErr error err := c.Control(func(fd uintptr) { if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) return } }) if err != nil { return err } if controlErr != nil { return controlErr } } return nil }, } } func (u *GenericConn) Rebind() error { return nil } ================================================ FILE: udp/udp_darwin.go ================================================ //go:build !e2e_testing // +build !e2e_testing package udp import ( "context" "encoding/binary" "errors" "fmt" "net" "net/netip" "syscall" "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) type StdConn struct { *net.UDPConn isV4 bool sysFd uintptr l *logrus.Logger } var _ Conn = &StdConn{} func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { return nil, err } if uc, ok := pc.(*net.UDPConn); ok { c := &StdConn{UDPConn: uc, l: l} rc, err := uc.SyscallConn() if err != nil { return nil, fmt.Errorf("failed to open udp socket: %w", err) } err = rc.Control(func(fd uintptr) { c.sysFd = fd }) if err != nil { return nil, fmt.Errorf("failed to get udp fd: %w", err) } la, err := c.LocalAddr() if err != nil { return nil, err } c.isV4 = la.Addr().Is4() return c, nil } return nil, fmt.Errorf("unexpected PacketConn: %T %#v", pc, pc) } func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { if multi { var controlErr error err := c.Control(func(fd uintptr) { if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) return } }) if err != nil { return err } if controlErr != nil { return controlErr } } return nil }, } } //go:linkname sendto golang.org/x/sys/unix.sendto //go:noescape func sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen int32) (err error) func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { var sa unsafe.Pointer var addrLen int32 if u.isV4 { if ap.Addr().Is6() { return ErrInvalidIPv6RemoteForSocket } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET rsa.Addr = ap.Addr().As4() binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) sa = unsafe.Pointer(&rsa) addrLen = syscall.SizeofSockaddrInet4 } else { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 rsa.Addr = ap.Addr().As16() binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) sa = unsafe.Pointer(&rsa) addrLen = syscall.SizeofSockaddrInet6 } // Golang stdlib doesn't handle EAGAIN correctly in some situations so we do writes ourselves // See https://github.com/golang/go/issues/73919 for { //_, _, err := unix.Syscall6(unix.SYS_SENDTO, u.sysFd, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), 0, sa, addrLen) err := sendto(int(u.sysFd), b, 0, sa, addrLen) if err == nil { // Written, get out before the error handling return nil } if errors.Is(err, syscall.EINTR) { // Write was interrupted, retry continue } if errors.Is(err, syscall.EAGAIN) { return &net.OpError{Op: "sendto", Err: unix.EWOULDBLOCK} } if errors.Is(err, syscall.EBADF) { return net.ErrClosed } return &net.OpError{Op: "sendto", Err: err} } } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: addr, ok := netip.AddrFromSlice(v.IP) if !ok { return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) } return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } func (u *StdConn) ReloadConfig(c *config.C) { // TODO } func NewUDPStatsEmitter(udpConns []Conn) func() { // No UDP stats for non-linux return func() {} } func (u *StdConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) for { // Just read one packet at a time n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } u.l.WithError(err).Error("unexpected udp socket receive error") } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } func (u *StdConn) SupportsMultipleReaders() bool { return false } func (u *StdConn) Rebind() error { var err error if u.isV4 { err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, 0) } else { err = syscall.SetsockoptInt(int(u.sysFd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, 0) } if err != nil { u.l.WithError(err).Error("Failed to rebind udp socket") } return nil } ================================================ FILE: udp/udp_generic.go ================================================ //go:build (!linux || android) && !e2e_testing && !darwin // +build !linux android // +build !e2e_testing // +build !darwin // udp_generic implements the nebula UDP interface in pure Go stdlib. This // means it can be used on platforms like Darwin and Windows. package udp import ( "context" "errors" "fmt" "net" "net/netip" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) type GenericConn struct { *net.UDPConn l *logrus.Logger } var _ Conn = &GenericConn{} func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { return nil, err } if uc, ok := pc.(*net.UDPConn); ok { return &GenericConn{UDPConn: uc, l: l}, nil } return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { _, err := u.UDPConn.WriteToUDPAddrPort(b, addr) return err } func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: addr, ok := netip.AddrFromSlice(v.IP) if !ok { return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) } return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } func (u *GenericConn) ReloadConfig(c *config.C) { } func NewUDPStatsEmitter(udpConns []Conn) func() { // No UDP stats for non-linux return func() {} } type rawMessage struct { Len uint32 } func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) var lastRecvErr time.Time for { // Just read one packet at a time n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() u.l.WithError(err).Warn("unexpected udp socket receive error") } continue } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } func (u *GenericConn) SupportsMultipleReaders() bool { return false } ================================================ FILE: udp/udp_linux.go ================================================ //go:build !android && !e2e_testing // +build !android,!e2e_testing package udp import ( "encoding/binary" "fmt" "net" "net/netip" "syscall" "unsafe" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/unix" ) type StdConn struct { sysFd int isV4 bool l *logrus.Logger batch int } func maybeIPV4(ip net.IP) (net.IP, bool) { ip4 := ip.To4() if ip4 != nil { return ip4, true } return ip, false } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 if ip.Is4() { af = unix.AF_INET } syscall.ForkLock.RLock() fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) if err == nil { unix.CloseOnExec(fd) } syscall.ForkLock.RUnlock() if err != nil { unix.Close(fd) return nil, fmt.Errorf("unable to open socket: %s", err) } if multi { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) } } var sa unix.Sockaddr if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} sa4.Addr = ip.As4() sa = sa4 } else { sa6 := &unix.SockaddrInet6{Port: port} sa6.Addr = ip.As16() sa = sa6 } if err = unix.Bind(fd, sa); err != nil { return nil, fmt.Errorf("unable to bind to socket: %s", err) } return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) SupportsMultipleReaders() bool { return true } func (u *StdConn) Rebind() error { return nil } func (u *StdConn) SetRecvBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { return netip.AddrPort{}, err } switch sa := sa.(type) { case *unix.SockaddrInet4: return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil case *unix.SockaddrInet6: return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil default: return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) } } func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle } for { n, err := read(msgs) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMSG, uintptr(u.sysFd), uintptr(unsafe.Pointer(&(msgs[0].Hdr))), 0, 0, 0, 0, ) if err != 0 { return 0, &net.OpError{Op: "recvmsg", Err: err} } msgs[0].Len = uint32(n) return 1, nil } } func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { for { n, _, err := unix.Syscall6( unix.SYS_RECVMMSG, uintptr(u.sysFd), uintptr(unsafe.Pointer(&msgs[0])), uintptr(len(msgs)), unix.MSG_WAITFORONE, 0, 0, ) if err != 0 { return 0, &net.OpError{Op: "recvmmsg", Err: err} } return int(n), nil } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { return u.writeTo4(b, ip) } return u.writeTo6(b, ip) } func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 rsa.Addr = ip.Addr().As16() binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( unix.SYS_SENDTO, uintptr(u.sysFd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(0), uintptr(unsafe.Pointer(&rsa)), uintptr(unix.SizeofSockaddrInet6), ) if err != 0 { return &net.OpError{Op: "sendto", Err: err} } return nil } } func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { if !ip.Addr().Is4() { return ErrInvalidIPv6RemoteForSocket } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET rsa.Addr = ip.Addr().As4() binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) for { _, _, err := unix.Syscall6( unix.SYS_SENDTO, uintptr(u.sysFd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(0), uintptr(unsafe.Pointer(&rsa)), uintptr(unix.SizeofSockaddrInet4), ) if err != 0 { return &net.OpError{Op: "sendto", Err: err} } return nil } } func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { err := u.SetRecvBuffer(b) if err == nil { s, err := u.GetRecvBuffer() if err == nil { u.l.WithField("size", s).Info("listen.read_buffer was set") } else { u.l.WithError(err).Warn("Failed to get listen.read_buffer") } } else { u.l.WithError(err).Error("Failed to set listen.read_buffer") } } b = c.GetInt("listen.write_buffer", 0) if b > 0 { err := u.SetSendBuffer(b) if err == nil { s, err := u.GetSendBuffer() if err == nil { u.l.WithField("size", s).Info("listen.write_buffer was set") } else { u.l.WithError(err).Warn("Failed to get listen.write_buffer") } } else { u.l.WithError(err).Error("Failed to set listen.write_buffer") } } b = c.GetInt("listen.so_mark", 0) s, err := u.GetSoMark() if b > 0 || (err == nil && s != 0) { err := u.SetSoMark(b) if err == nil { s, err := u.GetSoMark() if err == nil { u.l.WithField("mark", s).Info("listen.so_mark was set") } else { u.l.WithError(err).Warn("Failed to get listen.so_mark") } } else { u.l.WithError(err).Error("Failed to set listen.so_mark") } } } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) if err != 0 { return err } return nil } func (u *StdConn) Close() error { return syscall.Close(u.sysFd) } func NewUDPStatsEmitter(udpConns []Conn) func() { // Check if our kernel supports SO_MEMINFO before registering the gauges var udpGauges [][unix.SK_MEMINFO_VARS]metrics.Gauge var meminfo [unix.SK_MEMINFO_VARS]uint32 if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { udpGauges = make([][unix.SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) for i := range udpConns { udpGauges[i] = [unix.SK_MEMINFO_VARS]metrics.Gauge{ metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil), metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil), } } } return func() { for i, gauges := range udpGauges { if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { for j := 0; j < unix.SK_MEMINFO_VARS; j++ { gauges[j].Update(int64(meminfo[j])) } } } } } ================================================ FILE: udp/udp_linux_32.go ================================================ //go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing // +build linux // +build 386 amd64p32 arm mips mipsle // +build !android // +build !e2e_testing package udp import ( "golang.org/x/sys/unix" ) type iovec struct { Base *byte Len uint32 } type msghdr struct { Name *byte Namelen uint32 Iov *iovec Iovlen uint32 Control *byte Controllen uint32 Flags int32 } type rawMessage struct { Hdr msghdr Len uint32 } func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) for i := range msgs { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iovlen = uint32(len(vs)) msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) } return msgs, buffers, names } ================================================ FILE: udp/udp_linux_64.go ================================================ //go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing // +build linux // +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64 // +build !android // +build !e2e_testing package udp import ( "golang.org/x/sys/unix" ) type iovec struct { Base *byte Len uint64 } type msghdr struct { Name *byte Namelen uint32 Pad0 [4]byte Iov *iovec Iovlen uint64 Control *byte Controllen uint64 Flags int32 Pad1 [4]byte } type rawMessage struct { Hdr msghdr Len uint32 Pad0 [4]byte } func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) for i := range msgs { buffers[i] = make([]byte, MTU) names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, } msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iovlen = uint64(len(vs)) msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) } return msgs, buffers, names } ================================================ FILE: udp/udp_netbsd.go ================================================ //go:build !e2e_testing // +build !e2e_testing package udp // FreeBSD support is primarily implemented in udp_generic, besides NewListenConfig import ( "fmt" "net" "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { if multi { var controlErr error err := c.Control(func(fd uintptr) { if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { controlErr = fmt.Errorf("SO_REUSEPORT failed: %v", err) return } }) if err != nil { return err } if controlErr != nil { return controlErr } } return nil }, } } func (u *GenericConn) Rebind() error { return nil } ================================================ FILE: udp/udp_rio_windows.go ================================================ //go:build !e2e_testing // +build !e2e_testing // Inspired by https://git.zx2c4.com/wireguard-go/tree/conn/bind_windows.go package udp import ( "errors" "fmt" "io" "net" "net/netip" "sync" "sync/atomic" "syscall" "time" "unsafe" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) // Assert we meet the standard conn interface var _ Conn = &RIOConn{} //go:linkname procyield runtime.procyield func procyield(cycles uint32) const ( packetsPerRing = 1024 bytesPerPacket = 2048 - 32 receiveSpins = 15 ) type ringPacket struct { addr windows.RawSockaddrInet6 data [bytesPerPacket]byte } type ringBuffer struct { packets uintptr head, tail uint32 id winrio.BufferId iocp windows.Handle isFull bool cq winrio.Cq mu sync.Mutex overlapped windows.Overlapped } type RIOConn struct { isOpen atomic.Bool l *logrus.Logger sock windows.Handle rx, tx ringBuffer rq winrio.Rq results [packetsPerRing]winrio.Result } func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } u := &RIOConn{l: l} err := u.bind(l, &windows.SockaddrInet6{Addr: addr.As16(), Port: port}) if err != nil { return nil, fmt.Errorf("bind: %w", err) } for i := 0; i < packetsPerRing; i++ { err = u.insertReceiveRequest() if err != nil { return nil, fmt.Errorf("init rx ring: %w", err) } } u.isOpen.Store(true) return u, nil } func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { var err error u.sock, err = winrio.Socket(windows.AF_INET6, windows.SOCK_DGRAM, windows.IPPROTO_UDP) if err != nil { return fmt.Errorf("winrio.Socket error: %w", err) } // Enable v4 for this socket syscall.SetsockoptInt(syscall.Handle(u.sock), syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, 0) // Disable reporting of PORT_UNREACHABLE and NET_UNREACHABLE errors from the UDP socket receive call. // These errors are returned on Windows during UDP receives based on the receipt of ICMP packets. Disable // the UDP receive error returns with these ioctl calls. ret := uint32(0) flag := uint32(0) size := uint32(unsafe.Sizeof(flag)) err = syscall.WSAIoctl(syscall.Handle(u.sock), syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0) if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. l.WithError(err).Debug("failed to set UDP_CONNRESET ioctl") } ret = 0 flag = 0 size = uint32(unsafe.Sizeof(flag)) SIO_UDP_NETRESET := uint32(syscall.IOC_IN | syscall.IOC_VENDOR | 15) err = syscall.WSAIoctl(syscall.Handle(u.sock), SIO_UDP_NETRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0) if err != nil { // This is a best-effort to prevent errors from being returned by the udp recv operation. // Quietly log a failure and continue. l.WithError(err).Debug("failed to set UDP_NETRESET ioctl") } err = u.rx.Open() if err != nil { return fmt.Errorf("error rx.Open(): %w", err) } err = u.tx.Open() if err != nil { return fmt.Errorf("error tx.Open(): %w", err) } u.rq, err = winrio.CreateRequestQueue(u.sock, packetsPerRing, 1, packetsPerRing, 1, u.rx.cq, u.tx.cq, 0) if err != nil { return fmt.Errorf("error CreateRequestQueue: %w", err) } err = windows.Bind(u.sock, sa) if err != nil { return fmt.Errorf("error windows.Bind(): %w", err) } return nil } func (u *RIOConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) var lastRecvErr time.Time for { // Just read one packet at a time n, rua, err := u.receive(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } // Dampen unexpected message warns to once per minute if lastRecvErr.IsZero() || time.Since(lastRecvErr) > time.Minute { lastRecvErr = time.Now() u.l.WithError(err).Warn("unexpected udp socket receive error") } continue } r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) } } func (u *RIOConn) insertReceiveRequest() error { packet := u.rx.Push() dataBuffer := &winrio.Buffer{ Id: u.rx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.rx.packets), Length: uint32(len(packet.data)), } addressBuffer := &winrio.Buffer{ Id: u.rx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.rx.packets), Length: uint32(unsafe.Sizeof(packet.addr)), } return winrio.ReceiveEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) } func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { if !u.isOpen.Load() { return 0, windows.RawSockaddrInet6{}, net.ErrClosed } u.rx.mu.Lock() defer u.rx.mu.Unlock() var err error var count uint32 var results [1]winrio.Result retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { if tries > 0 { if !u.isOpen.Load() { return 0, windows.RawSockaddrInet6{}, net.ErrClosed } procyield(1) } count = winrio.DequeueCompletion(u.rx.cq, results[:]) } if count == 0 { err = winrio.Notify(u.rx.cq) if err != nil { return 0, windows.RawSockaddrInet6{}, err } var bytes uint32 var key uintptr var overlapped *windows.Overlapped err = windows.GetQueuedCompletionStatus(u.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) if err != nil { return 0, windows.RawSockaddrInet6{}, err } if !u.isOpen.Load() { return 0, windows.RawSockaddrInet6{}, net.ErrClosed } count = winrio.DequeueCompletion(u.rx.cq, results[:]) if count == 0 { return 0, windows.RawSockaddrInet6{}, io.ErrNoProgress } } u.rx.Return(1) err = u.insertReceiveRequest() if err != nil { return 0, windows.RawSockaddrInet6{}, err } // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to // attacker bandwidth, just like the rest of the receive path. if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { goto retry } if results[0].Status != 0 { return 0, windows.RawSockaddrInet6{}, windows.Errno(results[0].Status) } packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) ep := packet.addr n := copy(buf, packet.data[:results[0].BytesTransferred]) return n, ep, nil } func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { if !u.isOpen.Load() { return net.ErrClosed } if len(buf) > bytesPerPacket { return io.ErrShortBuffer } u.tx.mu.Lock() defer u.tx.mu.Unlock() count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) if count == 0 && u.tx.isFull { err := winrio.Notify(u.tx.cq) if err != nil { return err } var bytes uint32 var key uintptr var overlapped *windows.Overlapped err = windows.GetQueuedCompletionStatus(u.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) if err != nil { return err } if !u.isOpen.Load() { return net.ErrClosed } count = winrio.DequeueCompletion(u.tx.cq, u.results[:]) if count == 0 { return io.ErrNoProgress } } if count > 0 { u.tx.Return(count) } packet := u.tx.Push() packet.addr.Family = windows.AF_INET6 packet.addr.Addr = ip.Addr().As16() port := ip.Port() packet.addr.Port = (port >> 8) | ((port & 0xff) << 8) copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ Id: u.tx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - u.tx.packets), Length: uint32(len(buf)), } addressBuffer := &winrio.Buffer{ Id: u.tx.id, Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - u.tx.packets), Length: uint32(unsafe.Sizeof(packet.addr)), } return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { return netip.AddrPort{}, err } v6 := sa.(*windows.SockaddrInet6) return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil } func (u *RIOConn) SupportsMultipleReaders() bool { return false } func (u *RIOConn) Rebind() error { return nil } func (u *RIOConn) ReloadConfig(*config.C) {} func (u *RIOConn) Close() error { if !u.isOpen.CompareAndSwap(true, false) { return nil } windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) u.rx.CloseAndZero() u.tx.CloseAndZero() if u.sock != 0 { windows.CloseHandle(u.sock) } return nil } func (ring *ringBuffer) Push() *ringPacket { for ring.isFull { panic("ring is full") } ret := (*ringPacket)(unsafe.Pointer(ring.packets + (uintptr(ring.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) ring.tail += 1 if ring.tail%packetsPerRing == ring.head%packetsPerRing { ring.isFull = true } return ret } func (ring *ringBuffer) Return(count uint32) { if ring.head%packetsPerRing == ring.tail%packetsPerRing && !ring.isFull { return } ring.head += count ring.isFull = false } func (ring *ringBuffer) CloseAndZero() { if ring.cq != 0 { winrio.CloseCompletionQueue(ring.cq) ring.cq = 0 } if ring.iocp != 0 { windows.CloseHandle(ring.iocp) ring.iocp = 0 } if ring.id != 0 { winrio.DeregisterBuffer(ring.id) ring.id = 0 } if ring.packets != 0 { windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) ring.packets = 0 } ring.head = 0 ring.tail = 0 ring.isFull = false } func (ring *ringBuffer) Open() error { var err error packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) if err != nil { return err } ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) if err != nil { return err } ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { return err } ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) if err != nil { return err } return nil } ================================================ FILE: udp/udp_tester.go ================================================ //go:build e2e_testing // +build e2e_testing package udp import ( "io" "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) type Packet struct { To netip.AddrPort From netip.AddrPort Data []byte } func (u *Packet) Copy() *Packet { n := &Packet{ To: u.To, From: u.From, Data: make([]byte, len(u.Data)), } copy(n.Data, u.Data) return n } type TesterConn struct { Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula closed atomic.Bool l *logrus.Logger } func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, }, nil } // Send will place a UdpPacket onto the receive queue for nebula to consume // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { if u.closed.Load() { return } h := &header.H{} if err := h.Parse(packet.Data); err != nil { panic(err) } if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). WithField("udpAddr", packet.From). WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } u.RxPackets <- packet } // Get will pull a UdpPacket from the transmit queue // nebula meant to send this message on the network, it will be encrypted // packets were ingested from the tun side (in most cases), you can send them with Tun.Send func (u *TesterConn) Get(block bool) *Packet { if block { return <-u.TxPackets } select { case p := <-u.TxPackets: return p default: return nil } } //********************************************************************************************************************// // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { if u.closed.Load() { return io.ErrClosedPipe } p := &Packet{ Data: make([]byte, len(b), len(b)), From: u.Addr, To: addr, } copy(p.Data, b) u.TxPackets <- p return nil } func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets if !ok { return } r(p.From, p.Data) } } func (u *TesterConn) ReloadConfig(*config.C) {} func NewUDPStatsEmitter(_ []Conn) func() { // No UDP stats for non-linux return func() {} } func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } func (u *TesterConn) SupportsMultipleReaders() bool { return false } func (u *TesterConn) Rebind() error { return nil } func (u *TesterConn) Close() error { if u.closed.CompareAndSwap(false, true) { close(u.RxPackets) close(u.TxPackets) } return nil } ================================================ FILE: udp/udp_windows.go ================================================ //go:build !e2e_testing // +build !e2e_testing package udp import ( "fmt" "net" "net/netip" "syscall" "github.com/sirupsen/logrus" ) func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between // Windows and Linux return nil, fmt.Errorf("multiple udp listeners not supported on windows") } rc, err := NewRIOListener(l, ip, port) if err == nil { return rc, nil } l.WithError(err).Error("Falling back to standard udp sockets") return NewGenericListener(l, ip, port, multi, batch) } func NewListenConfig(multi bool) net.ListenConfig { return net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { if multi { // There is no way to support multiple listeners safely on Windows: // https://docs.microsoft.com/en-us/windows/desktop/winsock/using-so-reuseaddr-and-so-exclusiveaddruse return fmt.Errorf("multiple udp listeners not supported on windows") } return nil }, } } func (u *GenericConn) Rebind() error { return nil } ================================================ FILE: util/error.go ================================================ package util import ( "errors" "fmt" "github.com/sirupsen/logrus" ) type ContextualError struct { RealError error Fields map[string]any Context string } func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError { return &ContextualError{Context: msg, Fields: fields, RealError: realError} } // ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one func ContextualizeIfNeeded(msg string, err error) error { switch err.(type) { case *ContextualError: return err default: return NewContextualError(msg, nil, err) } } // LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) { switch v := err.(type) { case *ContextualError: v.Log(l) default: l.WithError(err).Error(msg) } } func (ce *ContextualError) Error() string { if ce.RealError == nil { return ce.Context } return fmt.Errorf("%s (%v): %w", ce.Context, ce.Fields, ce.RealError).Error() } func (ce *ContextualError) Unwrap() error { if ce.RealError == nil { return errors.New(ce.Context) } return ce.RealError } func (ce *ContextualError) Log(lr *logrus.Logger) { if ce.RealError != nil { lr.WithFields(ce.Fields).WithError(ce.RealError).Error(ce.Context) } else { lr.WithFields(ce.Fields).Error(ce.Context) } } ================================================ FILE: util/error_test.go ================================================ package util import ( "errors" "fmt" "testing" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) type m = map[string]any type TestLogWriter struct { Logs []string } func NewTestLogWriter() *TestLogWriter { return &TestLogWriter{Logs: make([]string, 0)} } func (tl *TestLogWriter) Write(p []byte) (n int, err error) { tl.Logs = append(tl.Logs, string(p)) return len(p), nil } func (tl *TestLogWriter) Reset() { tl.Logs = tl.Logs[:0] } func TestContextualError_Log(t *testing.T) { l := logrus.New() l.Formatter = &logrus.TextFormatter{ DisableTimestamp: true, DisableColors: true, } tl := NewTestLogWriter() l.Out = tl // Test a full context line tl.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) e.Log(l) assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) // Test a line with an error and msg but no fields tl.Reset() e = NewContextualError("test message", nil, errors.New("error")) e.Log(l) assert.Equal(t, []string{"level=error msg=\"test message\" error=error\n"}, tl.Logs) // Test just a context and fields tl.Reset() e = NewContextualError("test message", m{"field": "1"}, nil) e.Log(l) assert.Equal(t, []string{"level=error msg=\"test message\" field=1\n"}, tl.Logs) // Test just a context tl.Reset() e = NewContextualError("test message", nil, nil) e.Log(l) assert.Equal(t, []string{"level=error msg=\"test message\"\n"}, tl.Logs) // Test just an error tl.Reset() e = NewContextualError("", nil, errors.New("error")) e.Log(l) assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs) } func TestLogWithContextIfNeeded(t *testing.T) { l := logrus.New() l.Formatter = &logrus.TextFormatter{ DisableTimestamp: true, DisableColors: true, } tl := NewTestLogWriter() l.Out = tl // Test ignoring fallback context tl.Reset() e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) LogWithContextIfNeeded("This should get thrown away", e, l) assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs) // Test using fallback context tl.Reset() err := fmt.Errorf("this is a normal error") LogWithContextIfNeeded("Fallback context woo", err, l) assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs) } func TestContextualizeIfNeeded(t *testing.T) { // Test ignoring fallback context e := NewContextualError("test message", m{"field": "1"}, errors.New("error")) assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e)) // Test using fallback context err := fmt.Errorf("this is a normal error") cErr := ContextualizeIfNeeded("Fallback context woo", err) switch v := cErr.(type) { case *ContextualError: assert.Equal(t, err, v.RealError) default: t.Error("Error was not wrapped") t.Fail() } } ================================================ FILE: wintun/device.go ================================================ //go:build windows // +build windows /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. */ //NOTE: this file was forked from https://git.zx2c4.com/wireguard-go/tree/tun/tun.go?id=851efb1bb65555e0f765a3361c8eb5ac47435b19 package wintun import ( "os" ) type Device interface { File() *os.File // returns the file descriptor of the device Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) Flush() error // flush all previous writes to the device Name() (string, error) // fetches and returns the current name Close() error // stops the device and closes the event channel } ================================================ FILE: wintun/tun.go ================================================ //go:build windows // +build windows /* SPDX-License-Identifier: MIT * * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. */ //NOTE: This file was forked from https://git.zx2c4.com/wireguard-go/tree/tun/tun_windows.go?id=851efb1bb65555e0f765a3361c8eb5ac47435b19 // Mainly to shed functionality we won't be using and to fix names that display in the system package wintun import ( "errors" "fmt" "os" "sync" "sync/atomic" "time" _ "unsafe" "golang.org/x/sys/windows" "golang.zx2c4.com/wintun" ) const ( rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) spinloopRateThreshold = 800000000 / 8 // 800mbps spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s ) type rateJuggler struct { current uint64 nextByteCount uint64 nextStartTime int64 changing int32 } type NativeTun struct { wt *wintun.Adapter name string handle windows.Handle rate rateJuggler session wintun.Session readWait windows.Handle running sync.WaitGroup closeOnce sync.Once close int32 } var WintunTunnelType = "Nebula" var WintunStaticRequestedGUID *windows.GUID //go:linkname procyield runtime.procyield func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. func CreateTUN(ifname string, mtu int) (Device, error) { return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { return nil, fmt.Errorf("Error creating interface: %w", err) } tun := &NativeTun{ wt: wt, name: ifname, handle: windows.InvalidHandle, } tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { tun.wt.Close() return nil, fmt.Errorf("Error starting session: %w", err) } tun.readWait = tun.session.ReadWaitEvent() return tun, nil } func (tun *NativeTun) Name() (string, error) { return tun.name, nil } func (tun *NativeTun) File() *os.File { return nil } func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { atomic.StoreInt32(&tun.close, 1) windows.SetEvent(tun.readWait) tun.running.Wait() tun.session.End() if tun.wt != nil { tun.wt.Close() } }) return err } // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: if atomic.LoadInt32(&tun.close) == 1 { return 0, os.ErrClosed } start := nanotime() shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 for { if atomic.LoadInt32(&tun.close) == 1 { return 0, os.ErrClosed } packet, err := tun.session.ReceivePacket() switch err { case nil: packetSize := len(packet) copy(buff[offset:], packet) tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) return packetSize, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.readWait, windows.INFINITE) goto retry } procyield(1) continue case windows.ERROR_HANDLE_EOF: return 0, os.ErrClosed case windows.ERROR_INVALID_DATA: return 0, errors.New("Send ring corrupt") } return 0, fmt.Errorf("Read failed: %w", err) } } func (tun *NativeTun) Flush() error { return nil } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() if atomic.LoadInt32(&tun.close) == 1 { return 0, os.ErrClosed } packetSize := len(buff) - offset tun.rate.update(uint64(packetSize)) packet, err := tun.session.AllocateSendPacket(packetSize) if err == nil { copy(packet, buff[offset:]) tun.session.SendPacket(packet) return packetSize, nil } switch err { case windows.ERROR_HANDLE_EOF: return 0, os.ErrClosed case windows.ERROR_BUFFER_OVERFLOW: return 0, nil // Dropping when ring is full. } return 0, fmt.Errorf("Write failed: %w", err) } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { tun.running.Add(1) defer tun.running.Done() if atomic.LoadInt32(&tun.close) == 1 { return 0 } return tun.wt.LUID() } // RunningVersion returns the running version of the Wintun driver. func (tun *NativeTun) RunningVersion() (version uint32, err error) { return wintun.RunningVersion() } func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() total := atomic.AddUint64(&rate.nextByteCount, packetLen) period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) if period >= rateMeasurementGranularity { if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { return } atomic.StoreInt64(&rate.nextStartTime, now) atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) atomic.StoreUint64(&rate.nextByteCount, 0) atomic.StoreInt32(&rate.changing, 0) } }