Repository: discord/libdave Branch: main Commit: 52cd56dc550f Files: 111 Total size: 368.1 KB Directory structure: gitextract_k2hmfbp_/ ├── .github/ │ ├── actions/ │ │ └── prepare-build/ │ │ └── action.yaml │ └── workflows/ │ └── main.yaml ├── .gitignore ├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── cpp/ │ ├── .clang-format │ ├── .gitignore │ ├── CMakeLists.txt │ ├── Makefile │ ├── README.md │ ├── afl-driver/ │ │ └── src/ │ │ └── main.cpp │ ├── includes/ │ │ └── dave/ │ │ ├── array_view.h │ │ ├── dave.h │ │ ├── dave_interfaces.h │ │ ├── logger.h │ │ └── version.h │ ├── src/ │ │ ├── bindings_capi.cpp │ │ ├── bindings_wasm.cpp │ │ ├── boringssl_cryptor.cpp │ │ ├── boringssl_cryptor.h │ │ ├── codec_utils.cpp │ │ ├── codec_utils.h │ │ ├── common.h │ │ ├── cryptor.cpp │ │ ├── cryptor.h │ │ ├── cryptor_manager.cpp │ │ ├── cryptor_manager.h │ │ ├── decryptor.cpp │ │ ├── decryptor.h │ │ ├── encryptor.cpp │ │ ├── encryptor.h │ │ ├── frame_processors.cpp │ │ ├── frame_processors.h │ │ ├── key_ratchet.h │ │ ├── logger.cpp │ │ ├── mls/ │ │ │ ├── detail/ │ │ │ │ ├── persisted_key_pair.h │ │ │ │ ├── persisted_key_pair_apple.cpp │ │ │ │ ├── persisted_key_pair_generic.cpp │ │ │ │ ├── persisted_key_pair_null.cpp │ │ │ │ └── persisted_key_pair_win.cpp │ │ │ ├── parameters.cpp │ │ │ ├── parameters.h │ │ │ ├── persisted_key_pair.cpp │ │ │ ├── persisted_key_pair.h │ │ │ ├── persisted_key_pair_null.cpp │ │ │ ├── session.cpp │ │ │ ├── session.h │ │ │ ├── user_credential.cpp │ │ │ ├── user_credential.h │ │ │ ├── util.cpp │ │ │ └── util.h │ │ ├── mls_key_ratchet.cpp │ │ ├── mls_key_ratchet.h │ │ ├── openssl_cryptor.cpp │ │ ├── openssl_cryptor.h │ │ ├── utils/ │ │ │ ├── clock.h │ │ │ ├── leb128.cpp │ │ │ ├── leb128.h │ │ │ └── scope_exit.h │ │ └── version.cpp │ ├── test/ │ │ ├── CMakeLists.txt │ │ ├── capi/ │ │ │ ├── CMakeLists.txt │ │ │ ├── basic_tests.c │ │ │ ├── external_sender_wrapper.cpp │ │ │ ├── external_sender_wrapper.h │ │ │ ├── test_helpers.c │ │ │ └── test_helpers.h │ │ ├── codec_utils_tests.cpp │ │ ├── cryptor_manager_tests.cpp │ │ ├── cryptor_tests.cpp │ │ ├── dave_test.cpp │ │ ├── dave_test.h │ │ ├── external_sender.cpp │ │ ├── external_sender.h │ │ ├── static_key_ratchet.cpp │ │ ├── static_key_ratchet.h │ │ └── xssl_cryptor_tests.cpp │ └── vcpkg-alts/ │ ├── boringssl/ │ │ ├── overlay-ports/ │ │ │ └── mlspp/ │ │ │ ├── portfile.cmake │ │ │ └── vcpkg.json │ │ └── vcpkg.json │ ├── openssl_1.1/ │ │ ├── overlay-ports/ │ │ │ └── mlspp/ │ │ │ ├── portfile.cmake │ │ │ └── vcpkg.json │ │ └── vcpkg.json │ ├── openssl_3/ │ │ ├── overlay-ports/ │ │ │ └── mlspp/ │ │ │ ├── portfile.cmake │ │ │ └── vcpkg.json │ │ └── vcpkg.json │ └── wasm/ │ ├── overlay-ports/ │ │ └── mlspp/ │ │ ├── portfile.cmake │ │ └── vcpkg.json │ └── vcpkg.json ├── js/ │ ├── .gitignore │ ├── .npmrc │ ├── README.md │ ├── __tests__/ │ │ ├── DisplayableCode-test.ts │ │ ├── KeyFingerprint-test.ts │ │ ├── KeySerialization-test.ts │ │ └── PairwiseFingerprint-test.ts │ ├── jest-setup.js │ ├── jest.config.js │ ├── package.json │ ├── src/ │ │ ├── DisplayableCode.ts │ │ ├── KeyFingerprint.ts │ │ ├── KeySerialization.ts │ │ ├── PairwiseFingerprint.ts │ │ ├── index.ts │ │ └── wasm.ts │ ├── tsconfig.json │ └── wasm/ │ └── .gitignore └── samples/ └── typescript/ ├── DaveSessionManager.ts └── README.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/actions/prepare-build/action.yaml ================================================ name: Install build prerequisites inputs: runner: description: The runner on which the action is being run required: true crypto: description: The crypto library being used required: true cache-dir: description: Where to put vcpkg cache required: true make-args: description: Additional arguments to pass to make to configure the build required: true runs: using: "composite" steps: - name: Capture vcpkg revision for use in cache key shell: bash run: | git -C cpp/vcpkg rev-parse HEAD > cpp/vcpkg_commit.txt - name: Restore cache id: cache-vcpkg-restore uses: actions/cache/restore@v4 with: path: ${{ inputs.cache-dir }} key: vcpkg-${{ inputs.runner }}-${{ inputs.crypto }}-v03-${{ hashFiles('vcpkg_commit', 'cpp/vcpkg-alts/*') }} restore-keys: | vcpkg-${{ inputs.runner }}-${{ inputs.crypto }} v02-vcpkg-${{ inputs.runner }}-${{ inputs.crypto }} - name: vcpkg bootstrap (macOS/Linux) if: ${{ runner.os == 'macOS' || runner.os == 'Linux' }} shell: bash run: | ./cpp/vcpkg/bootstrap-vcpkg.sh - name: vcpkg bootstrap (Windows) if: ${{ runner.os == 'Windows' }} shell: cmd run: cpp\vcpkg\bootstrap-vcpkg.bat - name: Install dependencies (macOS) if: ${{ runner.os == 'macOS' }} shell: bash run: | brew install go nasm - name: Set CC and CXX environment variables (macOS) if: ${{ runner.os == 'macOS' }} shell: bash run: | echo "CC=$(brew --prefix llvm@18)/bin/clang" >> $GITHUB_ENV echo "CXX=$(brew --prefix llvm@18)/bin/clang++" >> $GITHUB_ENV - name: Update dependencies (act) if: ${{ runner.os == 'Linux' && env.ACT }} shell: bash run: | sudo apt-get update - name: Install dependencies (Ubuntu) if: ${{ runner.os == 'Linux' }} shell: bash run: | sudo apt-get install -y nasm - name: Set BUILD_DIR environment variable shell: bash run: | echo "BUILD_DIR=${{ runner.temp }}/build" >> $GITHUB_ENV - name: Configure build shell: bash working-directory: ./cpp run: make '${{ env.BUILD_DIR }}' BUILD_DIR='${{ env.BUILD_DIR }}' ${{ inputs.make-args }} - name: Cache vckpg if: steps.cache-vcpkg-restore.outputs.cache-hit != 'true' uses: actions/cache/save@v4 with: key: ${{ steps.cache-vcpkg-restore.outputs.cache-primary-key }} path: ${{ inputs.cache-dir }} ================================================ FILE: .github/workflows/main.yaml ================================================ name: cpp on: push: branches: ["main"] pull_request: branches: ["**"] env: CMAKE_BUILD_PARALLEL_LEVEL: 3 CMAKE_TOOLCHAIN_FILE: ${{ github.workspace }}/cpp/vcpkg/scripts/buildsystems/vcpkg.cmake VCPKG_BINARY_SOURCES: clear;files,${{ github.workspace }}/vcpkg_cache,readwrite VCPKG_CACHE_DIR: ${{ github.workspace }}/vcpkg_cache defaults: run: working-directory: ./cpp shell: bash jobs: debug-test: strategy: matrix: runner: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, macos-15-intel, windows-latest] crypto: [openssl_3, openssl_1.1, boringssl] fail-fast: false runs-on: ${{matrix.runner}} steps: - uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 - uses: ./.github/actions/prepare-build with: runner: ${{ matrix.runner }} crypto: ${{ matrix.crypto }} cache-dir: ${{ env.VCPKG_CACHE_DIR }} make-args: BUILD_TYPE=Debug SSL=${{ matrix.crypto }} BUILD_SHARED_LIBS=ON TESTING=ON SANITIZERS=ON MSVC_RUNTIME_LIBRARY=MultiThreadedDebug - name: build libdave run: make dev-sanitizers BUILD_DIR='${{ env.BUILD_DIR }}' - name: test run: make dtest BUILD_DIR='${{ env.BUILD_DIR }}' BUILD_TYPE=Debug - name: test-capi run: make dtest-capi BUILD_DIR='${{ env.BUILD_DIR }}' BUILD_TYPE=Debug release-build: strategy: matrix: runner: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, macos-15-intel, windows-latest] crypto: [boringssl] fail-fast: false runs-on: ${{matrix.runner}} steps: - uses: actions/checkout@v4 with: submodules: "recursive" fetch-depth: 0 - uses: ./.github/actions/prepare-build with: runner: ${{ matrix.runner }} crypto: ${{ matrix.crypto }} cache-dir: ${{ env.VCPKG_CACHE_DIR }} make-args: BUILD_TYPE=Release SSL=${{ matrix.crypto }} BUILD_SHARED_LIBS=ON TESTING=ON INSTALL_VCPKG_LICENSES=ON MSVC_RUNTIME_LIBRARY=MultiThreaded - name: build libdave run: make all BUILD_DIR='${{ env.BUILD_DIR }}' BUILD_TYPE=Release - name: run tests if: ${{ runner.os != 'Windows' }} run: make dtest BUILD_DIR='${{ env.BUILD_DIR }}' BUILD_TYPE=Release - name: run C api tests run: make dtest-capi BUILD_DIR='${{ env.BUILD_DIR }}' BUILD_TYPE=Release - name: prepare artifacts (install) run: make install BUILD_DIR='${{ env.BUILD_DIR }}' BUILD_TYPE=Release - name: check licenses run: ls -la '${{ env.BUILD_DIR }}/install/licenses' | grep -q "boringssl\|openssl" - name: upload build artifacts uses: actions/upload-artifact@v6 if: ${{ github.event_name == 'push' }} with: path: ${{ env.BUILD_DIR }}/install name: libdave-${{ runner.os }}-${{ runner.arch }}-${{ matrix.crypto }} if-no-files-found: error ================================================ FILE: .gitignore ================================================ .DS_Store ================================================ FILE: .gitmodules ================================================ [submodule "cpp/vcpkg"] path = cpp/vcpkg url = https://github.com/microsoft/vcpkg.git ================================================ FILE: CONTRIBUTING.md ================================================ At this time we are not taking pull requests to this repository. We welcome reports and suggestions via Github Issues. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Discord Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ ## libdave This repository contains the JS and C++ libraries which together implement Discord's Audio & Video End-to-End Encryption (DAVE) protocol. These libraries are leveraged by Discord's native clients to support the DAVE protocol. The DAVE protocol is described in detail in the [protocol whitepaper](https://github.com/discord/dave-protocol). See the [cpp README](/cpp/README.md) or the [js README](/js/README.md) for information specific to each library. ================================================ FILE: cpp/.clang-format ================================================ --- AccessModifierOffset: -4 AlignAfterOpenBracket: true AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false AlignEscapedNewlines: Left AlignOperands: false AlignTrailingComments: true AllowAllParametersOfDeclarationOnNextLine: false AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: InlineOnly AllowShortIfStatementsOnASingleLine: false AllowShortLoopsOnASingleLine: false AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: false AlwaysBreakTemplateDeclarations: true BinPackArguments: false BinPackParameters: false BreakBeforeBinaryOperators: None BreakBeforeBraces: Stroustrup BreakBeforeInheritanceComma: true BreakBeforeTernaryOperators: true BreakConstructorInitializers: BeforeComma BreakStringLiterals: true ColumnLimit: 100 CommentPragmas: "" CompactNamespaces: false ConstructorInitializerAllOnOneLineOrOnePerLine: false ConstructorInitializerIndentWidth: 2 ContinuationIndentWidth: 2 Cpp11BracedListStyle: true DerivePointerAlignment: false DisableFormat: false FixNamespaceComments: true ForEachMacros: [] IndentCaseLabels: false IncludeBlocks: Preserve IncludeCategories: - Regex: "^<(W|w)indows.h>" Priority: 1 - Regex: "^<" Priority: 2 - Regex: ".*" Priority: 3 IncludeIsMainRegex: "(_test|_win|_linux|_mac|_ios|_osx|_null)?$" IndentPPDirectives: None IndentWidth: 4 IndentWrappedFunctionNames: false KeepEmptyLinesAtTheStartOfBlocks: false MacroBlockBegin: "" MacroBlockEnd: "" MaxEmptyLinesToKeep: 1 NamespaceIndentation: None PenaltyBreakAssignment: 0 PenaltyBreakBeforeFirstCallParameter: 1 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 9999999 PointerAlignment: Left ReflowComments: true SortIncludes: true SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: true SpaceBeforeAssignmentOperators: true SpaceBeforeParens: ControlStatements SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 1 SpacesInAngles: false SpacesInCStyleCastParentheses: false SpacesInContainerLiterals: true SpacesInParentheses: false SpacesInSquareBrackets: false Standard: Cpp11 TabWidth: 4 UseTab: Never ================================================ FILE: cpp/.gitignore ================================================ build/ ================================================ FILE: cpp/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.20) project( libdave VERSION 1.0 LANGUAGES CXX C ) option(REQUIRE_BORINGSSL "Require BoringSSL instead of OpenSSL" OFF) option(TESTING "Build tests" OFF) option(PERSISTENT_KEYS "Enable storage of persistent signature keys" OFF) option(BUILD_SHARED_LIBS "Build shared libraries" OFF) option(ENABLE_SANITIZERS "Enable address and undefined behavior sanitizers" OFF) option(INSTALL_VCPKG_LICENSES "Installs license files from vcpkg deps which require it" OFF) include(CheckCXXCompilerFlag) include(CMakeFindDependencyMacro) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") add_compile_options(-Wall -pedantic -Wextra -Werror -Wimplicit-int-conversion) elseif (CMAKE_CXX_COMPILER_ID MATCHES "GNU") add_compile_options(-Wall -pedantic -Wextra -Werror) elseif(MSVC) add_compile_options(/W4 /WX) add_definitions(-DWINDOWS) # MSVC helpfully recommends safer equivalents for things like # getenv, but they are not portable. add_definitions(-D_CRT_SECURE_NO_WARNINGS) endif() # Configure sanitizers if (ENABLE_SANITIZERS) if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") message(FATAL_ERROR "Sanitizers are only supported for Debug builds") endif() if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") set(SANITIZER_FLAGS "-fsanitize=address,undefined -fno-omit-frame-pointer -fno-optimize-sibling-calls") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SANITIZER_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SANITIZER_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${SANITIZER_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${SANITIZER_FLAGS}") message(STATUS "Sanitizers enabled: address, undefined") elseif(MSVC) set(SANITIZER_FLAGS "/fsanitize=address /Zi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SANITIZER_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SANITIZER_FLAGS}") # Disable STL container annotations to avoid mismatch with dependencies built without ASAN add_definitions(-D_DISABLE_STRING_ANNOTATION=1 -D_DISABLE_VECTOR_ANNOTATION=1) # Find ASAN runtime DLL for copying to test directories get_filename_component(COMPILER_DIR "${CMAKE_CXX_COMPILER}" DIRECTORY) set(ASAN_RUNTIME_DLL "${COMPILER_DIR}/clang_rt.asan_dynamic-x86_64.dll" CACHE FILEPATH "Path to ASAN runtime DLL") if(EXISTS "${ASAN_RUNTIME_DLL}") message(STATUS "ASAN runtime DLL: ${ASAN_RUNTIME_DLL}") else() message(WARNING "ASAN runtime DLL not found at ${ASAN_RUNTIME_DLL}") endif() message(STATUS "Sanitizers enabled: address") endif() endif() find_package(OpenSSL REQUIRED) if (OPENSSL_FOUND) find_path(BORINGSSL_INCLUDE_DIR openssl/is_boringssl.h HINTS ${OPENSSL_INCLUDE_DIR} NO_DEFAULT_PATH) if (BORINGSSL_INCLUDE_DIR) message(STATUS "Found OpenSSL includes are for BoringSSL") add_compile_definitions(WITH_BORINGSSL) if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID MATCHES "GNU") add_compile_options(-Wno-gnu-anonymous-struct -Wno-nested-anon-types) endif () file(STRINGS "${OPENSSL_INCLUDE_DIR}/openssl/crypto.h" boringssl_version_str REGEX "^#[\t ]*define[\t ]+OPENSSL_VERSION_TEXT[\t ]+\"OpenSSL ([0-9])+\\.([0-9])+\\.([0-9])+ .+") string(REGEX REPLACE "^.*OPENSSL_VERSION_TEXT[\t ]+\"OpenSSL ([0-9]+\\.[0-9]+\\.[0-9])+ .+$" "\\1" OPENSSL_VERSION "${boringssl_version_str}") elseif (REQUIRE_BORINGSSL) message(FATAL_ERROR "BoringSSL required but not found") endif () if (${OPENSSL_VERSION} VERSION_GREATER_EQUAL 3) add_compile_definitions(WITH_OPENSSL3) elseif(${OPENSSL_VERSION} VERSION_LESS 1.1.1) message(FATAL_ERROR "OpenSSL 1.1.1 or greater is required") endif() message(STATUS "OpenSSL Found: ${OPENSSL_VERSION}") message(STATUS "OpenSSL Include: ${OPENSSL_INCLUDE_DIR}") message(STATUS "OpenSSL Libraries: ${OPENSSL_LIBRARIES}") else() message(FATAL_ERROR "No OpenSSL library found") endif() find_package(nlohmann_json REQUIRED) find_dependency(MLSPP REQUIRED) set(CMAKE_STATIC_LIBRARY_PREFIX "") SET(LIB_NAME ${PROJECT_NAME}) file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.h" "${CMAKE_CURRENT_SOURCE_DIR}/includes/*.h") file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") # remove all of the persistent key files list(FILTER LIB_SOURCES EXCLUDE REGEX ".*persisted_key.*") if (PERSISTENT_KEYS) # persistent keys enabled list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/mls/persisted_key_pair.cpp") if (APPLE) # Apple has its own native and generic implementation, we just add the _apple.cpp file list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/mls/detail/persisted_key_pair_apple.cpp") else () # Other platforms share the generic implementation list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/mls/detail/persisted_key_pair_generic.cpp") if (WIN32) # Windows has a native implementation list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/mls/detail/persisted_key_pair_win.cpp") else () # We don't have a native implementation, so we include the nullified native list(APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/mls/detail/persisted_key_pair_null.cpp") endif () endif () else () # not using persistent keys, so we just need to add the null implementation list (APPEND LIB_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/src/mls/persisted_key_pair_null.cpp") endif () if (NOT WIN32) list(FILTER LIB_SOURCES EXCLUDE REGEX ".*_win.cpp") endif () if (NOT APPLE) list(FILTER LIB_SOURCES EXCLUDE REGEX ".*_apple.cpp") endif () if (NOT DEFINED EMSCRIPTEN) list(FILTER LIB_SOURCES EXCLUDE REGEX ".*_wasm.cpp") else() list(FILTER LIB_SOURCES EXCLUDE REGEX ".*_capi.cpp") endif() if (BORINGSSL_INCLUDE_DIR) list(FILTER LIB_SOURCES EXCLUDE REGEX ".*openssl_cryptor.*") else () list(FILTER LIB_SOURCES EXCLUDE REGEX ".*boringssl_cryptor.*") endif() if (DEFINED EMSCRIPTEN) add_executable(${LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) set(OPTIMIZATION "-O3") set(CONFIG "-sWASM=1 -sWASM_BIGINT -sENVIRONMENT=web -sMODULARIZE -sALLOW_MEMORY_GROWTH") set(EXPORTS "-sEXPORT_ES6=1 -sEXPORT_NAME=DaveModuleFactory -sEXPORTED_RUNTIME_METHODS='[\"ccall\"]' -sEXPORTED_FUNCTIONS='[\"_malloc\", \"_free\"]'") set(COMPILE_FLAGS "${OPTIMIZATION}") set(LINK_FLAGS "${OPTIMIZATION} ${CONFIG} ${EXPORTS} -lembind --no-entry --whole-archive --emit-tsd libdave.d.ts") set_target_properties(${LIB_NAME} PROPERTIES COMPILE_FLAGS "${COMPILE_FLAGS}") set_target_properties(${LIB_NAME} PROPERTIES LINK_FLAGS "${LINK_FLAGS}") else() add_library(${LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) if (BUILD_SHARED_LIBS AND NOT WIN32) # Whithout this the resulting file is called liblibdave.dylib set_target_properties(${LIB_NAME} PROPERTIES OUTPUT_NAME dave) endif() endif() if (TESTING) add_subdirectory(test) endif() target_include_directories( ${LIB_NAME} PUBLIC $ $ PRIVATE $ ) target_link_libraries(${LIB_NAME} PRIVATE OpenSSL::Crypto) target_link_libraries(${LIB_NAME} PRIVATE MLSPP::mlspp) if (APPLE AND PERSISTENT_KEYS) target_link_libraries(${LIB_NAME} PUBLIC "-framework CoreFoundation" "-framework Security") endif() set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY ON) install(TARGETS ${LIB_NAME} INCLUDES DESTINATION "include") install(DIRECTORY ${PROJECT_SOURCE_DIR}/includes/ DESTINATION "include") if (INSTALL_VCPKG_LICENSES) set(DEPS_NEEDING_LICENSE mlspp nlohmann-json) if (BORINGSSL_INCLUDE_DIR) list(APPEND DEPS_NEEDING_LICENSE boringssl) else () list(APPEND DEPS_NEEDING_LICENSE openssl) endif () foreach(DEP_NAME ${DEPS_NEEDING_LICENSE}) set(DEP_LICENSE_PATH "${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}/share/${DEP_NAME}") if (NOT EXISTS ${DEP_LICENSE_PATH}) message(ERROR "Could not find license file for ${DEP_LICENSE_PATH}") endif() install(FILES ${DEP_LICENSE_PATH}/copyright DESTINATION "licenses" RENAME ${DEP_NAME}) endforeach() endif() install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE DESTINATION "licenses" RENAME ${LIB_NAME}) ================================================ FILE: cpp/Makefile ================================================ # Options BUILD_DIR=build INSTALL_DIR=${BUILD_DIR}/install SANITIZERS=OFF BUILD_SHARED_LIBS=OFF PERSISTENT_KEYS=OFF TESTING=OFF SSL=openssl_3 INSTALL_VCPKG_LICENSES=OFF # Paths BORINGSSL_MANIFEST=vcpkg-alts/boringssl OPENSSL_1_1_MANIFEST=vcpkg-alts/openssl_1.1 OPENSSL_3_MANIFEST=vcpkg-alts/openssl_3 WASM_MANIFEST=vcpkg-alts/wasm TOOLCHAIN_FILE=vcpkg/scripts/buildsystems/vcpkg.cmake EMSCRIPTEN_TOOLCHAIN_FILE=${EMSDK}/upstream/emscripten/cmake/Modules/Platform/Emscripten.cmake CLANG_FORMAT=clang-format -i -style=file:.clang-format DEFAULT_BUILD_TYPE=Debug BUILD_TYPE ?= $(DEFAULT_BUILD_TYPE) all shared install: DEFAULT_BUILD_TYPE=Release ifeq ($(SSL), boringssl) SSL_MANIFEST=${BORINGSSL_MANIFEST} else ifeq ($(SSL), openssl_1.1) SSL_MANIFEST=${OPENSSL_1_1_MANIFEST} else ifeq ($(SSL), openssl_3) SSL_MANIFEST=${OPENSSL_3_MANIFEST} else $(error Invalid SSL option: $(SSL)) endif ifeq ($(OS), Windows_NT) EXTRA_FLAGS=-DVCPKG_TARGET_TRIPLET=x64-windows-static \ -DVCPKG_TARGET_ARCHITECTURE=x86_64 ifeq ($(BUILD_TYPE), Debug) EXTRA_FLAGS+=-DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=ON endif ifdef MSVC_RUNTIME_LIBRARY EXTRA_FLAGS+=-DCMAKE_MSVC_RUNTIME_LIBRARY=${MSVC_RUNTIME_LIBRARY} endif endif all: ${BUILD_DIR} cmake --build ${BUILD_DIR} --target libdave --config ${BUILD_TYPE} ${BUILD_DIR}: CMakeLists.txt test/CMakeLists.txt test/capi/CMakeLists.txt cmake -B${BUILD_DIR} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DVCPKG_MANIFEST_DIR=${SSL_MANIFEST} \ -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS} \ -DENABLE_SANITIZERS=${SANITIZERS} \ -DPERSISTENT_KEYS=${PERSISTENT_KEYS} \ -DTESTING=${TESTING} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_DIR} \ -DINSTALL_VCPKG_LICENSES=${INSTALL_VCPKG_LICENSES} \ ${EXTRA_FLAGS} install: ${BUILD_DIR} cmake --build ${BUILD_DIR} --target install --config ${BUILD_TYPE} shared: $(MAKE) all BUILD_SHARED_LIBS=ON dev: $(MAKE) all TESTING=ON BUILD_TYPE=$(BUILD_TYPE) dev-shared: $(MAKE) dev BUILD_SHARED_LIBS=ON dev-sanitizers: $(MAKE) dev SANITIZERS=ON devB: # Like `dev`, but using OpenSSL 1.1 $(MAKE) dev SSL=openssl_1.1 devC: # Like `dev`, but using BoringSSL $(MAKE) dev SSL=boringssl wasm: check-emsdk emcmake cmake -B${BUILD_DIR} -DCMAKE_BUILD_TYPE=Release \ -DVCPKG_MANIFEST_DIR=${WASM_MANIFEST} \ -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ -DVCPKG_CHAINLOAD_TOOLCHAIN_FILE=${EMSCRIPTEN_TOOLCHAIN_FILE} \ -DVCPKG_TARGET_TRIPLET=wasm32-emscripten cmake --build ${BUILD_DIR} --target libdave --config ${BUILD_TYPE} check-emsdk: @if [ -z "$$EMSDK" ]; then \ echo "Error: EMSDK environment variable is not set"; \ echo "Please set it to your emsdk installation directory"; \ echo "Example: export EMSDK=/path/to/emsdk"; \ exit 1; \ fi test: dev test/* cmake --build ${BUILD_DIR} --target libdave_test --config ${BUILD_TYPE} test-capi: dev test/capi/* cmake --build ${BUILD_DIR} --target capi_test --config ${BUILD_TYPE} test-sanitizers: dev-sanitizers test/* cmake --build ${BUILD_DIR} --target libdave_test --config ${BUILD_TYPE} test-capi-sanitizers: dev-sanitizers test/capi/* cmake --build ${BUILD_DIR} --target capi_test --config ${BUILD_TYPE} dtest: test ifeq ($(OS), Windows_NT) ${BUILD_DIR}/test/${BUILD_TYPE}/libdave_test.exe else ${BUILD_DIR}/test/libdave_test endif dtest-capi: test-capi ifeq ($(OS), Windows_NT) ${BUILD_DIR}/test/capi/${BUILD_TYPE}/capi_test.exe else ${BUILD_DIR}/test/capi/capi_test endif dtest-sanitizers: test-sanitizers ifeq ($(OS), Windows_NT) ${BUILD_DIR}/test/${BUILD_TYPE}/libdave_test.exe else ${BUILD_DIR}/test/libdave_test endif dtest-capi-sanitizers: test-capi-sanitizers ifeq ($(OS), Windows_NT) ${BUILD_DIR}/test/capi/${BUILD_TYPE}/capi_test.exe else ${BUILD_DIR}/test/capi/capi_test endif dbtest: test lldb ${BUILD_DIR}/test/libdave_test dbtest-capi: test-capi lldb ${BUILD_DIR}/test/capi/capi_test ctest: test cmake --build ${BUILD_DIR} --target test --config ${BUILD_TYPE} clean: cmake --build ${BUILD_DIR} --target clean cclean: ifeq ($(OS), Windows_NT) if exist ${BUILD_DIR} rmdir /s /q ${BUILD_DIR} else rm -rf ${BUILD_DIR} endif format: find src -iname "*.h" -or -iname "*.cpp" -or -iname "*.c" | xargs ${CLANG_FORMAT} find test -iname "*.h" -or -iname "*.cpp" -or -iname "*.c" | xargs ${CLANG_FORMAT} ================================================ FILE: cpp/README.md ================================================ ## libdave C++ Contains the libdave C++ library, which handles the bulk of the DAVE protocol implementation for Discord's native clients. ### Dependencies - [mlspp](https://github.com/cisco/mlspp) - Configured with `-DMLS_CXX_NAMESPACE="mlspp"` and `-DDISABLE_GREASE=ON` - One of the supported SSL backends: - [OpenSSL 1.1 or 3.0](https://github.com/openssl/openssl) - [boringssl](https://boringssl.googlesource.com/boringssl) #### Testing - [googletest](https://github.com/google/googletest) - [AFLplusplus](https://github.com/AFLplusplus/AFLplusplus) ## Building ### vcpkg Make sure the vcpkg submodule is up to date and initialized: ``` git submodule update --recursive ./vcpkg/bootstrap-vcpkg.sh ``` ### Compiling For a static library, run: ``` make cclean make ``` For a shared library, run: ``` make cclean make shared ``` ### SSL By default the library builds with OpenSSL 3, however you can modify `VCPKG_MANIFEST_DIR` in the [Makefile](Makefile) to build with OpenSSL 1.1 or BoringSSL instead. ================================================ FILE: cpp/afl-driver/src/main.cpp ================================================ #include #include #include #include #include #include #include "common.h" #include "utils/array_view.h" #include "decryptor.h" using namespace discord::dave; extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { FuzzedDataProvider provider(data, size); MediaType mediaType = static_cast(provider.ConsumeIntegralInRange(0, 2)); const auto InFrame = provider.ConsumeRemainingBytes(); Decryptor decryptor; const auto OutFrameSize = decryptor.GetMaxPlaintextByteSize(mediaType, InFrame.size()); auto outFrame = std::make_unique(OutFrameSize); [[maybe_unused]] auto res = decryptor.Decrypt(mediaType, MakeArrayView(InFrame.data(), InFrame.size()), MakeArrayView(outFrame.get(), OutFrameSize)); return 0; } ================================================ FILE: cpp/includes/dave/array_view.h ================================================ #pragma once #include #include #include namespace discord { namespace dave { template class ArrayView { public: ArrayView() = default; ArrayView(T* data, size_t size) : data_(data) , size_(size) { } size_t size() const { return size_; } T* data() const { return data_; } T* begin() const { return data_; } T* end() const { return data_ + size_; } private: T* data_ = nullptr; size_t size_ = 0; }; template inline ArrayView MakeArrayView(T* data, size_t size) { return ArrayView(data, size); } template inline ArrayView MakeArrayView(std::vector& data) { return ArrayView(data.data(), data.size()); } } // namespace dave } // namespace discord ================================================ FILE: cpp/includes/dave/dave.h ================================================ /** * @file dave.h * @brief DAVE (Discord Audio/Video Encryption) C API * * This header provides the C API for end-to-end encryption of audio and video * streams using the DAVE protocol. * * All handles are opaque pointers that must be created and destroyed using * the corresponding API functions. Memory management rules: * - Handles from *Create functions must be freed with their *Destroy counterpart * - Output handles should be destroyed by the caller using the corresponding *Destroy function * - Functions do not take ownership of the input data unless otherwise specified * - Output byte arrays allocated by the library must be freed using daveFree() */ #pragma once #include #include #include #if (defined(_WIN32) || defined(_WIN64)) #define DAVE_EXPORT __declspec(dllexport) #else #define DAVE_EXPORT __attribute__((visibility("default"))) #endif #define DECLARE_OPAQUE_HANDLE(x) typedef struct x##_s* x #ifdef __cplusplus extern "C" { #endif /** @brief DAVE session handle for managing group encryption state */ DECLARE_OPAQUE_HANDLE(DAVESessionHandle); /** @brief Result handle from processing an MLS commit message */ DECLARE_OPAQUE_HANDLE(DAVECommitResultHandle); /** @brief Result handle from processing an MLS welcome message */ DECLARE_OPAQUE_HANDLE(DAVEWelcomeResultHandle); /** @brief Key ratchet handle for deriving encryption keys */ DECLARE_OPAQUE_HANDLE(DAVEKeyRatchetHandle); /** @brief Media frame encryptor handle */ DECLARE_OPAQUE_HANDLE(DAVEEncryptorHandle); /** @brief Media frame decryptor handle */ DECLARE_OPAQUE_HANDLE(DAVEDecryptorHandle); /** * @brief Supported media codecs for encryption */ typedef enum { DAVE_CODEC_UNKNOWN = 0, /**< Unknown or unspecified codec */ DAVE_CODEC_OPUS = 1, /**< Opus audio codec */ DAVE_CODEC_VP8 = 2, /**< VP8 video codec */ DAVE_CODEC_VP9 = 3, /**< VP9 video codec */ DAVE_CODEC_H264 = 4, /**< H.264/AVC video codec */ DAVE_CODEC_H265 = 5, /**< H.265/HEVC video codec */ DAVE_CODEC_AV1 = 6 /**< AV1 video codec */ } DAVECodec; /** * @brief Media stream type classification */ typedef enum { DAVE_MEDIA_TYPE_AUDIO = 0, /**< Audio stream */ DAVE_MEDIA_TYPE_VIDEO = 1 /**< Video stream */ } DAVEMediaType; /** * @brief Result codes returned by encryption operations */ typedef enum { DAVE_ENCRYPTOR_RESULT_CODE_SUCCESS = 0, /**< Encryption succeeded */ DAVE_ENCRYPTOR_RESULT_CODE_ENCRYPTION_FAILURE = 1, /**< Encryption failed */ DAVE_ENCRYPTOR_RESULT_CODE_MISSING_KEY_RATCHET = 2,/**< No key ratchet available */ DAVE_ENCRYPTOR_RESULT_CODE_MISSING_CRYPTOR = 3, /**< Missing cryptographic context */ DAVE_ENCRYPTOR_RESULT_CODE_TOO_MANY_ATTEMPTS = 4, /**< Too many attempts to encrypt the frame failed */ } DAVEEncryptorResultCode; /** * @brief Result codes returned by decryption operations */ typedef enum { DAVE_DECRYPTOR_RESULT_CODE_SUCCESS = 0, /**< Decryption succeeded */ DAVE_DECRYPTOR_RESULT_CODE_DECRYPTION_FAILURE = 1, /**< Decryption failed */ DAVE_DECRYPTOR_RESULT_CODE_MISSING_KEY_RATCHET = 2,/**< No key ratchet available */ DAVE_DECRYPTOR_RESULT_CODE_INVALID_NONCE = 3, /**< Invalid nonce in encrypted frame */ DAVE_DECRYPTOR_RESULT_CODE_MISSING_CRYPTOR = 4 /**< Missing cryptographic context */ } DAVEDecryptorResultCode; /** * @brief Log message severity levels */ typedef enum { DAVE_LOGGING_SEVERITY_VERBOSE = 0, /**< Verbose debug information */ DAVE_LOGGING_SEVERITY_INFO = 1, /**< Informational messages */ DAVE_LOGGING_SEVERITY_WARNING = 2, /**< Warning messages */ DAVE_LOGGING_SEVERITY_ERROR = 3, /**< Error messages */ DAVE_LOGGING_SEVERITY_NONE = 4 /**< Messages to be ignored */ } DAVELoggingSeverity; /** * @brief Callback invoked when an MLS protocol failure occurs * @param source The source/component where the failure occurred * @param reason Human-readable failure reason * @param userData User-provided context pointer */ typedef void (*DAVEMLSFailureCallback)(const char* source, const char* reason, void* userData); /** * @brief Callback invoked with the computed pairwise fingerprint for identity verification * @param fingerprint Pointer to fingerprint bytes (freed by the library after the callback returns) * @param length Length of fingerprint in bytes * @param userData User-provided context pointer */ typedef void (*DAVEPairwiseFingerprintCallback)(const uint8_t* fingerprint, size_t length, void* userData); /** * @brief Callback invoked when the encryptor's protocol version changes * @param userData User-provided context pointer */ typedef void (*DAVEEncryptorProtocolVersionChangedCallback)(void* userData); /** * @brief Custom log sink callback for receiving library log messages * @param severity Log severity level * @param file Source file name where log originated (freed by the library after the callback returns) * @param line Line number in source file * @param message Log message text (freed by the library after the callback returns) */ typedef void (*DAVELogSinkCallback)(DAVELoggingSeverity severity, const char* file, int line, const char* message); /** * @brief Statistics for encryption operations */ typedef struct DAVEEncryptorStats { uint64_t passthroughCount; /**< Frames passed through without encryption */ uint64_t encryptSuccessCount; /**< Successful encryption count */ uint64_t encryptFailureCount; /**< Failed encryption count */ uint64_t encryptDuration; /**< Total encryption duration */ uint64_t encryptAttempts; /**< Total encryption attempts */ uint64_t encryptMaxAttempts; /**< Maximum retry attempts for a single frame */ uint64_t encryptMissingKeyCount;/**< Encryptions skipped due to missing key */ } DAVEEncryptorStats; /** * @brief Statistics for decryption operations */ typedef struct DAVEDecryptorStats { uint64_t passthroughCount; /**< Frames passed through without decryption */ uint64_t decryptSuccessCount; /**< Successful decryption count */ uint64_t decryptFailureCount; /**< Failed decryption count */ uint64_t decryptDuration; /**< Total decryption duration */ uint64_t decryptAttempts; /**< Total decryption attempts */ uint64_t decryptMissingKeyCount; /**< Decryptions failed due to missing key */ uint64_t decryptInvalidNonceCount;/**< Decryptions failed due to invalid nonce */ } DAVEDecryptorStats; /******************************************************************************* * Version ******************************************************************************/ /** * @brief Returns the maximum protocol version supported by this library * @return Maximum supported protocol version number */ DAVE_EXPORT uint16_t daveMaxSupportedProtocolVersion(void); /******************************************************************************* * Memory Management ******************************************************************************/ /** * @brief Frees memory allocated by the DAVE library * * Use this function to free any byte arrays or buffers allocated by DAVE API * functions. * * @param ptr Pointer to memory previously allocated by a DAVE API function. * If NULL, this function does nothing. * * @note This function should be used to free output byte arrays from functions * like daveSessionGetLastEpochAuthenticator, daveSessionGetMarshalledKeyPackage, * daveCommitResultGetRosterMemberIds, etc. Do NOT use this to destroy handles; * use the corresponding *Destroy functions instead. */ DAVE_EXPORT void daveFree(void* ptr); /******************************************************************************* * Session Management ******************************************************************************/ /** * @brief Creates a new DAVE session * @param context Currently unused platform-specific context pointer (can be NULL) * @param authSessionId String used to manage persistent key lifetimes (can be NULL) * @param callback Callback invoked on MLS failures * @param userData User data pointer passed to the callback * @return New session handle, or NULL on failure. Must be destroyed with daveSessionDestroy() */ DAVE_EXPORT DAVESessionHandle daveSessionCreate(void* context, const char* authSessionId, DAVEMLSFailureCallback callback, void* userData); /** * @brief Destroys a session and frees associated resources * @param session Session handle to destroy */ DAVE_EXPORT void daveSessionDestroy(DAVESessionHandle session); /** * @brief Initializes a session with protocol version and group information * @param session Session handle * @param version Protocol version to use * @param groupId Group identifier common to all users in the group * @param selfUserId User ID of the local user */ DAVE_EXPORT void daveSessionInit(DAVESessionHandle session, uint16_t version, uint64_t groupId, const char* selfUserId); /** * @brief Resets the session state * @param session Session handle */ DAVE_EXPORT void daveSessionReset(DAVESessionHandle session); /** * @brief Sets the protocol version for the session * @param session Session handle * @param version Protocol version to set */ DAVE_EXPORT void daveSessionSetProtocolVersion(DAVESessionHandle session, uint16_t version); /** * @brief Gets the current protocol version of the session * @param session Session handle * @return Current protocol version */ DAVE_EXPORT uint16_t daveSessionGetProtocolVersion(DAVESessionHandle session); /** * @brief Retrieves the authenticator from the last MLS epoch * @param session Session handle * @param[out] authenticator Output pointer to authenticator bytes (caller must free with daveFree) * @param[out] length Output pointer to authenticator length */ DAVE_EXPORT void daveSessionGetLastEpochAuthenticator(DAVESessionHandle session, uint8_t** authenticator, size_t* length); /** * @brief Sets the external sender credentials for the session * @param session Session handle * @param externalSender External sender credential bytes * @param length Length of external sender data */ DAVE_EXPORT void daveSessionSetExternalSender(DAVESessionHandle session, const uint8_t* externalSender, size_t length); /** * @brief Processes MLS proposals and generates commit/welcome messages * @param session Session handle * @param proposals Serialized proposal bytes * @param length Length of proposals * @param recognizedUserIds Array of recognized user ID strings * @param recognizedUserIdsLength Number of recognized user IDs * @param[out] commitWelcomeBytes Output buffer to commit/welcome message bytes (caller must free with daveFree) * @param[out] commitWelcomeBytesLength Output length of the commit/welcome message */ DAVE_EXPORT void daveSessionProcessProposals(DAVESessionHandle session, const uint8_t* proposals, size_t length, const char** recognizedUserIds, size_t recognizedUserIdsLength, uint8_t** commitWelcomeBytes, size_t* commitWelcomeBytesLength); /** * @brief Processes an incoming MLS commit message * @param session Session handle * @param commit Serialized commit message bytes * @param length Length of commit message * @return Commit result handle. Must be destroyed with daveCommitResultDestroy() */ DAVE_EXPORT DAVECommitResultHandle daveSessionProcessCommit(DAVESessionHandle session, const uint8_t* commit, size_t length); /** * @brief Processes an incoming MLS welcome message to join a group * @param session Session handle * @param welcome Serialized welcome message bytes * @param length Length of welcome message * @param recognizedUserIds Array of recognized user ID strings * @param recognizedUserIdsLength Number of recognized user IDs * @return Welcome result handle. Must be destroyed with daveWelcomeResultDestroy() */ DAVE_EXPORT DAVEWelcomeResultHandle daveSessionProcessWelcome(DAVESessionHandle session, const uint8_t* welcome, size_t length, const char** recognizedUserIds, size_t recognizedUserIdsLength); /** * @brief Gets the marshalled MLS key package for this session * @param session Session handle * @param[out] keyPackage Output buffer to key package bytes (caller must free with daveFree) * @param[out] length Output length of the key package */ DAVE_EXPORT void daveSessionGetMarshalledKeyPackage(DAVESessionHandle session, uint8_t** keyPackage, size_t* length); /** * @brief Gets a key ratchet for a specific user in the session * @param session Session handle * @param userId User ID to get key ratchet for * @return Key ratchet handle. Must be destroyed with daveKeyRatchetDestroy() */ DAVE_EXPORT DAVEKeyRatchetHandle daveSessionGetKeyRatchet(DAVESessionHandle session, const char* userId); /** * @brief Computes a pairwise fingerprint for identity verification with another user * @param session Session handle * @param version Protocol version currently in use * @param userId User ID of the remote user to compute the fingerprint for * @param callback Callback to receive the fingerprint * @param userData User data passed to callback */ DAVE_EXPORT void daveSessionGetPairwiseFingerprint(DAVESessionHandle session, uint16_t version, const char* userId, DAVEPairwiseFingerprintCallback callback, void* userData); /******************************************************************************* * Key Ratchet ******************************************************************************/ /** * @brief Destroys a key ratchet and frees associated resources * @param keyRatchet Key ratchet handle to destroy */ DAVE_EXPORT void daveKeyRatchetDestroy(DAVEKeyRatchetHandle keyRatchet); /******************************************************************************* * Commit Result ******************************************************************************/ /** * @brief Checks if processing the commit failed * @param commitResultHandle Commit result handle * @return true if commit processing failed */ DAVE_EXPORT bool daveCommitResultIsFailed(DAVECommitResultHandle commitResultHandle); /** * @brief Checks if the commit should be ignored * @param commitResultHandle Commit result handle * @return true if commit should be ignored */ DAVE_EXPORT bool daveCommitResultIsIgnored(DAVECommitResultHandle commitResultHandle); /** * @brief Gets the list of member IDs in the roster after the commit * @param commitResultHandle Commit result handle * @param[out] rosterIds Output buffer to array of roster member IDs (caller must free with daveFree) * @param[out] rosterIdsLength Output length of the roster member IDs array */ DAVE_EXPORT void daveCommitResultGetRosterMemberIds(DAVECommitResultHandle commitResultHandle, uint64_t** rosterIds, size_t* rosterIdsLength); /** * @brief Gets the signature for a specific roster member * @param commitResultHandle Commit result handle * @param rosterId Roster member ID * @param[out] signature Output buffer to signature bytes (caller must free with daveFree) * @param[out] signatureLength Output length of the signature */ DAVE_EXPORT void daveCommitResultGetRosterMemberSignature(DAVECommitResultHandle commitResultHandle, uint64_t rosterId, uint8_t** signature, size_t* signatureLength); /** * @brief Destroys a commit result and frees associated resources * @param commitResultHandle Commit result handle to destroy */ DAVE_EXPORT void daveCommitResultDestroy(DAVECommitResultHandle commitResultHandle); /******************************************************************************* * Welcome Result ******************************************************************************/ /** * @brief Gets the list of member IDs in the roster from the welcome message * @param welcomeResultHandle Welcome result handle * @param[out] rosterIds Output buffer to array of roster member IDs (caller must free with daveFree) * @param[out] rosterIdsLength Output length of the roster member IDs array */ DAVE_EXPORT void daveWelcomeResultGetRosterMemberIds(DAVEWelcomeResultHandle welcomeResultHandle, uint64_t** rosterIds, size_t* rosterIdsLength); /** * @brief Gets the signature for a specific roster member * @param welcomeResultHandle Welcome result handle * @param rosterId Roster member ID * @param[out] signature Output buffer to signature bytes (caller must free with daveFree) * @param[out] signatureLength Output length of the signature */ DAVE_EXPORT void daveWelcomeResultGetRosterMemberSignature(DAVEWelcomeResultHandle welcomeResultHandle, uint64_t rosterId, uint8_t** signature, size_t* signatureLength); /** * @brief Destroys a welcome result and frees associated resources * @param welcomeResultHandle Welcome result handle to destroy */ DAVE_EXPORT void daveWelcomeResultDestroy(DAVEWelcomeResultHandle welcomeResultHandle); /******************************************************************************* * Encryptor ******************************************************************************/ /** * @brief Creates a new media frame encryptor * @return New encryptor handle. Must be destroyed with daveEncryptorDestroy() */ DAVE_EXPORT DAVEEncryptorHandle daveEncryptorCreate(void); /** * @brief Destroys an encryptor and frees associated resources * @param encryptor Encryptor handle to destroy */ DAVE_EXPORT void daveEncryptorDestroy(DAVEEncryptorHandle encryptor); /** * @brief Sets the key ratchet for encryption * @param encryptor Encryptor handle * @param keyRatchet Key ratchet to use for encryption (does *not* take ownership) */ DAVE_EXPORT void daveEncryptorSetKeyRatchet(DAVEEncryptorHandle encryptor, DAVEKeyRatchetHandle keyRatchet); /** * @brief Enables or disables passthrough mode (frames pass through unencrypted) * @param encryptor Encryptor handle * @param passthroughMode true to enable passthrough, false to encrypt */ DAVE_EXPORT void daveEncryptorSetPassthroughMode(DAVEEncryptorHandle encryptor, bool passthroughMode); /** * @brief Associates an SSRC (Synchronization Source) with a specific codec * @param encryptor Encryptor handle * @param ssrc SSRC identifier * @param codecType Codec type for this SSRC */ DAVE_EXPORT void daveEncryptorAssignSsrcToCodec(DAVEEncryptorHandle encryptor, uint32_t ssrc, DAVECodec codecType); /** * @brief Gets the current protocol version used by the encryptor * @param encryptor Encryptor handle * @return Protocol version number */ DAVE_EXPORT uint16_t daveEncryptorGetProtocolVersion(DAVEEncryptorHandle encryptor); /** * @brief Calculates the maximum ciphertext size for a given plaintext frame size * @param encryptor Encryptor handle * @param mediaType Media type (audio or video) * @param frameSize Size of plaintext frame in bytes * @return Maximum possible ciphertext size in bytes */ DAVE_EXPORT size_t daveEncryptorGetMaxCiphertextByteSize(DAVEEncryptorHandle encryptor, DAVEMediaType mediaType, size_t frameSize); /** * @brief Checks if the encryptor has a key ratchet * @param encryptor Encryptor handle * @return true if has key ratchet, false otherwise */ DAVE_EXPORT bool daveEncryptorHasKeyRatchet(DAVEEncryptorHandle encryptor); /** * @brief Checks if the encryptor is in passthrough mode * @param encryptor Encryptor handle * @return true if in passthrough mode, false otherwise */ DAVE_EXPORT bool daveEncryptorIsPassthroughMode(DAVEEncryptorHandle encryptor); /** * @brief Encrypts a media frame * @param encryptor Encryptor handle * @param mediaType Media type (audio or video) * @param ssrc SSRC of the stream * @param frame Pointer to plaintext frame data * @param frameLength Length of plaintext frame * @param[out] encryptedFrame Pointer to the output buffer the encrypted frame will be written to * @param encryptedFrameCapacity Capacity of the output buffer * @param[out] bytesWritten Number of bytes written to the output buffer * @return Result code indicating success or failure */ DAVE_EXPORT DAVEEncryptorResultCode daveEncryptorEncrypt(DAVEEncryptorHandle encryptor, DAVEMediaType mediaType, uint32_t ssrc, const uint8_t* frame, size_t frameLength, uint8_t* encryptedFrame, size_t encryptedFrameCapacity, size_t* bytesWritten); /** * @brief Sets a callback to be notified when the protocol version changes * @param encryptor Encryptor handle * @param callback Callback function * @param userData User data passed to callback */ DAVE_EXPORT void daveEncryptorSetProtocolVersionChangedCallback( DAVEEncryptorHandle encryptor, DAVEEncryptorProtocolVersionChangedCallback callback, void* userData); /** * @brief Gets encryption statistics * @param encryptor Encryptor handle * @param mediaType Media type (audio or video) * @param[out] stats Pointer to the stats structure to be filled */ DAVE_EXPORT void daveEncryptorGetStats(DAVEEncryptorHandle encryptor, DAVEMediaType mediaType, DAVEEncryptorStats* stats); /******************************************************************************* * Decryptor ******************************************************************************/ /** * @brief Creates a new media frame decryptor * @return New decryptor handle. Must be destroyed with daveDecryptorDestroy() */ DAVE_EXPORT DAVEDecryptorHandle daveDecryptorCreate(void); /** * @brief Destroys a decryptor and frees associated resources * @param decryptor Decryptor handle to destroy */ DAVE_EXPORT void daveDecryptorDestroy(DAVEDecryptorHandle decryptor); /** * @brief Transitions the decryptor to use a new key ratchet * @param decryptor Decryptor handle * @param keyRatchet New key ratchet to transition to (does *not* take ownership) */ DAVE_EXPORT void daveDecryptorTransitionToKeyRatchet(DAVEDecryptorHandle decryptor, DAVEKeyRatchetHandle keyRatchet); /** * @brief Transitions to or from passthrough mode * @param decryptor Decryptor handle * @param passthroughMode true to enable passthrough, false to decrypt */ DAVE_EXPORT void daveDecryptorTransitionToPassthroughMode(DAVEDecryptorHandle decryptor, bool passthroughMode); /** * @brief Decrypts an encrypted media frame * @param decryptor Decryptor handle * @param mediaType Media type (audio or video) * @param encryptedFrame Pointer to the encrypted frame data * @param encryptedFrameLength Length of the encrypted frame * @param[out] frame Pointer to the output buffer the decrypted frame will be written to * @param frameCapacity Capacity of the output buffer * @param[out] bytesWritten Number of bytes written to the output buffer * @return Result code indicating success or failure */ DAVE_EXPORT DAVEDecryptorResultCode daveDecryptorDecrypt(DAVEDecryptorHandle decryptor, DAVEMediaType mediaType, const uint8_t* encryptedFrame, size_t encryptedFrameLength, uint8_t* frame, size_t frameCapacity, size_t* bytesWritten); /** * @brief Calculates the maximum plaintext size for a given ciphertext frame size * @param decryptor Decryptor handle * @param mediaType Media type (audio or video) * @param encryptedFrameSize Size of encrypted frame in bytes * @return Maximum possible plaintext size in bytes */ DAVE_EXPORT size_t daveDecryptorGetMaxPlaintextByteSize(DAVEDecryptorHandle decryptor, DAVEMediaType mediaType, size_t encryptedFrameSize); /** * @brief Gets decryption statistics * @param decryptor Decryptor handle * @param mediaType Media type (audio or video) * @param[out] stats Pointer to the stats structure to be filled */ DAVE_EXPORT void daveDecryptorGetStats(DAVEDecryptorHandle decryptor, DAVEMediaType mediaType, DAVEDecryptorStats* stats); /******************************************************************************* * Logging ******************************************************************************/ /** * @brief Sets a global callback for receiving log messages from the library * @param callback Log sink callback function */ DAVE_EXPORT void daveSetLogSinkCallback(DAVELogSinkCallback callback); #ifdef __cplusplus } #endif ================================================ FILE: cpp/includes/dave/dave_interfaces.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace mlspp { namespace bytes_ns { struct bytes; }; // namespace bytes_ns struct SignaturePrivateKey; } // namespace mlspp namespace discord { namespace dave { using EncryptorStats = DAVEEncryptorStats; using DecryptorStats = DAVEDecryptorStats; using KeyGeneration = uint32_t; using EncryptionKey = ::mlspp::bytes_ns::bytes; class MlsKeyRatchet; enum MediaType : uint8_t { Audio, Video }; enum Codec : uint8_t { Unknown, Opus, VP8, VP9, H264, H265, AV1 }; enum LoggingSeverity { LS_VERBOSE, LS_INFO, LS_WARNING, LS_ERROR, LS_NONE, }; // Returned in std::variant when a message is hard-rejected and should trigger a reset struct failed_t {}; // Returned in std::variant when a message is soft-rejected and should not trigger a reset struct ignored_t {}; // Map of ID-key pairs. // In ProcessCommit, this lists IDs whose keys have been added, changed, or removed; // an empty value value means a key was removed. using RosterMap = std::map>; // Return type for functions producing RosterMap or hard or soft failures using RosterVariant = std::variant; constexpr auto kDefaultTransitionDuration = std::chrono::seconds(10); class IKeyRatchet { public: virtual ~IKeyRatchet() noexcept = default; virtual EncryptionKey GetKey(KeyGeneration generation) noexcept = 0; virtual void DeleteKey(KeyGeneration generation) noexcept = 0; }; namespace mls { #if defined(__ANDROID__) typedef JNIEnv* KeyPairContextType; #else typedef const char* KeyPairContextType; #endif class ISession { public: virtual ~ISession() noexcept = default; virtual void Init(ProtocolVersion version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept = 0; virtual void Reset() noexcept = 0; virtual void SetProtocolVersion(ProtocolVersion version) noexcept = 0; virtual ProtocolVersion GetProtocolVersion() const noexcept = 0; virtual std::vector GetLastEpochAuthenticator() const noexcept = 0; virtual void SetExternalSender(std::vector const& externalSenderPackage) noexcept = 0; virtual std::optional> ProcessProposals( std::vector proposals, std::set const& recognizedUserIDs) noexcept = 0; virtual RosterVariant ProcessCommit(std::vector commit) noexcept = 0; virtual std::optional ProcessWelcome( std::vector welcome, std::set const& recognizedUserIDs) noexcept = 0; virtual std::vector GetMarshalledKeyPackage() noexcept = 0; virtual std::unique_ptr GetKeyRatchet( std::string const& userId) const noexcept = 0; using PairwiseFingerprintCallback = std::function const&)>; virtual void GetPairwiseFingerprint(uint16_t version, std::string const& userId, PairwiseFingerprintCallback callback) const noexcept = 0; }; using MLSFailureCallback = std::function; std::unique_ptr CreateSession(KeyPairContextType context, std::string authSessionId, MLSFailureCallback callback) noexcept; } // namespace mls class IEncryptor { public: enum ResultCode { Success, EncryptionFailure, MissingKeyRatchet, MissingCryptor, TooManyAttempts, }; virtual ~IEncryptor() = default; virtual void SetKeyRatchet(std::unique_ptr keyRatchet) = 0; virtual void SetPassthroughMode(bool passthroughMode) = 0; virtual bool HasKeyRatchet() const = 0; virtual bool IsPassthroughMode() const = 0; virtual void AssignSsrcToCodec(uint32_t ssrc, Codec codecType) = 0; virtual Codec CodecForSsrc(uint32_t ssrc) = 0; virtual ResultCode Encrypt(MediaType mediaType, uint32_t ssrc, ArrayView frame, ArrayView encryptedFrame, size_t* bytesWritten) = 0; virtual size_t GetMaxCiphertextByteSize(MediaType mediaType, size_t frameSize) = 0; virtual EncryptorStats GetStats(MediaType mediaType) const = 0; using ProtocolVersionChangedCallback = std::function; virtual void SetProtocolVersionChangedCallback(ProtocolVersionChangedCallback callback) = 0; virtual ProtocolVersion GetProtocolVersion() const = 0; }; std::unique_ptr CreateEncryptor(); class IDecryptor { public: using Duration = std::chrono::seconds; enum ResultCode { Success, DecryptionFailure, MissingKeyRatchet, InvalidNonce, MissingCryptor, }; virtual ~IDecryptor() = default; virtual void TransitionToKeyRatchet(std::unique_ptr keyRatchet, Duration transitionExpiry = kDefaultTransitionDuration) = 0; virtual void TransitionToPassthroughMode( bool passthroughMode, Duration transitionExpiry = kDefaultTransitionDuration) = 0; virtual ResultCode Decrypt(MediaType mediaType, ArrayView encryptedFrame, ArrayView frame, size_t* bytesWritten) = 0; virtual size_t GetMaxPlaintextByteSize(MediaType mediaType, size_t encryptedFrameSize) = 0; virtual DecryptorStats GetStats(MediaType mediaType) const = 0; }; std::unique_ptr CreateDecryptor(); static_assert(DAVE_CODEC_UNKNOWN == static_cast(Codec::Unknown)); static_assert(DAVE_CODEC_OPUS == static_cast(Codec::Opus)); static_assert(DAVE_CODEC_VP8 == static_cast(Codec::VP8)); static_assert(DAVE_CODEC_VP9 == static_cast(Codec::VP9)); static_assert(DAVE_CODEC_H264 == static_cast(Codec::H264)); static_assert(DAVE_CODEC_H265 == static_cast(Codec::H265)); static_assert(DAVE_CODEC_AV1 == static_cast(Codec::AV1)); static_assert(DAVE_MEDIA_TYPE_AUDIO == static_cast(MediaType::Audio)); static_assert(DAVE_MEDIA_TYPE_VIDEO == static_cast(MediaType::Video)); static_assert(DAVE_ENCRYPTOR_RESULT_CODE_SUCCESS == static_cast(IEncryptor::Success)); static_assert(DAVE_ENCRYPTOR_RESULT_CODE_ENCRYPTION_FAILURE == static_cast(IEncryptor::EncryptionFailure)); static_assert(DAVE_ENCRYPTOR_RESULT_CODE_MISSING_KEY_RATCHET == static_cast(IEncryptor::MissingKeyRatchet)); static_assert(DAVE_ENCRYPTOR_RESULT_CODE_MISSING_CRYPTOR == static_cast(IEncryptor::MissingCryptor)); static_assert(DAVE_ENCRYPTOR_RESULT_CODE_TOO_MANY_ATTEMPTS == static_cast(IEncryptor::TooManyAttempts)); static_assert(DAVE_DECRYPTOR_RESULT_CODE_SUCCESS == static_cast(IDecryptor::Success)); static_assert(DAVE_DECRYPTOR_RESULT_CODE_DECRYPTION_FAILURE == static_cast(IDecryptor::DecryptionFailure)); static_assert(DAVE_DECRYPTOR_RESULT_CODE_MISSING_KEY_RATCHET == static_cast(IDecryptor::MissingKeyRatchet)); static_assert(DAVE_DECRYPTOR_RESULT_CODE_INVALID_NONCE == static_cast(IDecryptor::InvalidNonce)); static_assert(DAVE_DECRYPTOR_RESULT_CODE_MISSING_CRYPTOR == static_cast(IDecryptor::MissingCryptor)); static_assert(DAVE_LOGGING_SEVERITY_VERBOSE == static_cast(LoggingSeverity::LS_VERBOSE)); static_assert(DAVE_LOGGING_SEVERITY_INFO == static_cast(LoggingSeverity::LS_INFO)); static_assert(DAVE_LOGGING_SEVERITY_WARNING == static_cast(LoggingSeverity::LS_WARNING)); static_assert(DAVE_LOGGING_SEVERITY_ERROR == static_cast(LoggingSeverity::LS_ERROR)); static_assert(DAVE_LOGGING_SEVERITY_NONE == static_cast(LoggingSeverity::LS_NONE)); } // namespace dave } // namespace discord ================================================ FILE: cpp/includes/dave/logger.h ================================================ #pragma once #include #include #if !defined(DISCORD_LOG) #define DISCORD_LOG_FILE_LINE(sev, file, line) ::discord::dave::LogStreamer(sev, file, line) #define DISCORD_LOG(sev) DISCORD_LOG_FILE_LINE(::discord::dave::sev, __FILE__, __LINE__) #endif namespace discord { namespace dave { using LogSink = void (*)(LoggingSeverity severity, const char* file, int line, const std::string& message); void SetLogSink(LogSink sink); class LogStreamer { public: LogStreamer(LoggingSeverity severity, const char* file, int line); ~LogStreamer(); template LogStreamer& operator<<(const T& value) { stream_ << value; return *this; } private: LoggingSeverity severity_; const char* file_; int line_; std::ostringstream stream_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/includes/dave/version.h ================================================ #pragma once #include #include namespace discord { namespace dave { using ProtocolVersion = uint16_t; using SignatureVersion = uint8_t; ProtocolVersion MaxSupportedProtocolVersion(); } // namespace dave } // namespace discord ================================================ FILE: cpp/src/bindings_capi.cpp ================================================ #include #include #include #include #include #include #include #include "mls_key_ratchet.h" #define ARG_CHECK(arg) \ if (arg == nullptr) { \ fprintf(stderr, "ERROR: %s is null\n", #arg); \ assert(false); \ return; \ } #define ARG_CHECK_RET(arg, ret) \ if (arg == nullptr) { \ fprintf(stderr, "ERROR: %s is null\n", #arg); \ assert(false); \ return ret; \ } std::unique_ptr CopyKeyRatchet(DAVEKeyRatchetHandle keyRatchet) { auto mlsKeyRatchet = reinterpret_cast(keyRatchet); if (!mlsKeyRatchet) { return nullptr; } auto hashRatchet = mlsKeyRatchet->GetHashRatchet(); auto cipherSuite = hashRatchet.suite; auto baseSecret = hashRatchet.next_secret; return std::make_unique(cipherSuite, std::move(baseSecret)); } void CopyVectorToOutputBuffer(std::vector const& vector, uint8_t** data, size_t* length) { if (data == nullptr || length == nullptr) { return; } if (vector.empty()) { *data = nullptr; *length = 0; return; } *data = reinterpret_cast(malloc(vector.size())); memcpy(*data, vector.data(), vector.size()); *length = vector.size(); } void GetRosterMemberIds(const discord::dave::RosterMap& rosterMap, uint64_t** rosterIds, size_t* rosterIdsLength) { *rosterIdsLength = rosterMap.size(); *rosterIds = reinterpret_cast(malloc(*rosterIdsLength * sizeof(uint64_t))); size_t i = 0; for (const auto& [key, value] : rosterMap) { (*rosterIds)[i++] = key; } } void GetRosterMemberSignature(const discord::dave::RosterMap& rosterMap, uint64_t rosterId, uint8_t** signature, size_t* signatureLength) { CopyVectorToOutputBuffer(rosterMap.at(rosterId), signature, signatureLength); } uint16_t daveMaxSupportedProtocolVersion(void) { return discord::dave::MaxSupportedProtocolVersion(); } void daveFree(void* ptr) { free(ptr); } DAVESessionHandle daveSessionCreate(void* context, const char* authSessionId, DAVEMLSFailureCallback callback, void* userData) { discord::dave::mls::MLSFailureCallback mlsFailureCallback; if (callback != nullptr) { mlsFailureCallback = [callback, userData](std::string source, std::string reason) { callback(source.c_str(), reason.c_str(), userData); }; }; auto contextType = static_cast(context); auto authSessionIdStr = authSessionId ? std::string(authSessionId) : std::string(); auto session = discord::dave::mls::CreateSession(contextType, authSessionIdStr, mlsFailureCallback); return reinterpret_cast(session.release()); } void daveSessionDestroy(DAVESessionHandle sessionHandle) { auto session = reinterpret_cast(sessionHandle); delete session; } void daveSessionInit(DAVESessionHandle sessionHandle, uint16_t version, uint64_t groupId, const char* selfUserId) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); auto selfUserIdStr = selfUserId ? std::string(selfUserId) : std::string(); std::shared_ptr<::mlspp::SignaturePrivateKey> transientKey; session->Init(version, groupId, selfUserIdStr, transientKey); } void daveSessionReset(DAVESessionHandle sessionHandle) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); session->Reset(); } void daveSessionSetProtocolVersion(DAVESessionHandle sessionHandle, uint16_t version) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); session->SetProtocolVersion(version); } uint16_t daveSessionGetProtocolVersion(DAVESessionHandle sessionHandle) { ARG_CHECK_RET(sessionHandle, 0); auto session = reinterpret_cast(sessionHandle); return session->GetProtocolVersion(); } void daveSessionGetLastEpochAuthenticator(DAVESessionHandle sessionHandle, uint8_t** authenticator, size_t* length) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); auto lastEpochAuthenticator = session->GetLastEpochAuthenticator(); CopyVectorToOutputBuffer(lastEpochAuthenticator, authenticator, length); } void daveSessionSetExternalSender(DAVESessionHandle sessionHandle, const uint8_t* externalSender, size_t length) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); auto externalSenderVec = std::vector(externalSender, externalSender + length); session->SetExternalSender(externalSenderVec); } void daveSessionProcessProposals(DAVESessionHandle sessionHandle, const uint8_t* proposals, size_t length, const char** recognizedUserIds, size_t recognizedUserIdsLength, uint8_t** commitWelcomeBytes, size_t* commitWelcomeBytesLength) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); auto proposalsVec = std::vector(proposals, proposals + length); auto recognizedUserIdsSet = std::set(recognizedUserIds, recognizedUserIds + recognizedUserIdsLength); auto result = session->ProcessProposals(std::move(proposalsVec), std::move(recognizedUserIdsSet)); if (result) { CopyVectorToOutputBuffer(*result, commitWelcomeBytes, commitWelcomeBytesLength); } } DAVECommitResultHandle daveSessionProcessCommit(DAVESessionHandle sessionHandle, const uint8_t* commit, size_t length) { ARG_CHECK_RET(sessionHandle, nullptr); auto session = reinterpret_cast(sessionHandle); auto commitVec = std::vector(commit, commit + length); auto rosterVariant = session->ProcessCommit(std::move(commitVec)); auto rosterVariantPtr = new discord::dave::RosterVariant(std::move(rosterVariant)); return reinterpret_cast(rosterVariantPtr); } DAVEWelcomeResultHandle daveSessionProcessWelcome(DAVESessionHandle sessionHandle, const uint8_t* welcome, size_t length, const char** recognizedUserIds, size_t recognizedUserIdsLength) { ARG_CHECK_RET(sessionHandle, nullptr); auto session = reinterpret_cast(sessionHandle); auto welcomeVec = std::vector(welcome, welcome + length); auto recognizedUserIdsSet = std::set(recognizedUserIds, recognizedUserIds + recognizedUserIdsLength); auto result = session->ProcessWelcome(std::move(welcomeVec), std::move(recognizedUserIdsSet)); if (!result) { return nullptr; } auto rosterMapPtr = new discord::dave::RosterMap(std::move(*result)); return reinterpret_cast(rosterMapPtr); } void daveSessionGetMarshalledKeyPackage(DAVESessionHandle sessionHandle, uint8_t** keyPackage, size_t* length) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); auto keyPackageVec = session->GetMarshalledKeyPackage(); CopyVectorToOutputBuffer(keyPackageVec, keyPackage, length); } DAVEKeyRatchetHandle daveSessionGetKeyRatchet(DAVESessionHandle sessionHandle, const char* userId) { ARG_CHECK_RET(sessionHandle, nullptr); auto session = reinterpret_cast(sessionHandle); auto userIdStr = userId ? std::string(userId) : std::string(); auto keyRatchetPtr = session->GetKeyRatchet(userIdStr); return reinterpret_cast(keyRatchetPtr.release()); } void daveSessionGetPairwiseFingerprint(DAVESessionHandle sessionHandle, uint16_t version, const char* userId, DAVEPairwiseFingerprintCallback callback, void* userData) { ARG_CHECK(sessionHandle); auto session = reinterpret_cast(sessionHandle); auto userIdStr = userId ? std::string(userId) : std::string(); session->GetPairwiseFingerprint( version, userIdStr, [callback, userData](std::vector const& fingerprint) { callback(fingerprint.data(), fingerprint.size(), userData); }); } void daveKeyRatchetDestroy(DAVEKeyRatchetHandle keyRatchet) { delete reinterpret_cast(keyRatchet); } bool daveCommitResultIsFailed(DAVECommitResultHandle commitResultHandle) { ARG_CHECK_RET(commitResultHandle, false); auto commitResult = reinterpret_cast(commitResultHandle); return std::holds_alternative(*commitResult); } bool daveCommitResultIsIgnored(DAVECommitResultHandle commitResultHandle) { ARG_CHECK_RET(commitResultHandle, false); auto commitResult = reinterpret_cast(commitResultHandle); return std::holds_alternative(*commitResult); } void daveCommitResultGetRosterMemberIds(DAVECommitResultHandle commitResultHandle, uint64_t** rosterIds, size_t* rosterIdsLength) { ARG_CHECK(commitResultHandle); auto commitResult = reinterpret_cast(commitResultHandle); if (!std::holds_alternative(*commitResult)) { *rosterIds = nullptr; *rosterIdsLength = 0; return; } GetRosterMemberIds( std::get(*commitResult), rosterIds, rosterIdsLength); } void daveCommitResultGetRosterMemberSignature(DAVECommitResultHandle commitResultHandle, uint64_t rosterId, uint8_t** signature, size_t* signatureLength) { ARG_CHECK(commitResultHandle); auto commitResult = reinterpret_cast(commitResultHandle); if (!std::holds_alternative(*commitResult)) { *signature = nullptr; *signatureLength = 0; return; } GetRosterMemberSignature( std::get(*commitResult), rosterId, signature, signatureLength); } void daveCommitResultDestroy(DAVECommitResultHandle commitResultHandle) { auto commitResult = reinterpret_cast(commitResultHandle); delete commitResult; } void daveWelcomeResultGetRosterMemberIds(DAVEWelcomeResultHandle welcomeResultHandle, uint64_t** rosterIds, size_t* rosterIdsLength) { ARG_CHECK(welcomeResultHandle); auto welcomeResult = reinterpret_cast(welcomeResultHandle); GetRosterMemberIds(*welcomeResult, rosterIds, rosterIdsLength); } void daveWelcomeResultGetRosterMemberSignature(DAVEWelcomeResultHandle welcomeResultHandle, uint64_t rosterId, uint8_t** signature, size_t* signatureLength) { ARG_CHECK(welcomeResultHandle); auto welcomeResult = reinterpret_cast(welcomeResultHandle); GetRosterMemberSignature(*welcomeResult, rosterId, signature, signatureLength); } void daveWelcomeResultDestroy(DAVEWelcomeResultHandle welcomeResultHandle) { auto welcomeResult = reinterpret_cast(welcomeResultHandle); delete welcomeResult; } DAVEEncryptorHandle daveEncryptorCreate() { auto encryptor = discord::dave::CreateEncryptor(); return reinterpret_cast(encryptor.release()); } void daveEncryptorDestroy(DAVEEncryptorHandle encryptorHandle) { auto encryptor = reinterpret_cast(encryptorHandle); delete encryptor; } void daveEncryptorSetKeyRatchet(DAVEEncryptorHandle encryptorHandle, DAVEKeyRatchetHandle keyRatchet) { ARG_CHECK(encryptorHandle); auto encryptor = reinterpret_cast(encryptorHandle); auto keyRatchetCopy = CopyKeyRatchet(keyRatchet); encryptor->SetKeyRatchet(std::move(keyRatchetCopy)); } void daveEncryptorSetPassthroughMode(DAVEEncryptorHandle encryptorHandle, bool passthroughMode) { ARG_CHECK(encryptorHandle); auto encryptor = reinterpret_cast(encryptorHandle); encryptor->SetPassthroughMode(passthroughMode); } void daveEncryptorAssignSsrcToCodec(DAVEEncryptorHandle encryptorHandle, uint32_t ssrc, DAVECodec codecType) { ARG_CHECK(encryptorHandle); auto encryptor = reinterpret_cast(encryptorHandle); encryptor->AssignSsrcToCodec(ssrc, static_cast(codecType)); } uint16_t daveEncryptorGetProtocolVersion(DAVEEncryptorHandle encryptorHandle) { ARG_CHECK_RET(encryptorHandle, 0); auto encryptor = reinterpret_cast(encryptorHandle); return encryptor->GetProtocolVersion(); } size_t daveEncryptorGetMaxCiphertextByteSize(DAVEEncryptorHandle encryptorHandle, DAVEMediaType mediaType, size_t frameSize) { ARG_CHECK_RET(encryptorHandle, 0); auto encryptor = reinterpret_cast(encryptorHandle); return encryptor->GetMaxCiphertextByteSize(static_cast(mediaType), frameSize); } bool daveEncryptorHasKeyRatchet(DAVEEncryptorHandle encryptorHandle) { ARG_CHECK_RET(encryptorHandle, false); auto encryptor = reinterpret_cast(encryptorHandle); return encryptor->HasKeyRatchet(); } bool daveEncryptorIsPassthroughMode(DAVEEncryptorHandle encryptorHandle) { ARG_CHECK_RET(encryptorHandle, false); auto encryptor = reinterpret_cast(encryptorHandle); return encryptor->IsPassthroughMode(); } DAVEEncryptorResultCode daveEncryptorEncrypt(DAVEEncryptorHandle encryptorHandle, DAVEMediaType mediaType, uint32_t ssrc, const uint8_t* frame, size_t frameLength, uint8_t* encryptedFrame, size_t encryptedFrameCapacity, size_t* bytesWritten) { ARG_CHECK_RET(encryptorHandle, DAVE_ENCRYPTOR_RESULT_CODE_ENCRYPTION_FAILURE); auto encryptor = reinterpret_cast(encryptorHandle); auto frameView = discord::dave::MakeArrayView(frame, frameLength); auto encryptedFrameView = discord::dave::MakeArrayView(encryptedFrame, encryptedFrameCapacity); auto result = encryptor->Encrypt(static_cast(mediaType), ssrc, frameView, encryptedFrameView, bytesWritten); return static_cast(result); } void daveEncryptorSetProtocolVersionChangedCallback( DAVEEncryptorHandle encryptorHandle, DAVEEncryptorProtocolVersionChangedCallback callback, void* userData) { ARG_CHECK(encryptorHandle); auto encryptor = reinterpret_cast(encryptorHandle); encryptor->SetProtocolVersionChangedCallback([callback, userData]() { callback(userData); }); } void daveEncryptorGetStats(DAVEEncryptorHandle encryptorHandle, DAVEMediaType mediaType, DAVEEncryptorStats* stats) { ARG_CHECK(encryptorHandle); auto encryptor = reinterpret_cast(encryptorHandle); *stats = encryptor->GetStats(static_cast(mediaType)); } DAVEDecryptorHandle daveDecryptorCreate() { auto decryptor = discord::dave::CreateDecryptor(); return reinterpret_cast(decryptor.release()); } void daveDecryptorDestroy(DAVEDecryptorHandle decryptorHandle) { auto decryptor = reinterpret_cast(decryptorHandle); delete decryptor; } void daveDecryptorTransitionToKeyRatchet(DAVEDecryptorHandle decryptorHandle, DAVEKeyRatchetHandle keyRatchet) { ARG_CHECK(decryptorHandle); auto decryptor = reinterpret_cast(decryptorHandle); auto keyRatchetCopy = CopyKeyRatchet(keyRatchet); decryptor->TransitionToKeyRatchet(std::move(keyRatchetCopy)); } void daveDecryptorTransitionToPassthroughMode(DAVEDecryptorHandle decryptorHandle, bool passthroughMode) { ARG_CHECK(decryptorHandle); auto decryptor = reinterpret_cast(decryptorHandle); decryptor->TransitionToPassthroughMode(passthroughMode); } DAVEDecryptorResultCode daveDecryptorDecrypt(DAVEDecryptorHandle decryptorHandle, DAVEMediaType mediaType, const uint8_t* encryptedFrame, size_t encryptedFrameLength, uint8_t* frame, size_t frameCapacity, size_t* bytesWritten) { ARG_CHECK_RET(decryptorHandle, DAVE_DECRYPTOR_RESULT_CODE_DECRYPTION_FAILURE); auto decryptor = reinterpret_cast(decryptorHandle); auto encryptedFrameView = discord::dave::MakeArrayView(encryptedFrame, encryptedFrameLength); auto frameView = discord::dave::MakeArrayView(frame, frameCapacity); auto result = decryptor->Decrypt(static_cast(mediaType), encryptedFrameView, frameView, bytesWritten); return static_cast(result); } size_t daveDecryptorGetMaxPlaintextByteSize(DAVEDecryptorHandle decryptorHandle, DAVEMediaType mediaType, size_t encryptedFrameSize) { ARG_CHECK_RET(decryptorHandle, 0); auto decryptor = reinterpret_cast(decryptorHandle); return decryptor->GetMaxPlaintextByteSize(static_cast(mediaType), encryptedFrameSize); } void daveDecryptorGetStats(DAVEDecryptorHandle decryptorHandle, DAVEMediaType mediaType, DAVEDecryptorStats* stats) { ARG_CHECK(decryptorHandle); auto decryptor = reinterpret_cast(decryptorHandle); *stats = decryptor->GetStats(static_cast(mediaType)); } static std::atomic gLogSinkCallback{nullptr}; void LogSinkCallback(discord::dave::LoggingSeverity severity, const char* file, int line, const std::string& message) { auto callback = gLogSinkCallback.load(); if (callback) { callback(static_cast(severity), file, line, message.c_str()); } } void daveSetLogSinkCallback(DAVELogSinkCallback callback) { gLogSinkCallback.store(callback); discord::dave::SetLogSink(callback ? LogSinkCallback : nullptr); } ================================================ FILE: cpp/src/bindings_wasm.cpp ================================================ #include #include #include #include #include #include #include #include #include #include #include #include "common.h" #include "decryptor.h" #include "encryptor.h" #include "mls/parameters.h" #include "mls/session.h" #include "mls_key_ratchet.h" using namespace emscripten; namespace discord { namespace dave { val ToOwnedTypedArray(const uint8_t* data, size_t size) { val array = val::array(); for (size_t i = 0; i < size; i++) { array.call("push", data[i]); } return array; } val ToOwnedTypedArray(const ::mlspp::bytes_ns::bytes& data) { return ToOwnedTypedArray(data.data(), data.size()); } val ToOwnedTypedArray(const std::vector& data) { return ToOwnedTypedArray(data.data(), data.size()); } val MlsKeyRatchetToJS(std::unique_ptr keyRatchet) { if (!keyRatchet) { return val::null(); } auto hashRatchet = keyRatchet->GetHashRatchet(); auto value = val::object(); value.set("cipherSuite", static_cast(hashRatchet.suite.cipher_suite())); value.set("baseSecret", ToOwnedTypedArray(hashRatchet.next_secret)); return value; } std::unique_ptr MlsKeyRatchetFromJS(val keyRatchet) { if (keyRatchet.isNull()) { return nullptr; } auto cipherSuite = ::mlspp::CipherSuite( static_cast<::mlspp::CipherSuite::ID>(keyRatchet["cipherSuite"].as())); auto baseSecret = emscripten::convertJSArrayToNumberVector(keyRatchet["baseSecret"]); auto baseSecretBytes = ::mlspp::bytes_ns::bytes(baseSecret); return std::make_unique(cipherSuite, baseSecretBytes); } namespace mls { class TransientKeys { public: std::shared_ptr<::mlspp::SignaturePrivateKey> GetTransientPrivateKey(ProtocolVersion version) { auto it = keys_.find(version); if (it == keys_.end()) { auto ciphersuite = CiphersuiteForProtocolVersion(version); auto key = std::make_shared<::mlspp::SignaturePrivateKey>( ::mlspp::SignaturePrivateKey::generate(ciphersuite)); it = keys_.emplace(version, key).first; } return it->second; } void Clear() { keys_.clear(); } private: std::map> keys_; }; class SessionWrapper { public: SessionWrapper(std::string ctx, std::string authSessionId, val callback) { session_ = std::make_unique( ctx.c_str(), authSessionId, [callback](std::string source, std::string message) { callback(source, message); }); } void Init(ProtocolVersion version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) { session_->Init(version, groupId, selfUserId, transientKey); } void Reset() { session_->Reset(); } void SetProtocolVersion(ProtocolVersion version) { session_->SetProtocolVersion(version); } ProtocolVersion GetProtocolVersion() { return session_->GetProtocolVersion(); } val GetLastEpochAuthenticator() { return ToOwnedTypedArray(session_->GetLastEpochAuthenticator()); } void SetExternalSender(val externalSender) { if (!externalSender.isNull()) { std::vector externalSenderVec = emscripten::convertJSArrayToNumberVector(externalSender); session_->SetExternalSender(externalSenderVec); } else { DISCORD_LOG(LS_ERROR) << "External sender is null"; } } val ProcessProposals(val proposals, val recognizedUserIDs) { std::vector proposalsVec = emscripten::convertJSArrayToNumberVector(proposals); auto recognizedUserIDsVec = emscripten::vecFromJSArray(recognizedUserIDs); auto recognizedUserIDsSet = std::set(recognizedUserIDsVec.begin(), recognizedUserIDsVec.end()); auto bytes = session_->ProcessProposals(proposalsVec, recognizedUserIDsSet); if (!bytes) { return val::null(); } return ToOwnedTypedArray(*bytes); } val ProcessCommit(val commit) { std::vector commitVec = emscripten::convertJSArrayToNumberVector(commit); auto processedCommit = session_->ProcessCommit(commitVec); auto failed = std::holds_alternative(processedCommit); auto ignored = std::holds_alternative(processedCommit); auto rosterUpdate = GetOptional(std::move(processedCommit)); val result = val::object(); result.set("failed", failed); result.set("ignored", ignored); val rosterObj = val::null(); if (rosterUpdate) { rosterObj = val::object(); for (const auto& [key, value] : *rosterUpdate) { rosterObj.set(key, ToOwnedTypedArray(value)); } } result.set("rosterUpdate", rosterObj); return result; } val ProcessWelcome(val welcome, val recognizedUserIDs) { auto welcomeVec = emscripten::convertJSArrayToNumberVector(welcome); auto recognizedUserIDsVec = emscripten::vecFromJSArray(recognizedUserIDs); auto recognizedUserIDsSet = std::set(recognizedUserIDsVec.begin(), recognizedUserIDsVec.end()); auto roster = session_->ProcessWelcome(welcomeVec, recognizedUserIDsSet); if (!roster) { return val::null(); } val rosterObj = val::object(); for (const auto& [key, value] : *roster) { rosterObj.set(key, ToOwnedTypedArray(value)); } return rosterObj; } val GetMarshalledKeyPackage() { return ToOwnedTypedArray(session_->GetMarshalledKeyPackage()); } val GetKeyRatchet(std::string const& userId) { auto keyRatchet = session_->GetKeyRatchet(userId); auto mlsKeyRatchet = std::unique_ptr(static_cast(keyRatchet.release())); return MlsKeyRatchetToJS(std::move(mlsKeyRatchet)); } private: std::unique_ptr session_; }; } // namespace mls class EncryptorWrapper { public: EncryptorWrapper() { encryptor_ = std::make_unique(); } void SetKeyRatchet(val keyRatchet) { encryptor_->SetKeyRatchet(MlsKeyRatchetFromJS(keyRatchet)); } void SetPassthroughMode(bool passthroughMode) { encryptor_->SetPassthroughMode(passthroughMode); } void AssignSsrcToCodec(uint32_t ssrc, Codec codecType) { encryptor_->AssignSsrcToCodec(ssrc, codecType); } ProtocolVersion GetProtocolVersion() { return encryptor_->GetProtocolVersion(); } size_t GetMaxCiphertextByteSize(MediaType mediaType, size_t plaintextByteSize) { return encryptor_->GetMaxCiphertextByteSize(mediaType, plaintextByteSize); } size_t Encrypt(MediaType mediaType, uint32_t ssrc, int framePtr, size_t frameLength, size_t frameCapacity) { auto frame = reinterpret_cast(framePtr); auto frameView = MakeArrayView(const_cast(frame), frameLength); auto encryptedFrameMaxSize = GetMaxCiphertextByteSize(mediaType, frameLength); if (frameCapacity < encryptedFrameMaxSize) { DISCORD_LOG(LS_ERROR) << "Frame capacity is less than the maximum ciphertext size"; return 0; } auto encryptedFrameView = MakeArrayView(frame, encryptedFrameMaxSize); size_t bytesWritten = 0; auto result = encryptor_->Encrypt(mediaType, ssrc, frameView, encryptedFrameView, &bytesWritten); if (result != 0) { return 0; } return bytesWritten; } void SetProtocolVersionChangedCallback(val callback) { encryptor_->SetProtocolVersionChangedCallback([callback]() { callback(); }); } private: std::unique_ptr encryptor_; }; class DecryptorWrapper { public: DecryptorWrapper() { decryptor_ = std::make_unique(); } void TransitionToKeyRatchet(val keyRatchet) { decryptor_->TransitionToKeyRatchet(MlsKeyRatchetFromJS(keyRatchet)); } void TransitionToPassthroughMode(bool passthroughMode) { decryptor_->TransitionToPassthroughMode(passthroughMode); } size_t GetMaxPlaintextByteSize(MediaType mediaType, size_t ciphertextByteSize) { return decryptor_->GetMaxPlaintextByteSize(mediaType, ciphertextByteSize); } size_t Decrypt(MediaType mediaType, int framePtr, size_t frameLength, size_t frameCapacity) { auto frame = reinterpret_cast(framePtr); auto frameView = MakeArrayView(const_cast(frame), frameLength); auto maxPlaintextByteSize = decryptor_->GetMaxPlaintextByteSize(mediaType, frameLength); if (frameCapacity < maxPlaintextByteSize) { DISCORD_LOG(LS_ERROR) << "Frame capacity is less than the maximum plaintext size"; return 0; } auto plaintextView = MakeArrayView(frame, maxPlaintextByteSize); size_t bytesWritten = 0; auto result = decryptor_->Decrypt(mediaType, frameView, plaintextView, &bytesWritten); if (result != Decryptor::ResultCode::Success) { return 0; } return bytesWritten; } private: std::unique_ptr decryptor_; }; } // namespace dave } // namespace discord EMSCRIPTEN_BINDINGS(dave) { constant("kInitTransitionId", discord::dave::kInitTransitionId); constant("kDisabledVersion", discord::dave::kDisabledVersion); enum_("MediaType") .value("Audio", discord::dave::MediaType::Audio) .value("Video", discord::dave::MediaType::Video); enum_("Codec") .value("Unknown", discord::dave::Codec::Unknown) .value("Opus", discord::dave::Codec::Opus) .value("VP8", discord::dave::Codec::VP8) .value("VP9", discord::dave::Codec::VP9) .value("H264", discord::dave::Codec::H264) .value("H265", discord::dave::Codec::H265) .value("AV1", discord::dave::Codec::AV1); function("MaxSupportedProtocolVersion", &discord::dave::MaxSupportedProtocolVersion); class_<::mlspp::SignaturePrivateKey>("SignaturePrivateKey") .smart_ptr>("SignaturePrivateKeyPtr"); class_("TransientKeys") .constructor<>() .function("GetTransientPrivateKey", &discord::dave::mls::TransientKeys::GetTransientPrivateKey) .function("Clear", &discord::dave::mls::TransientKeys::Clear); class_("Session") .constructor() .function("Init", &discord::dave::mls::SessionWrapper::Init) .function("Reset", &discord::dave::mls::SessionWrapper::Reset) .function("SetProtocolVersion", &discord::dave::mls::SessionWrapper::SetProtocolVersion) .function("GetProtocolVersion", &discord::dave::mls::SessionWrapper::GetProtocolVersion) .function("GetLastEpochAuthenticator", &discord::dave::mls::SessionWrapper::GetLastEpochAuthenticator) .function("SetExternalSender", &discord::dave::mls::SessionWrapper::SetExternalSender) .function("ProcessProposals", &discord::dave::mls::SessionWrapper::ProcessProposals) .function("ProcessCommit", &discord::dave::mls::SessionWrapper::ProcessCommit) .function("ProcessWelcome", &discord::dave::mls::SessionWrapper::ProcessWelcome) .function("GetMarshalledKeyPackage", &discord::dave::mls::SessionWrapper::GetMarshalledKeyPackage) .function("GetKeyRatchet", &discord::dave::mls::SessionWrapper::GetKeyRatchet); class_("Encryptor") .constructor<>() .function("SetKeyRatchet", &discord::dave::EncryptorWrapper::SetKeyRatchet) .function("SetPassthroughMode", &discord::dave::EncryptorWrapper::SetPassthroughMode) .function("AssignSsrcToCodec", &discord::dave::EncryptorWrapper::AssignSsrcToCodec) .function("GetProtocolVersion", &discord::dave::EncryptorWrapper::GetProtocolVersion) .function("GetMaxCiphertextByteSize", &discord::dave::EncryptorWrapper::GetMaxCiphertextByteSize) .function( "Encrypt", &discord::dave::EncryptorWrapper::Encrypt, emscripten::allow_raw_pointers()) .function("SetProtocolVersionChangedCallback", &discord::dave::EncryptorWrapper::SetProtocolVersionChangedCallback); class_("Decryptor") .constructor<>() .function("TransitionToKeyRatchet", &discord::dave::DecryptorWrapper::TransitionToKeyRatchet) .function("TransitionToPassthroughMode", &discord::dave::DecryptorWrapper::TransitionToPassthroughMode) .function("GetMaxPlaintextByteSize", &discord::dave::DecryptorWrapper::GetMaxPlaintextByteSize) .function( "Decrypt", &discord::dave::DecryptorWrapper::Decrypt, emscripten::allow_raw_pointers()); } ================================================ FILE: cpp/src/boringssl_cryptor.cpp ================================================ #include "boringssl_cryptor.h" #include #include #include #include "common.h" namespace discord { namespace dave { void PrintSSLErrors() { ERR_print_errors_cb( [](const char* str, size_t len, [[maybe_unused]] void* ctx) { DISCORD_LOG(LS_ERROR) << std::string(str, len); return 1; }, nullptr); } BoringSSLCryptor::BoringSSLCryptor(const EncryptionKey& encryptionKey) { EVP_AEAD_CTX_zero(&cipherCtx_); auto initResult = EVP_AEAD_CTX_init(&cipherCtx_, EVP_aead_aes_128_gcm(), encryptionKey.data(), encryptionKey.size(), kAesGcm128TruncatedTagBytes, nullptr); if (initResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to initialize AEAD context"; PrintSSLErrors(); } } BoringSSLCryptor::~BoringSSLCryptor() { EVP_AEAD_CTX_cleanup(&cipherCtx_); } bool BoringSSLCryptor::Encrypt(ArrayView ciphertextBufferOut, ArrayView plaintextBuffer, ArrayView nonceBuffer, ArrayView additionalData, ArrayView tagBufferOut) { if (cipherCtx_.aead == nullptr) { DISCORD_LOG(LS_ERROR) << "Encrypt: AEAD context is not initialized"; return false; } size_t tagSizeOut; auto encryptResult = EVP_AEAD_CTX_seal_scatter(&cipherCtx_, ciphertextBufferOut.data(), tagBufferOut.data(), &tagSizeOut, kAesGcm128TruncatedTagBytes, nonceBuffer.data(), kAesGcm128NonceBytes, plaintextBuffer.data(), plaintextBuffer.size(), nullptr, 0, additionalData.data(), additionalData.size()); if (encryptResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to encrypt data"; PrintSSLErrors(); } return encryptResult == 1; } bool BoringSSLCryptor::Decrypt(ArrayView plaintextBufferOut, ArrayView ciphertextBuffer, ArrayView tagBuffer, ArrayView nonceBuffer, ArrayView additionalData) { if (cipherCtx_.aead == nullptr) { DISCORD_LOG(LS_ERROR) << "Decrypt: AEAD context is not initialized"; return false; } auto decryptResult = EVP_AEAD_CTX_open_gather(&cipherCtx_, plaintextBufferOut.data(), nonceBuffer.data(), kAesGcm128NonceBytes, ciphertextBuffer.data(), ciphertextBuffer.size(), tagBuffer.data(), kAesGcm128TruncatedTagBytes, additionalData.data(), additionalData.size()); return decryptResult == 1; } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/boringssl_cryptor.h ================================================ #pragma once #include #include "cryptor.h" namespace discord { namespace dave { class BoringSSLCryptor : public ICryptor { public: BoringSSLCryptor(const EncryptionKey& encryptionKey); ~BoringSSLCryptor(); bool IsValid() const { return cipherCtx_.aead != nullptr; } bool Encrypt(ArrayView ciphertextBufferOut, ArrayView plaintextBuffer, ArrayView nonceBuffer, ArrayView additionalData, ArrayView tagBufferOut) override; bool Decrypt(ArrayView plaintextBufferOut, ArrayView ciphertextBuffer, ArrayView tagBuffer, ArrayView nonceBuffer, ArrayView additionalData) override; private: EVP_AEAD_CTX cipherCtx_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/codec_utils.cpp ================================================ #include "codec_utils.h" #include #include #include #include #include "common.h" #include "utils/leb128.h" namespace discord { namespace dave { namespace codec_utils { UnencryptedFrameHeaderSize BytesCoveringH264PPS(const uint8_t* payload, const uint64_t sizeRemaining) { // the payload starts with three exponential golomb encoded values // (first_mb_in_slice, sps_id, pps_id) // the depacketizer needs the pps_id unencrypted // and the payload has RBSP encoding that we need to work around constexpr uint8_t kEmulationPreventionByte = 0x03; uint64_t payloadBitIndex = 0; auto zeroBitCount = 0; auto parsedExpGolombValues = 0; while (payloadBitIndex < sizeRemaining * 8 && parsedExpGolombValues < 3) { auto bitIndex = payloadBitIndex % 8; auto byteIndex = payloadBitIndex / 8; auto payloadByte = payload[byteIndex]; // if we're starting a new byte // check if this is an emulation prevention byte // which we skip over if (bitIndex == 0) { if (byteIndex >= 2 && payloadByte == kEmulationPreventionByte && payload[byteIndex - 1] == 0 && payload[byteIndex - 2] == 0) { payloadBitIndex += 8; continue; } } if ((payloadByte & (1 << (7 - bitIndex))) == 0) { // still in the run of leading zero bits ++zeroBitCount; ++payloadBitIndex; if (zeroBitCount >= 32) { assert(false && "Unexpectedly large exponential golomb encoded value"); return 0; } } else { // we hit a one // skip forward the number of bits dictated by the leading number of zeroes parsedExpGolombValues += 1; payloadBitIndex += 1 + zeroBitCount; zeroBitCount = 0; } } // return the number of bytes that covers the last exp golomb encoded value auto result = (payloadBitIndex / 8) + 1; if (result > std::numeric_limits::max()) { DISCORD_LOG(LS_WARNING) << "BytesCoveringH264PPS result cannot fit in UnencryptedFrameHeaderSize"; return 0; } else { return static_cast(result); } } const uint8_t kH26XNaluLongStartCode[] = {0, 0, 0, 1}; constexpr uint8_t kH26XNaluShortStartSequenceSize = 3; using IndexStartCodeSizePair = std::pair; std::optional FindNextH26XNaluIndex(const uint8_t* buffer, const size_t bufferSize, const size_t searchStartIndex = 0) { constexpr uint8_t kH26XStartCodeHighestPossibleValue = 1; constexpr uint8_t kH26XStartCodeEndByteValue = 1; constexpr uint8_t kH26XStartCodeLeadingBytesValue = 0; if (bufferSize < kH26XNaluShortStartSequenceSize) { return std::nullopt; } // look for NAL unit 3 or 4 byte start code for (size_t i = searchStartIndex; i < bufferSize - kH26XNaluShortStartSequenceSize;) { if (buffer[i + 2] > kH26XStartCodeHighestPossibleValue) { // third byte is not 0 or 1, can't be a start code i += kH26XNaluShortStartSequenceSize; } else if (buffer[i + 2] == kH26XStartCodeEndByteValue) { // third byte matches the start code end byte, might be a start code sequence if (buffer[i + 1] == kH26XStartCodeLeadingBytesValue && buffer[i] == kH26XStartCodeLeadingBytesValue) { // confirmed start sequence {0, 0, 1} auto nalUnitStartIndex = i + kH26XNaluShortStartSequenceSize; if (i >= 1 && buffer[i - 1] == kH26XStartCodeLeadingBytesValue) { // 4 byte start code return std::optional({nalUnitStartIndex, 4}); } else { // 3 byte start code return std::optional({nalUnitStartIndex, 3}); } } i += kH26XNaluShortStartSequenceSize; } else { // third byte is 0, might be a four byte start code ++i; } } return std::nullopt; } bool ProcessFrameOpus(OutboundFrameProcessor& processor, ArrayView frame) { processor.AddEncryptedBytes(frame.data(), frame.size()); return true; } bool ProcessFrameVp8(OutboundFrameProcessor& processor, ArrayView frame) { constexpr uint8_t kVP8KeyFrameUnencryptedBytes = 10; constexpr uint8_t kVP8DeltaFrameUnencryptedBytes = 1; // parse the VP8 payload header to determine if it's a key frame // https://datatracker.ietf.org/doc/html/rfc7741#section-4.3 // 0 1 2 3 4 5 6 7 // +-+-+-+-+-+-+-+-+ // |Size0|H| VER |P| // +-+-+-+-+-+-+-+-+ // P is an inverse key frame flag // if this is a key frame the depacketizer will read 10 bytes into the payload header // if this is a delta frame the depacketizer only needs the first byte of the payload // header (since that's where the key frame flag is) size_t unencryptedHeaderBytes = 0; if ((frame.data()[0] & 0x01) == 0) { unencryptedHeaderBytes = kVP8KeyFrameUnencryptedBytes; } else { unencryptedHeaderBytes = kVP8DeltaFrameUnencryptedBytes; } processor.AddUnencryptedBytes(frame.data(), unencryptedHeaderBytes); processor.AddEncryptedBytes(frame.data() + unencryptedHeaderBytes, frame.size() - unencryptedHeaderBytes); return true; } bool ProcessFrameVp9(OutboundFrameProcessor& processor, ArrayView frame) { // payload descriptor is unencrypted in each packet // and includes all information the depacketizer needs processor.AddEncryptedBytes(frame.data(), frame.size()); return true; } bool ProcessFrameH264(OutboundFrameProcessor& processor, ArrayView frame) { // minimize the amount of unencrypted header data for H264 depending on the NAL unit // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h264.cc // src/common_video/h264/h264_common.cc // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h264.cc // constexpr uint8_t kH264SBit = 0x80; constexpr uint8_t kH264NalHeaderTypeMask = 0x1F; constexpr uint8_t kH264NalTypeSlice = 1; constexpr uint8_t kH264NalTypeIdr = 5; constexpr uint8_t kH264NalUnitHeaderSize = 1; // this frame can be packetized as a STAP-A or a FU-A // so we need to look at the first NAL units to determine how many bytes // the packetizer/depacketizer will need into the payload if (frame.size() < kH26XNaluShortStartSequenceSize + kH264NalUnitHeaderSize) { assert(false && "H264 frame is too small to contain a NAL unit"); DISCORD_LOG(LS_WARNING) << "H264 frame is too small to contain a NAL unit"; return false; } auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; auto nalType = frame.data()[nalUnitStartIndex] & kH264NalHeaderTypeMask; // copy the start code and then the NAL unit // Because WebRTC will convert them all start codes to 4-byte on the receiver side // always write a long start code and then the NAL unit processor.AddUnencryptedBytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size(); if (nalType == kH264NalTypeSlice || nalType == kH264NalTypeIdr) { // once we've hit a slice or an IDR // we just need to cover getting to the PPS ID auto nalUnitPayloadStart = nalUnitStartIndex + kH264NalUnitHeaderSize; auto nalUnitPPSBytes = BytesCoveringH264PPS(frame.data() + nalUnitPayloadStart, frame.size() - nalUnitPayloadStart); processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, kH264NalUnitHeaderSize + nalUnitPPSBytes); processor.AddEncryptedBytes( frame.data() + nalUnitStartIndex + kH264NalUnitHeaderSize + nalUnitPPSBytes, nextNaluStart - nalUnitStartIndex - kH264NalUnitHeaderSize - nalUnitPPSBytes); } else { // copy the whole NAL unit processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); } naluIndexPair = nextNaluIndexPair; } return true; } bool ProcessFrameH265(OutboundFrameProcessor& processor, ArrayView frame) { // minimize the amount of unencrypted header data for H265 depending on the NAL unit // type from WebRTC, see: src/modules/rtp_rtcp/source/rtp_format_h265.cc // src/common_video/h265/h265_common.cc // src/modules/rtp_rtcp/source/video_rtp_depacketizer_h265.cc constexpr uint8_t kH265NalHeaderTypeMask = 0x7E; constexpr uint8_t kH265NalTypeVclCutoff = 32; constexpr uint8_t kH265NalUnitHeaderSize = 2; // this frame can be packetized as a STAP-A or a FU-A // so we need to look at the first NAL units to determine how many bytes // the packetizer/depacketizer will need into the payload if (frame.size() < kH26XNaluShortStartSequenceSize + kH265NalUnitHeaderSize) { assert(false && "H265 frame is too small to contain a NAL unit"); DISCORD_LOG(LS_WARNING) << "H265 frame is too small to contain a NAL unit"; return false; } // look for NAL unit 3 or 4 byte start code auto naluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size()); while (naluIndexPair && naluIndexPair->first < frame.size() - 1) { auto [nalUnitStartIndex, startCodeSize] = *naluIndexPair; uint8_t nalType = (frame.data()[nalUnitStartIndex] & kH265NalHeaderTypeMask) >> 1; // copy the start code and then the NAL unit // Because WebRTC will convert them all start codes to 4-byte on the receiver side // always write a long start code and then the NAL unit processor.AddUnencryptedBytes(kH26XNaluLongStartCode, sizeof(kH26XNaluLongStartCode)); auto nextNaluIndexPair = FindNextH26XNaluIndex(frame.data(), frame.size(), nalUnitStartIndex); auto nextNaluStart = nextNaluIndexPair.has_value() ? nextNaluIndexPair->first - nextNaluIndexPair->second : frame.size(); if (nalType < kH265NalTypeVclCutoff) { // found a VCL NAL, encrypt the payload only processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, kH265NalUnitHeaderSize); processor.AddEncryptedBytes(frame.data() + nalUnitStartIndex + kH265NalUnitHeaderSize, nextNaluStart - nalUnitStartIndex - kH265NalUnitHeaderSize); } else { // copy the whole NAL unit processor.AddUnencryptedBytes(frame.data() + nalUnitStartIndex, nextNaluStart - nalUnitStartIndex); } naluIndexPair = nextNaluIndexPair; } return true; } bool ProcessFrameAv1(OutboundFrameProcessor& processor, ArrayView frame) { constexpr uint8_t kAv1ObuHeaderHasExtensionMask = 0b0'0000'100; constexpr uint8_t kAv1ObuHeaderHasSizeMask = 0b0'0000'010; constexpr uint8_t kAv1ObuHeaderTypeMask = 0b0'1111'000; constexpr uint8_t kObuTypeTemporalDelimiter = 2; constexpr uint8_t kObuTypeTileList = 8; constexpr uint8_t kObuTypePadding = 15; constexpr uint8_t kObuExtensionSizeBytes = 1; size_t i = 0; while (i < frame.size()) { // Read the OBU header. size_t obuHeaderIndex = i; uint8_t obuHeader = frame.data()[obuHeaderIndex]; i += sizeof(obuHeader); bool obuHasExtension = obuHeader & kAv1ObuHeaderHasExtensionMask; bool obuHasSize = obuHeader & kAv1ObuHeaderHasSizeMask; int obuType = (obuHeader & kAv1ObuHeaderTypeMask) >> 3; if (obuHasExtension) { // Skip extension byte i += kObuExtensionSizeBytes; } if (i >= frame.size()) { // Malformed frame assert(false && "Malformed AV1 frame: header overflows frame"); DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: header overflows frame"; return false; } size_t obuPayloadSize = 0; if (obuHasSize) { // Read payload size const uint8_t* start = frame.data() + i; const uint8_t* ptr = start; obuPayloadSize = ReadLeb128(ptr, frame.end()); if (!ptr) { // Malformed frame assert(false && "Malformed AV1 frame: invalid LEB128 size"); DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: invalid LEB128 size"; return false; } i += ptr - start; } else { // If the size is not present, the OBU extends to the end of the frame. obuPayloadSize = frame.size() - i; } const auto obuPayloadIndex = i; if (i + obuPayloadSize > frame.size()) { // Malformed frame assert(false && "Malformed AV1 frame: payload overflows frame"); DISCORD_LOG(LS_WARNING) << "Malformed AV1 frame: payload overflows frame"; return false; } i += obuPayloadSize; // We only copy the OBUs that will not get dropped by the packetizer if (obuType != kObuTypeTemporalDelimiter && obuType != kObuTypeTileList && obuType != kObuTypePadding) { // if this is the last OBU, we may need to flip the "has size" bit // which allows us to append necessary protocol data to the frame bool rewrittenWithoutSize = false; if (i == frame.size() && obuHasSize) { // Flip the "has size" bit obuHeader &= ~kAv1ObuHeaderHasSizeMask; rewrittenWithoutSize = true; } // write the OBU header unencrypted processor.AddUnencryptedBytes(&obuHeader, sizeof(obuHeader)); if (obuHasExtension) { // write the extension byte unencrypted processor.AddUnencryptedBytes(frame.data() + obuHeaderIndex + sizeof(obuHeader), kObuExtensionSizeBytes); } // write the OBU payload size unencrypted if it was present and we didn't rewrite // without it if (obuHasSize && !rewrittenWithoutSize) { // The AMD AV1 encoder may pad LEB128 encoded sizes with a zero byte which the // webrtc packetizer removes. To prevent the packetizer from changing the frame, // we sanitize the size by re-writing it ourselves uint8_t leb128Buffer[Leb128MaxSize]; size_t additionalBytesToWrite = WriteLeb128(obuPayloadSize, leb128Buffer); processor.AddUnencryptedBytes(leb128Buffer, additionalBytesToWrite); } // add the OBU payload, encrypted processor.AddEncryptedBytes(frame.data() + obuPayloadIndex, obuPayloadSize); } } return true; } bool ValidateEncryptedFrame(OutboundFrameProcessor& processor, ArrayView frame) { auto codec = processor.GetCodec(); if (codec != Codec::H264 && codec != Codec::H265) { return true; } static_assert(kH26XNaluShortStartSequenceSize - 1 >= 0, "Padding will overflow!"); constexpr size_t Padding = kH26XNaluShortStartSequenceSize - 1; const auto& unencryptedRanges = processor.GetUnencryptedRanges(); // H264 and H265 ciphertexts cannot contain a 3 or 4 byte start code {0, 0, 1} // otherwise the packetizer gets confused // and the frame we get on the decryption side will be shifted and fail to decrypt size_t encryptedSectionStart = 0; for (auto& range : unencryptedRanges) { if (encryptedSectionStart == range.offset) { encryptedSectionStart += range.size; continue; } auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); auto end = std::min(range.offset + Padding, frame.size()); if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { return false; } encryptedSectionStart = range.offset + range.size; } if (encryptedSectionStart == frame.size()) { return true; } auto start = encryptedSectionStart - std::min(encryptedSectionStart, size_t{Padding}); auto end = frame.size(); if (FindNextH26XNaluIndex(frame.data() + start, end - start)) { return false; } return true; } } // namespace codec_utils } // namespace dave } // namespace discord ================================================ FILE: cpp/src/codec_utils.h ================================================ #pragma once #include #include "common.h" #include "frame_processors.h" namespace discord { namespace dave { namespace codec_utils { bool ProcessFrameOpus(OutboundFrameProcessor& processor, ArrayView frame); bool ProcessFrameVp8(OutboundFrameProcessor& processor, ArrayView frame); bool ProcessFrameVp9(OutboundFrameProcessor& processor, ArrayView frame); bool ProcessFrameH264(OutboundFrameProcessor& processor, ArrayView frame); bool ProcessFrameH265(OutboundFrameProcessor& processor, ArrayView frame); bool ProcessFrameAv1(OutboundFrameProcessor& processor, ArrayView frame); bool ValidateEncryptedFrame(OutboundFrameProcessor& processor, ArrayView frame); } // namespace codec_utils } // namespace dave } // namespace discord ================================================ FILE: cpp/src/common.h ================================================ #pragma once #include #include #include #include #include #include #include #include namespace discord { namespace dave { using UnencryptedFrameHeaderSize = uint16_t; using TruncatedSyncNonce = uint32_t; using MagicMarker = uint16_t; using TransitionId = uint16_t; using SupplementalBytesSize = uint8_t; constexpr MagicMarker kMarkerBytes = 0xFAFA; // Layout constants constexpr size_t kAesGcm128KeyBytes = 16; constexpr size_t kAesGcm128NonceBytes = 12; constexpr size_t kAesGcm128TruncatedSyncNonceBytes = 4; constexpr size_t kAesGcm128TruncatedSyncNonceOffset = kAesGcm128NonceBytes - kAesGcm128TruncatedSyncNonceBytes; constexpr size_t kAesGcm128TruncatedTagBytes = 8; constexpr size_t kRatchetGenerationBytes = 1; constexpr size_t kRatchetGenerationShiftBits = 8 * (kAesGcm128TruncatedSyncNonceBytes - kRatchetGenerationBytes); constexpr size_t kSupplementalBytes = kAesGcm128TruncatedTagBytes + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); constexpr size_t kTransformPaddingBytes = 64; // Timing constants constexpr auto kCryptorExpiry = std::chrono::seconds(10); // Behavior constants constexpr auto kInitTransitionId = 0; constexpr auto kDisabledVersion = 0; constexpr auto kMaxGenerationGap = 250; constexpr auto kMaxMissingNonces = 1000; constexpr auto kGenerationWrap = 1 << (8 * kRatchetGenerationBytes); constexpr auto kMaxFramesPerSecond = 50 + 2 * 60; // 50 audio frames + 2 * 60fps video streams constexpr std::array kOpusSilencePacket = {0xF8, 0xFF, 0xFE}; // Utility routine for variant return types template inline std::optional GetOptional(V&& variant) { if (auto map = std::get_if(&variant)) { if constexpr (std::is_rvalue_reference_v) { return std::move(*map); } else { return *map; } } else { return std::nullopt; } } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/cryptor.cpp ================================================ #include "cryptor.h" #ifdef WITH_BORINGSSL #include "boringssl_cryptor.h" #else #include "openssl_cryptor.h" #endif namespace discord { namespace dave { std::unique_ptr CreateCryptor(const EncryptionKey& encryptionKey) { #ifdef WITH_BORINGSSL auto cryptor = std::make_unique(encryptionKey); #else auto cryptor = std::make_unique(encryptionKey); #endif return cryptor->IsValid() ? std::move(cryptor) : nullptr; } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/cryptor.h ================================================ #pragma once #include #include #include namespace discord { namespace dave { class ICryptor { public: virtual ~ICryptor() = default; virtual bool Encrypt(ArrayView ciphertextBufferOut, ArrayView plaintextBuffer, ArrayView nonceBuffer, ArrayView additionalData, ArrayView tagBufferOut) = 0; virtual bool Decrypt(ArrayView plaintextBufferOut, ArrayView ciphertextBuffer, ArrayView tagBuffer, ArrayView nonceBuffer, ArrayView additionalData) = 0; }; std::unique_ptr CreateCryptor(const EncryptionKey& encryptionKey); } // namespace dave } // namespace discord ================================================ FILE: cpp/src/cryptor_manager.cpp ================================================ #include "cryptor_manager.h" #include #include #include using namespace std::chrono_literals; namespace discord { namespace dave { KeyGeneration ComputeWrappedGeneration(KeyGeneration oldest, KeyGeneration generation) { // Assume generation is greater than or equal to oldest, this may be wrong in a few cases but // will be caught by the max generation gap check. auto remainder = oldest % kGenerationWrap; auto factor = oldest / kGenerationWrap + (generation < remainder ? 1 : 0); return factor * kGenerationWrap + generation; } BigNonce ComputeWrappedBigNonce(KeyGeneration generation, TruncatedSyncNonce nonce) { // Remove the generation bits from the nonce auto maskedNonce = nonce & ((1 << kRatchetGenerationShiftBits) - 1); // Add the wrapped generation bits back in return static_cast(generation) << kRatchetGenerationShiftBits | maskedNonce; } CryptorManager::CryptorManager(const IClock& clock, std::unique_ptr keyRatchet) : clock_(clock) , keyRatchet_(std::move(keyRatchet)) , ratchetCreation_(clock.Now()) , ratchetExpiry_(TimePoint::max()) { } bool CryptorManager::CanProcessNonce(KeyGeneration generation, TruncatedSyncNonce nonce) const { if (!newestProcessedNonce_) { return true; } auto bigNonce = ComputeWrappedBigNonce(generation, nonce); return bigNonce > *newestProcessedNonce_ || std::find(missingNonces_.rbegin(), missingNonces_.rend(), bigNonce) != missingNonces_.rend(); } ICryptor* CryptorManager::GetCryptor(KeyGeneration generation) { CleanupExpiredCryptors(); if (generation < oldestGeneration_) { DISCORD_LOG(LS_INFO) << "Received frame with old generation: " << generation << ", oldest generation: " << oldestGeneration_; return nullptr; } if (generation > newestGeneration_ + kMaxGenerationGap) { DISCORD_LOG(LS_INFO) << "Received frame with future generation: " << generation << ", newest generation: " << newestGeneration_; return nullptr; } auto ratchetLifetimeSec = std::chrono::duration_cast(clock_.Now() - ratchetCreation_).count(); auto maxLifetimeFrames = kMaxFramesPerSecond * ratchetLifetimeSec; auto maxLifetimeGenerations = maxLifetimeFrames >> kRatchetGenerationShiftBits; if (generation > maxLifetimeGenerations) { DISCORD_LOG(LS_INFO) << "Received frame with generation " << generation << " beyond ratchet max lifetime generations: " << maxLifetimeGenerations << ", ratchet lifetime: " << ratchetLifetimeSec << "s"; return nullptr; } auto it = cryptors_.find(generation); if (it == cryptors_.end()) { // We don't have a cryptor for this generation, create one std::tie(it, std::ignore) = cryptors_.emplace(generation, MakeExpiringCryptor(generation)); } // Return a non-owning pointer to the cryptor auto& [cryptor, expiry] = it->second; return cryptor.get(); } void CryptorManager::ReportCryptorSuccess(KeyGeneration generation, TruncatedSyncNonce nonce) { auto bigNonce = ComputeWrappedBigNonce(generation, nonce); // Add any missing nonces to the queue if (!newestProcessedNonce_) { newestProcessedNonce_ = bigNonce; } else if (bigNonce > *newestProcessedNonce_) { auto missingNonces = std::min(bigNonce - *newestProcessedNonce_ - 1, static_cast(kMaxMissingNonces)); while (!missingNonces_.empty() && missingNonces_.size() + missingNonces > kMaxMissingNonces) { missingNonces_.pop_front(); } for (auto i = bigNonce - missingNonces; i < bigNonce; ++i) { missingNonces_.push_back(i); } // Update the newest processed nonce newestProcessedNonce_ = bigNonce; } else { auto it = std::find(missingNonces_.begin(), missingNonces_.end(), bigNonce); if (it != missingNonces_.end()) { missingNonces_.erase(it); } } if (generation <= newestGeneration_ || cryptors_.find(generation) == cryptors_.end()) { return; } DISCORD_LOG(LS_INFO) << "Reporting cryptor success, generation: " << generation; newestGeneration_ = generation; // Update the expiry time for all old cryptors const auto expiryTime = clock_.Now() + kCryptorExpiry; for (auto& [gen, cryptor] : cryptors_) { if (gen < newestGeneration_) { DISCORD_LOG(LS_INFO) << "Updating expiry for cryptor, generation: " << gen; cryptor.expiry = std::min(cryptor.expiry, expiryTime); } } } KeyGeneration CryptorManager::ComputeWrappedGeneration(KeyGeneration generation) const { return ::discord::dave::ComputeWrappedGeneration(oldestGeneration_, generation); } CryptorManager::ExpiringCryptor CryptorManager::MakeExpiringCryptor(KeyGeneration generation) { // Get the new key from the ratchet auto encryptionKey = keyRatchet_->GetKey(generation); auto expiryTime = TimePoint::max(); // If we got frames out of order, we might have to create a cryptor for an old generation // In that case, create it with a non-infinite expiry time as we have already transitioned // to a newer generation if (generation < newestGeneration_) { DISCORD_LOG(LS_INFO) << "Creating cryptor for old generation: " << generation; expiryTime = clock_.Now() + kCryptorExpiry; } else { DISCORD_LOG(LS_INFO) << "Creating cryptor for new generation: " << generation; } return {CreateCryptor(encryptionKey), expiryTime}; } void CryptorManager::CleanupExpiredCryptors() { for (auto it = cryptors_.begin(); it != cryptors_.end();) { auto& [generation, cryptor] = *it; bool expired = cryptor.expiry < clock_.Now(); if (expired) { DISCORD_LOG(LS_INFO) << "Removing expired cryptor, generation: " << generation; } it = expired ? cryptors_.erase(it) : ++it; } while (oldestGeneration_ < newestGeneration_ && cryptors_.find(oldestGeneration_) == cryptors_.end()) { DISCORD_LOG(LS_INFO) << "Deleting key for old generation: " << oldestGeneration_; keyRatchet_->DeleteKey(oldestGeneration_); ++oldestGeneration_; } } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/cryptor_manager.h ================================================ #pragma once #include #include #include #include #include "common.h" #include "cryptor.h" #include "utils/clock.h" namespace discord { namespace dave { KeyGeneration ComputeWrappedGeneration(KeyGeneration oldest, KeyGeneration generation); using BigNonce = uint64_t; BigNonce ComputeWrappedBigNonce(KeyGeneration generation, TruncatedSyncNonce nonce); class CryptorManager { public: using TimePoint = typename IClock::TimePoint; CryptorManager(const IClock& clock, std::unique_ptr keyRatchet); void UpdateExpiry(TimePoint expiry) { ratchetExpiry_ = expiry; } bool IsExpired() const { return clock_.Now() > ratchetExpiry_; } bool CanProcessNonce(KeyGeneration generation, TruncatedSyncNonce nonce) const; KeyGeneration ComputeWrappedGeneration(KeyGeneration generation) const; ICryptor* GetCryptor(KeyGeneration generation); void ReportCryptorSuccess(KeyGeneration generation, TruncatedSyncNonce nonce); private: struct ExpiringCryptor { std::unique_ptr cryptor; TimePoint expiry; }; ExpiringCryptor MakeExpiringCryptor(KeyGeneration generation); void CleanupExpiredCryptors(); const IClock& clock_; std::unique_ptr keyRatchet_; std::unordered_map cryptors_; TimePoint ratchetCreation_; TimePoint ratchetExpiry_; KeyGeneration oldestGeneration_{0}; KeyGeneration newestGeneration_{0}; std::optional newestProcessedNonce_; std::deque missingNonces_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/decryptor.cpp ================================================ #include "decryptor.h" #include #include #include #include "common.h" #include "utils/leb128.h" #include "utils/scope_exit.h" using namespace std::chrono_literals; namespace discord { namespace dave { constexpr auto kStatsInterval = 10s; std::unique_ptr CreateDecryptor() { return std::make_unique(); } void Decryptor::TransitionToKeyRatchet(std::unique_ptr keyRatchet, Duration transitionExpiry) { DISCORD_LOG(LS_INFO) << "Transitioning to new key ratchet: " << keyRatchet.get() << ", expiry: " << transitionExpiry.count(); // Update the expiry time for all existing cryptor managers UpdateCryptorManagerExpiry(transitionExpiry); if (keyRatchet) { cryptorManagers_.emplace_back(clock_, std::move(keyRatchet)); } } void Decryptor::TransitionToPassthroughMode(bool passthroughMode, Duration transitionExpiry) { if (passthroughMode) { allowPassThroughUntil_ = TimePoint::max(); } else { // Update the pass through mode expiry auto maxExpiry = clock_.Now() + transitionExpiry; allowPassThroughUntil_ = std::min(allowPassThroughUntil_, maxExpiry); } } Decryptor::ResultCode Decryptor::Decrypt(MediaType mediaType, ArrayView encryptedFrame, ArrayView frame, size_t* bytesWritten) { if (mediaType != Audio && mediaType != Video) { DISCORD_LOG(LS_WARNING) << "Decrypt failed, invalid media type: " << static_cast(mediaType); *bytesWritten = 0; return ResultCode::DecryptionFailure; } auto& stats = stats_[mediaType]; auto start = clock_.Now(); auto localFrame = GetOrCreateFrameProcessor(); ScopeExit cleanup([&] { ReturnFrameProcessor(std::move(localFrame)); }); // Skip decrypting for silence frames if (mediaType == Audio && encryptedFrame.size() == kOpusSilencePacket.size() && memcmp(encryptedFrame.data(), kOpusSilencePacket.data(), kOpusSilencePacket.size()) == 0) { DISCORD_LOG(LS_VERBOSE) << "Decrypt skipping silence of size: " << encryptedFrame.size(); auto copySize = std::min(frame.size(), encryptedFrame.size()); if (encryptedFrame.data() != frame.data()) { memcpy(frame.data(), encryptedFrame.data(), copySize); } *bytesWritten = copySize; return ResultCode::Success; } // Remove any expired cryptor manager CleanupExpiredCryptorManagers(); // Process the incoming frame // This will check whether it looks like a valid encrypted frame // and if so it will parse it into its different components localFrame->ParseFrame(encryptedFrame); // If the frame is not encrypted and we can pass it through, do it bool canUsePassThrough = allowPassThroughUntil_ > start; if (!localFrame->IsEncrypted() && canUsePassThrough) { auto copySize = std::min(frame.size(), encryptedFrame.size()); if (encryptedFrame.data() != frame.data()) { memcpy(frame.data(), encryptedFrame.data(), copySize); } stats_[mediaType].passthroughCount++; *bytesWritten = copySize; return ResultCode::Success; } // If the frame is not encrypted and we can't pass it through, fail if (!localFrame->IsEncrypted()) { DISCORD_LOG(LS_INFO) << "Decrypt failed, frame is not encrypted and pass through is disabled"; stats_[mediaType].decryptFailureCount++; *bytesWritten = 0; return ResultCode::DecryptionFailure; } // Try and decrypt with each valid cryptor // reverse iterate to try the newest cryptors first auto result = ResultCode::MissingKeyRatchet; for (auto it = cryptorManagers_.rbegin(); it != cryptorManagers_.rend(); ++it) { auto& cryptorManager = *it; result = DecryptImpl(cryptorManager, mediaType, *localFrame); if (result == ResultCode::Success) { break; } } size_t reconstructedFrameSize = 0; if (result == ResultCode::Success) { stats.decryptSuccessCount++; reconstructedFrameSize = localFrame->ReconstructFrame(frame); } else { stats.decryptFailureCount++; DISCORD_LOG(LS_WARNING) << "Decrypt failed, no valid cryptor found, type: " << (mediaType ? "video" : "audio") << ", encrypted frame size: " << encryptedFrame.size() << ", plaintext frame size: " << frame.size() << ", number of cryptor managers: " << cryptorManagers_.size() << ", pass through enabled: " << (canUsePassThrough ? "yes" : "no"); if (result == ResultCode::InvalidNonce) { stats.decryptInvalidNonceCount++; } else if (result == ResultCode::MissingKeyRatchet) { stats.decryptMissingKeyCount++; } } auto end = clock_.Now(); if (end > lastStatsTime_ + kStatsInterval) { lastStatsTime_ = end; DISCORD_LOG(LS_INFO) << "Decrypted audio: " << stats_[Audio].decryptSuccessCount << ", video: " << stats_[Video].decryptSuccessCount << ". Failed audio: " << stats_[Audio].decryptFailureCount << ", video: " << stats_[Video].decryptFailureCount; } stats.decryptDuration += std::chrono::duration_cast(end - start).count(); *bytesWritten = reconstructedFrameSize; return result; } Decryptor::ResultCode Decryptor::DecryptImpl(CryptorManager& cryptorManager, MediaType mediaType, InboundFrameProcessor& encryptedFrame) { auto tag = encryptedFrame.GetTag(); auto truncatedNonce = encryptedFrame.GetTruncatedNonce(); auto authenticatedData = encryptedFrame.GetAuthenticatedData(); auto ciphertext = encryptedFrame.GetCiphertext(); auto plaintext = encryptedFrame.GetPlaintext(); // expand the truncated nonce to the full sized one needed for decryption auto nonceBuffer = std::array(); memcpy(nonceBuffer.data() + kAesGcm128TruncatedSyncNonceOffset, &truncatedNonce, kAesGcm128TruncatedSyncNonceBytes); auto nonceBufferView = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); auto generation = cryptorManager.ComputeWrappedGeneration(truncatedNonce >> kRatchetGenerationShiftBits); if (!cryptorManager.CanProcessNonce(generation, truncatedNonce)) { DISCORD_LOG(LS_INFO) << "Decrypt failed, cannot process nonce: " << truncatedNonce; return ResultCode::InvalidNonce; } // Get the cryptor for this generation ICryptor* cryptor = cryptorManager.GetCryptor(generation); if (!cryptor) { DISCORD_LOG(LS_INFO) << "Decrypt failed, no cryptor found for generation: " << generation; return ResultCode::MissingCryptor; } // perform the decryption bool success = cryptor->Decrypt(plaintext, ciphertext, tag, nonceBufferView, authenticatedData); stats_[mediaType].decryptAttempts++; if (success) { cryptorManager.ReportCryptorSuccess(generation, truncatedNonce); } return success ? ResultCode::Success : ResultCode::DecryptionFailure; } size_t Decryptor::GetMaxPlaintextByteSize([[maybe_unused]] MediaType mediaType, size_t encryptedFrameSize) { return encryptedFrameSize; } void Decryptor::UpdateCryptorManagerExpiry(Duration expiry) { auto maxExpiryTime = clock_.Now() + expiry; for (auto& cryptorManager : cryptorManagers_) { cryptorManager.UpdateExpiry(maxExpiryTime); } } void Decryptor::CleanupExpiredCryptorManagers() { while (!cryptorManagers_.empty() && cryptorManagers_.front().IsExpired()) { DISCORD_LOG(LS_INFO) << "Removing expired cryptor manager."; cryptorManagers_.pop_front(); } } std::unique_ptr Decryptor::GetOrCreateFrameProcessor() { std::lock_guard lock(frameProcessorsMutex_); if (frameProcessors_.empty()) { return std::make_unique(); } auto frameProcessor = std::move(frameProcessors_.back()); frameProcessors_.pop_back(); return frameProcessor; } void Decryptor::ReturnFrameProcessor(std::unique_ptr frameProcessor) { std::lock_guard lock(frameProcessorsMutex_); frameProcessors_.push_back(std::move(frameProcessor)); } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/decryptor.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include "codec_utils.h" #include "common.h" #include "cryptor.h" #include "cryptor_manager.h" #include "frame_processors.h" #include "utils/clock.h" namespace discord { namespace dave { class IKeyRatchet; class Decryptor final : public IDecryptor { public: using Duration = std::chrono::seconds; virtual ~Decryptor() noexcept = default; virtual void TransitionToKeyRatchet( std::unique_ptr keyRatchet, Duration transitionExpiry = kDefaultTransitionDuration) override; virtual void TransitionToPassthroughMode( bool passthroughMode, Duration transitionExpiry = kDefaultTransitionDuration) override; virtual ResultCode Decrypt(MediaType mediaType, ArrayView encryptedFrame, ArrayView frame, size_t* bytesWritten) override; virtual size_t GetMaxPlaintextByteSize(MediaType mediaType, size_t encryptedFrameSize) override; virtual DecryptorStats GetStats(MediaType mediaType) const override { return stats_[mediaType]; } private: using TimePoint = IClock::TimePoint; Decryptor::ResultCode DecryptImpl(CryptorManager& cryptor, MediaType mediaType, InboundFrameProcessor& encryptedFrame); void UpdateCryptorManagerExpiry(Duration expiry); void CleanupExpiredCryptorManagers(); std::unique_ptr GetOrCreateFrameProcessor(); void ReturnFrameProcessor(std::unique_ptr frameProcessor); Clock clock_; std::deque cryptorManagers_; std::mutex frameProcessorsMutex_; std::vector> frameProcessors_; TimePoint allowPassThroughUntil_{TimePoint::min()}; TimePoint lastStatsTime_{TimePoint::min()}; std::array stats_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/encryptor.cpp ================================================ #include "encryptor.h" #include #include #include #include #include #include "codec_utils.h" #include "common.h" #include "cryptor_manager.h" #include "utils/leb128.h" #include "utils/scope_exit.h" using namespace std::chrono_literals; namespace discord { namespace dave { constexpr auto kStatsInterval = 10s; std::unique_ptr CreateEncryptor() { return std::make_unique(); } void Encryptor::SetKeyRatchet(std::unique_ptr keyRatchet) { std::lock_guard lock(keyGenMutex_); keyRatchet_ = std::move(keyRatchet); cryptor_ = nullptr; currentKeyGeneration_ = 0; truncatedNonce_ = 0; } void Encryptor::SetPassthroughMode(bool passthroughMode) { passthroughMode_ = passthroughMode; UpdateCurrentProtocolVersion(passthroughMode ? 0 : MaxSupportedProtocolVersion()); } Encryptor::ResultCode Encryptor::Encrypt(MediaType mediaType, uint32_t ssrc, ArrayView frame, ArrayView encryptedFrame, size_t* bytesWritten) { if (mediaType != Audio && mediaType != Video) { DISCORD_LOG(LS_WARNING) << "Encrypt failed, invalid media type: " << static_cast(mediaType); return ResultCode::EncryptionFailure; } auto& stats = stats_[mediaType]; if (passthroughMode_) { // Pass frame through without encrypting auto copySize = std::min(encryptedFrame.size(), frame.size()); memcpy(encryptedFrame.data(), frame.data(), copySize); *bytesWritten = copySize; stats.passthroughCount++; return ResultCode::Success; } { std::lock_guard lock(keyGenMutex_); if (!keyRatchet_) { stats.encryptFailureCount++; stats.encryptMissingKeyCount++; return ResultCode::MissingKeyRatchet; } } auto start = std::chrono::steady_clock::now(); auto result = ResultCode::Success; // write the codec identifier auto codec = CodecForSsrc(ssrc); auto frameProcessor = GetOrCreateFrameProcessor(); ScopeExit cleanup([&] { ReturnFrameProcessor(std::move(frameProcessor)); }); frameProcessor->ProcessFrame(frame, codec); const auto& unencryptedBytes = frameProcessor->GetUnencryptedBytes(); const auto& encryptedBytes = frameProcessor->GetEncryptedBytes(); auto& ciphertextBytes = frameProcessor->GetCiphertextBytes(); const auto& unencryptedRanges = frameProcessor->GetUnencryptedRanges(); auto unencryptedRangesSize = UnencryptedRangesSize(unencryptedRanges); auto additionalData = MakeArrayView(unencryptedBytes.data(), unencryptedBytes.size()); auto plaintextBuffer = MakeArrayView(encryptedBytes.data(), encryptedBytes.size()); auto ciphertextBuffer = MakeArrayView(ciphertextBytes.data(), ciphertextBytes.size()); auto frameSize = encryptedBytes.size() + unencryptedBytes.size(); auto tagBuffer = MakeArrayView(encryptedFrame.data() + frameSize, kAesGcm128TruncatedTagBytes); auto nonceBuffer = std::array(); auto nonceBufferView = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); constexpr auto MAX_CIPHERTEXT_VALIDATION_RETRIES = 10; // some codecs (e.g. H26X) have packetizers that cannot handle specific byte sequences // so we attempt up to MAX_CIPHERTEXT_VALIDATION_RETRIES to encrypt the frame // calling into codec utils to validate the ciphertext + supplemental section // and re-rolling the truncated nonce if it fails // the nonce increment will definitely change the ciphertext and the tag // incrementing the nonce will also change the appropriate bytes // in the tail end of the nonce // which can remove start codes from the last 1 or 2 bytes of the nonce // and the two bytes of the unencrypted header bytes for (auto attempt = 1; attempt <= MAX_CIPHERTEXT_VALIDATION_RETRIES; ++attempt) { auto [cryptor, truncatedNonce] = GetNextCryptorAndNonce(); if (!cryptor) { stats.encryptMissingKeyCount++; result = ResultCode::MissingCryptor; break; } // write the truncated nonce to our temporary full nonce array // (since the encryption call expects a full size nonce) memcpy(nonceBuffer.data() + kAesGcm128TruncatedSyncNonceOffset, &truncatedNonce, kAesGcm128TruncatedSyncNonceBytes); // encrypt the plaintext, adding the unencrypted header to the tag bool success = cryptor->Encrypt( ciphertextBuffer, plaintextBuffer, nonceBufferView, additionalData, tagBuffer); stats.encryptAttempts++; stats.encryptMaxAttempts = std::max(stats.encryptMaxAttempts, (uint64_t)attempt); if (!success) { assert(false && "Failed to encrypt frame"); result = ResultCode::EncryptionFailure; break; } auto reconstructedFrameSize = frameProcessor->ReconstructFrame(encryptedFrame); assert(reconstructedFrameSize == frameSize && "Failed to reconstruct frame"); auto nonceSize = Leb128Size(truncatedNonce); auto truncatedNonceBuffer = MakeArrayView(tagBuffer.end(), nonceSize); auto unencryptedRangesBuffer = MakeArrayView(truncatedNonceBuffer.end(), unencryptedRangesSize); auto supplementalBytesBuffer = MakeArrayView(unencryptedRangesBuffer.end(), sizeof(SupplementalBytesSize)); auto markerBytesBuffer = MakeArrayView(supplementalBytesBuffer.end(), sizeof(MagicMarker)); // write the nonce auto res = WriteLeb128(truncatedNonce, truncatedNonceBuffer.begin()); if (res != nonceSize) { assert(false && "Failed to write truncated nonce"); result = ResultCode::EncryptionFailure; break; } // write the unencrypted ranges res = SerializeUnencryptedRanges( unencryptedRanges, unencryptedRangesBuffer.begin(), unencryptedRangesBuffer.size()); if (res != unencryptedRangesSize) { assert(false && "Failed to write unencrypted ranges"); result = ResultCode::EncryptionFailure; break; } // write the supplemental bytes size uint64_t supplementalBytesLarge = kSupplementalBytes + nonceSize + unencryptedRangesSize; if (supplementalBytesLarge > std::numeric_limits::max()) { assert(false && "Supplemental bytes size too large"); result = ResultCode::EncryptionFailure; break; } SupplementalBytesSize supplementalBytes = static_cast(supplementalBytesLarge); memcpy(supplementalBytesBuffer.data(), &supplementalBytes, sizeof(SupplementalBytesSize)); // write the marker bytes, ends the frame memcpy(markerBytesBuffer.data(), &kMarkerBytes, sizeof(MagicMarker)); auto encryptedFrameBytes = reconstructedFrameSize + kAesGcm128TruncatedTagBytes + nonceSize + unencryptedRangesSize + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); if (codec_utils::ValidateEncryptedFrame( *frameProcessor, MakeArrayView(encryptedFrame.data(), encryptedFrameBytes))) { *bytesWritten = encryptedFrameBytes; break; } else if (attempt >= MAX_CIPHERTEXT_VALIDATION_RETRIES) { assert(false && "Failed to validate encrypted section for codec"); result = ResultCode::TooManyAttempts; break; } } auto now = std::chrono::steady_clock::now(); stats.encryptDuration += std::chrono::duration_cast(now - start).count(); if (result == ResultCode::Success) { stats.encryptSuccessCount++; } else { stats.encryptFailureCount++; } if (now > lastStatsTime_ + kStatsInterval) { lastStatsTime_ = now; DISCORD_LOG(LS_INFO) << "Encrypted audio: " << stats_[Audio].encryptSuccessCount << ", video: " << stats_[Video].encryptSuccessCount << ". Failed audio: " << stats_[Audio].encryptFailureCount << ", video: " << stats_[Video].encryptFailureCount; DISCORD_LOG(LS_INFO) << "Last encrypted frame, type: " << (mediaType == Audio ? "audio" : "video") << ", ssrc: " << ssrc << ", size: " << frame.size(); } return result; } size_t Encryptor::GetMaxCiphertextByteSize([[maybe_unused]] MediaType mediaType, size_t frameSize) { return frameSize + kSupplementalBytes + kTransformPaddingBytes; } void Encryptor::AssignSsrcToCodec(uint32_t ssrc, Codec codecType) { auto existingCodecIt = std::find_if( ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { return pair.first == ssrc; }); if (existingCodecIt == ssrcCodecPairs_.end()) { ssrcCodecPairs_.emplace_back(ssrc, codecType); } else { existingCodecIt->second = codecType; } } Codec Encryptor::CodecForSsrc(uint32_t ssrc) { auto existingCodecIt = std::find_if( ssrcCodecPairs_.begin(), ssrcCodecPairs_.end(), [ssrc](const SsrcCodecPair& pair) { return pair.first == ssrc; }); if (existingCodecIt != ssrcCodecPairs_.end()) { return existingCodecIt->second; } else { return Codec::Unknown; } } std::unique_ptr Encryptor::GetOrCreateFrameProcessor() { std::lock_guard lock(frameProcessorsMutex_); if (frameProcessors_.empty()) { return std::make_unique(); } auto frameProcessor = std::move(frameProcessors_.back()); frameProcessors_.pop_back(); return frameProcessor; } void Encryptor::ReturnFrameProcessor(std::unique_ptr frameProcessor) { std::lock_guard lock(frameProcessorsMutex_); frameProcessors_.push_back(std::move(frameProcessor)); } Encryptor::CryptorAndNonce Encryptor::GetNextCryptorAndNonce() { std::lock_guard lock(keyGenMutex_); if (!keyRatchet_) { return {nullptr, 0}; } auto generation = ComputeWrappedGeneration(currentKeyGeneration_, ++truncatedNonce_ >> kRatchetGenerationShiftBits); if (generation != currentKeyGeneration_ || !cryptor_) { currentKeyGeneration_ = generation; auto encryptionKey = keyRatchet_->GetKey(currentKeyGeneration_); cryptor_ = CreateCryptor(encryptionKey); } return {cryptor_, truncatedNonce_}; } void Encryptor::UpdateCurrentProtocolVersion(ProtocolVersion version) { if (version == currentProtocolVersion_) { return; } currentProtocolVersion_ = version; if (protocolVersionChangedCallback_) { protocolVersionChangedCallback_(); } } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/encryptor.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include "codec_utils.h" #include "common.h" #include "cryptor.h" #include "frame_processors.h" namespace discord { namespace dave { class Encryptor final : public IEncryptor { public: virtual ~Encryptor() noexcept = default; virtual void SetKeyRatchet(std::unique_ptr keyRatchet) override; virtual void SetPassthroughMode(bool passthroughMode) override; virtual bool HasKeyRatchet() const override { return keyRatchet_ != nullptr; } virtual bool IsPassthroughMode() const override { return passthroughMode_; } virtual void AssignSsrcToCodec(uint32_t ssrc, Codec codecType) override; virtual Codec CodecForSsrc(uint32_t ssrc) override; virtual ResultCode Encrypt(MediaType mediaType, uint32_t ssrc, ArrayView frame, ArrayView encryptedFrame, size_t* bytesWritten) override; virtual size_t GetMaxCiphertextByteSize(MediaType mediaType, size_t frameSize) override; virtual EncryptorStats GetStats(MediaType mediaType) const override { return stats_[mediaType]; } using ProtocolVersionChangedCallback = std::function; virtual void SetProtocolVersionChangedCallback(ProtocolVersionChangedCallback callback) override { protocolVersionChangedCallback_ = std::move(callback); } virtual ProtocolVersion GetProtocolVersion() const override { return currentProtocolVersion_; } private: std::unique_ptr GetOrCreateFrameProcessor(); void ReturnFrameProcessor(std::unique_ptr frameProcessor); using CryptorAndNonce = std::pair, TruncatedSyncNonce>; CryptorAndNonce GetNextCryptorAndNonce(); void UpdateCurrentProtocolVersion(ProtocolVersion version); std::atomic_bool passthroughMode_{false}; std::mutex keyGenMutex_; std::unique_ptr keyRatchet_; std::shared_ptr cryptor_; KeyGeneration currentKeyGeneration_{0}; TruncatedSyncNonce truncatedNonce_{0}; std::mutex frameProcessorsMutex_; std::vector> frameProcessors_; using SsrcCodecPair = std::pair; std::vector ssrcCodecPairs_; using TimePoint = std::chrono::time_point; TimePoint lastStatsTime_{TimePoint::min()}; std::array stats_; ProtocolVersionChangedCallback protocolVersionChangedCallback_; ProtocolVersion currentProtocolVersion_{MaxSupportedProtocolVersion()}; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/frame_processors.cpp ================================================ #include "frame_processors.h" #include #include #include #include #include #include #include "codec_utils.h" #include "utils/leb128.h" #if defined(_MSC_VER) #include #endif namespace discord { namespace dave { std::pair OverflowAdd(size_t a, size_t b) { size_t res; #if defined(_MSC_VER) && defined(_M_X64) bool didOverflow = _addcarry_u64(0, a, b, &res); #elif defined(_MSC_VER) && defined(_M_IX86) bool didOverflow = _addcarry_u32(0, a, b, &res); #else bool didOverflow = __builtin_add_overflow(a, b, &res); #endif return {didOverflow, res}; } uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges) { size_t size = 0; for (const auto& range : unencryptedRanges) { size += Leb128Size(range.offset); size += Leb128Size(range.size); } assert(size <= std::numeric_limits::max() && "Unencrypted ranges size exceeds 255 bytes"); return static_cast(size); } uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, uint8_t* buffer, size_t bufferSize) { auto writeAt = buffer; auto end = buffer + bufferSize; for (const auto& range : unencryptedRanges) { auto rangeSize = Leb128Size(range.offset) + Leb128Size(range.size); if (rangeSize > static_cast(end - writeAt)) { assert(false && "Buffer is too small to serialize unencrypted ranges"); break; } writeAt += WriteLeb128(range.offset, writeAt); writeAt += WriteLeb128(range.size, writeAt); } assert(writeAt >= buffer); return static_cast(writeAt - buffer); } uint8_t DeserializeUnencryptedRanges(const uint8_t*& readAt, const uint8_t bufferSize, Ranges& unencryptedRanges) { auto start = readAt; auto end = readAt + bufferSize; while (readAt < end) { size_t offset = ReadLeb128(readAt, end); if (readAt == nullptr) { break; } size_t size = ReadLeb128(readAt, end); if (readAt == nullptr) { break; } unencryptedRanges.push_back({offset, size}); } if (readAt != end) { DISCORD_LOG(LS_WARNING) << "Failed to deserialize unencrypted ranges"; unencryptedRanges.clear(); readAt = nullptr; return 0; } return static_cast(readAt - start); } bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize) { if (unencryptedRanges.empty()) { return true; } // validate that the ranges are in order and don't overlap for (auto i = 0u; i < unencryptedRanges.size(); ++i) { auto current = unencryptedRanges[i]; // The current range should not overflow into the next range // or if it is the last range, the end of the frame auto maxEnd = i + 1 < unencryptedRanges.size() ? unencryptedRanges[i + 1].offset : frameSize; auto [didOverflow, currentEnd] = OverflowAdd(current.offset, current.size); if (didOverflow || currentEnd > maxEnd) { DISCORD_LOG(LS_WARNING) << "Unencrypted range may overlap or be out of order: current offset: " << current.offset << ", current size: " << current.size << ", maximum end: " << maxEnd << ", frame size: " << frameSize; return false; } } return true; } size_t Reconstruct(Ranges ranges, const std::vector& rangeBytes, const std::vector& otherBytes, const ArrayView& output) { size_t frameIndex = 0; size_t rangeBytesIndex = 0; size_t otherBytesIndex = 0; const auto CopyRangeBytes = [&](size_t size) { assert(rangeBytesIndex + size <= rangeBytes.size()); assert(frameIndex + size <= output.size()); if ((rangeBytes.size() - rangeBytesIndex < size) || (output.size() - frameIndex < size)) { return; } memcpy(output.data() + frameIndex, rangeBytes.data() + rangeBytesIndex, size); rangeBytesIndex += size; frameIndex += size; }; const auto CopyOtherBytes = [&](size_t size) { assert(otherBytesIndex + size <= otherBytes.size()); assert(frameIndex + size <= output.size()); if ((otherBytes.size() - otherBytesIndex < size) || (output.size() - frameIndex < size)) { return; } memcpy(output.data() + frameIndex, otherBytes.data() + otherBytesIndex, size); otherBytesIndex += size; frameIndex += size; }; for (const auto& range : ranges) { if (range.offset > frameIndex) { CopyOtherBytes(range.offset - frameIndex); } CopyRangeBytes(range.size); } if (otherBytesIndex < otherBytes.size()) { CopyOtherBytes(otherBytes.size() - otherBytesIndex); } assert(rangeBytesIndex == rangeBytes.size()); assert(otherBytesIndex == otherBytes.size()); assert(frameIndex <= output.size()); return frameIndex; } void InboundFrameProcessor::Clear() { isEncrypted_ = false; originalSize_ = 0; truncatedNonce_ = std::numeric_limits::max(); unencryptedRanges_.clear(); authenticated_.clear(); ciphertext_.clear(); plaintext_.clear(); } void InboundFrameProcessor::ParseFrame(ArrayView frame) { Clear(); constexpr auto MinSupplementalBytesSize = kAesGcm128TruncatedTagBytes + sizeof(SupplementalBytesSize) + sizeof(MagicMarker); if (frame.size() < MinSupplementalBytesSize) { DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain min supplemental bytes"; return; } // Check the frame ends with the magic marker auto magicMarkerBuffer = frame.end() - sizeof(MagicMarker); if (memcmp(magicMarkerBuffer, &kMarkerBytes, sizeof(MagicMarker)) != 0) { return; } // Read the supplemental bytes size SupplementalBytesSize supplementalBytesSize; auto supplementalBytesSizeBuffer = magicMarkerBuffer - sizeof(SupplementalBytesSize); assert(frame.begin() <= supplementalBytesSizeBuffer && supplementalBytesSizeBuffer <= frame.end()); memcpy(&supplementalBytesSize, supplementalBytesSizeBuffer, sizeof(SupplementalBytesSize)); // Check the frame is large enough to contain the supplemental bytes if (frame.size() < supplementalBytesSize) { DISCORD_LOG(LS_WARNING) << "Encrypted frame is too small to contain supplemental bytes"; return; } // Check that supplemental bytes size is large enough to contain the supplemental bytes if (supplementalBytesSize < MinSupplementalBytesSize) { DISCORD_LOG(LS_WARNING) << "Supplemental bytes size is too small to contain supplemental bytes"; return; } auto supplementalBytesBuffer = frame.end() - supplementalBytesSize; assert(frame.begin() <= supplementalBytesBuffer && supplementalBytesBuffer <= frame.end()); // Read the tag tag_ = MakeArrayView(supplementalBytesBuffer, kAesGcm128TruncatedTagBytes); // Read the nonce auto nonceBuffer = supplementalBytesBuffer + kAesGcm128TruncatedTagBytes; assert(frame.begin() <= nonceBuffer && nonceBuffer <= frame.end()); auto readAt = nonceBuffer; auto end = supplementalBytesSizeBuffer; truncatedNonce_ = static_cast(ReadLeb128(readAt, end)); if (readAt == nullptr) { DISCORD_LOG(LS_WARNING) << "Failed to read truncated nonce"; return; } // Read the unencrypted ranges assert(nonceBuffer <= readAt && readAt <= end && end - readAt <= std::numeric_limits::max()); auto unencryptedRangesSize = static_cast(end - readAt); DeserializeUnencryptedRanges(readAt, unencryptedRangesSize, unencryptedRanges_); if (readAt == nullptr) { DISCORD_LOG(LS_WARNING) << "Failed to read unencrypted ranges"; return; } if (!ValidateUnencryptedRanges(unencryptedRanges_, frame.size())) { DISCORD_LOG(LS_WARNING) << "Invalid unencrypted ranges"; return; } // This is overly aggressive but will keep reallocations to a minimum authenticated_.reserve(frame.size()); ciphertext_.reserve(frame.size()); plaintext_.reserve(frame.size()); originalSize_ = frame.size(); // Split the frame into authenticated and ciphertext bytes size_t frameIndex = 0; for (const auto& range : unencryptedRanges_) { auto encryptedBytes = range.offset - frameIndex; if (encryptedBytes > 0) { assert(frameIndex + encryptedBytes <= frame.size()); AddCiphertextBytes(frame.data() + frameIndex, encryptedBytes); } assert(range.offset + range.size <= frame.size()); AddAuthenticatedBytes(frame.data() + range.offset, range.size); frameIndex = range.offset + range.size; } auto actualFrameSize = frame.size() - supplementalBytesSize; if (frameIndex < actualFrameSize) { AddCiphertextBytes(frame.data() + frameIndex, actualFrameSize - frameIndex); } // Make sure the plaintext buffer is the same size as the ciphertext buffer plaintext_.resize(ciphertext_.size()); // We've successfully parsed the frame // Mark the frame as encrypted isEncrypted_ = true; } size_t InboundFrameProcessor::ReconstructFrame(ArrayView frame) const { if (!isEncrypted_) { DISCORD_LOG(LS_WARNING) << "Cannot reconstruct an invalid encrypted frame"; return 0; } if (authenticated_.size() + plaintext_.size() > frame.size()) { DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the decrypted frame"; return 0; } return Reconstruct(unencryptedRanges_, authenticated_, plaintext_, frame); } void InboundFrameProcessor::AddAuthenticatedBytes(const uint8_t* data, size_t size) { authenticated_.resize(authenticated_.size() + size); memcpy(authenticated_.data() + authenticated_.size() - size, data, size); } void InboundFrameProcessor::AddCiphertextBytes(const uint8_t* data, size_t size) { ciphertext_.resize(ciphertext_.size() + size); memcpy(ciphertext_.data() + ciphertext_.size() - size, data, size); } void OutboundFrameProcessor::Reset() { codec_ = Codec::Unknown; frameIndex_ = 0; unencryptedBytes_.clear(); encryptedBytes_.clear(); unencryptedRanges_.clear(); } void OutboundFrameProcessor::ProcessFrame(ArrayView frame, Codec codec) { Reset(); codec_ = codec; unencryptedBytes_.reserve(frame.size()); encryptedBytes_.reserve(frame.size()); bool success = false; switch (codec) { case Codec::Opus: success = codec_utils::ProcessFrameOpus(*this, frame); break; case Codec::VP8: success = codec_utils::ProcessFrameVp8(*this, frame); break; case Codec::VP9: success = codec_utils::ProcessFrameVp9(*this, frame); break; case Codec::H264: success = codec_utils::ProcessFrameH264(*this, frame); break; case Codec::H265: success = codec_utils::ProcessFrameH265(*this, frame); break; case Codec::AV1: success = codec_utils::ProcessFrameAv1(*this, frame); break; default: assert(false && "Unsupported codec for frame encryption"); break; } if (!success) { frameIndex_ = 0; unencryptedBytes_.clear(); encryptedBytes_.clear(); unencryptedRanges_.clear(); AddEncryptedBytes(frame.data(), frame.size()); } ciphertextBytes_.resize(encryptedBytes_.size()); } size_t OutboundFrameProcessor::ReconstructFrame(ArrayView frame) { if (unencryptedBytes_.size() + ciphertextBytes_.size() > frame.size()) { DISCORD_LOG(LS_WARNING) << "Frame is too small to contain the encrypted frame"; return 0; } return Reconstruct(unencryptedRanges_, unencryptedBytes_, ciphertextBytes_, frame); } void OutboundFrameProcessor::AddUnencryptedBytes(const uint8_t* bytes, size_t size) { if (!unencryptedRanges_.empty() && unencryptedRanges_.back().offset + unencryptedRanges_.back().size == frameIndex_) { // extend the last range unencryptedRanges_.back().size += size; } else { // add a new range (offset, size) unencryptedRanges_.push_back({frameIndex_, size}); } unencryptedBytes_.resize(unencryptedBytes_.size() + size); memcpy(unencryptedBytes_.data() + unencryptedBytes_.size() - size, bytes, size); frameIndex_ += size; } void OutboundFrameProcessor::AddEncryptedBytes(const uint8_t* bytes, size_t size) { encryptedBytes_.resize(encryptedBytes_.size() + size); memcpy(encryptedBytes_.data() + encryptedBytes_.size() - size, bytes, size); frameIndex_ += size; } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/frame_processors.h ================================================ #pragma once #include #include #include #include #include #include #include "common.h" namespace discord { namespace dave { struct Range { size_t offset; size_t size; }; using Ranges = std::vector; uint8_t UnencryptedRangesSize(const Ranges& unencryptedRanges); uint8_t SerializeUnencryptedRanges(const Ranges& unencryptedRanges, uint8_t* buffer, size_t bufferSize); uint8_t DeserializeUnencryptedRanges(const uint8_t*& buffer, const uint8_t bufferSize, Ranges& unencryptedRanges); bool ValidateUnencryptedRanges(const Ranges& unencryptedRanges, size_t frameSize); class InboundFrameProcessor { public: void ParseFrame(ArrayView frame); size_t ReconstructFrame(ArrayView frame) const; bool IsEncrypted() const { return isEncrypted_; } size_t Size() const { return originalSize_; } void Clear(); ArrayView GetTag() const { return tag_; } TruncatedSyncNonce GetTruncatedNonce() const { return truncatedNonce_; } ArrayView GetAuthenticatedData() const { return MakeArrayView(authenticated_.data(), authenticated_.size()); } ArrayView GetCiphertext() const { return MakeArrayView(ciphertext_.data(), ciphertext_.size()); } ArrayView GetPlaintext() { return MakeArrayView(plaintext_); } private: void AddAuthenticatedBytes(const uint8_t* data, size_t size); void AddCiphertextBytes(const uint8_t* data, size_t size); bool isEncrypted_{false}; size_t originalSize_{0}; ArrayView tag_; TruncatedSyncNonce truncatedNonce_; Ranges unencryptedRanges_; std::vector authenticated_; std::vector ciphertext_; std::vector plaintext_; }; class OutboundFrameProcessor { public: void ProcessFrame(ArrayView frame, Codec codec); size_t ReconstructFrame(ArrayView frame); Codec GetCodec() const { return codec_; } const std::vector& GetUnencryptedBytes() const { return unencryptedBytes_; } const std::vector& GetEncryptedBytes() const { return encryptedBytes_; } std::vector& GetCiphertextBytes() { return ciphertextBytes_; } const Ranges& GetUnencryptedRanges() const { return unencryptedRanges_; } void Reset(); void AddUnencryptedBytes(const uint8_t* bytes, size_t size); void AddEncryptedBytes(const uint8_t* bytes, size_t size); private: Codec codec_{Codec::Unknown}; size_t frameIndex_{0}; std::vector unencryptedBytes_; std::vector encryptedBytes_; std::vector ciphertextBytes_; Ranges unencryptedRanges_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/key_ratchet.h ================================================ #pragma once #include #include "common.h" namespace discord { namespace dave { } // namespace dave } // namespace discord ================================================ FILE: cpp/src/logger.cpp ================================================ #include #include #include #include namespace discord { namespace dave { std::atomic gLogSink = nullptr; void SetLogSink(LogSink sink) { gLogSink = sink; } LogStreamer::LogStreamer(LoggingSeverity severity, const char* file, int line) : severity_(severity) , file_(file) , line_(line) { } LogStreamer::~LogStreamer() { std::string logLine = stream_.str(); if (logLine.empty()) { return; } auto sink = gLogSink.load(); if (sink) { sink(severity_, file_, line_, logLine); return; } switch (severity_) { case LS_VERBOSE: case LS_INFO: case LS_WARNING: case LS_ERROR: { const char* file = file_; if (auto separator = strrchr(file, '/')) { file = separator + 1; } std::cout << "(" << file << ":" << line_ << ") " << logLine << std::endl; break; } case LS_NONE: break; } } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/detail/persisted_key_pair.h ================================================ #pragma once #include #include #include #include "mls/persisted_key_pair.h" namespace discord { namespace dave { namespace mls { namespace detail { std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair(KeyPairContextType ctx, const std::string& keyID, ::mlspp::CipherSuite suite, bool& supported); std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair( KeyPairContextType ctx, const std::string& keyID, ::mlspp::CipherSuite suite); bool DeleteNativePersistedKeyPair(KeyPairContextType ctx, const std::string& keyID); bool DeleteGenericPersistedKeyPair(KeyPairContextType ctx, const std::string& keyID); } // namespace detail } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/detail/persisted_key_pair_apple.cpp ================================================ #include "mls/detail/persisted_key_pair.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include "mls/parameters.h" static const CFStringRef KeyServiceLabel = CFSTR("Discord Secure Frames Key"); static const std::string KeyLabelPrefix = "Discord Secure Frames Key: "; static const std::string KeyTagPrefix = "discord-secure-frames-key-"; #ifdef KEYCHAIN_ACCESS_GROUP_ID_SYMBOL extern CFStringRef KEYCHAIN_ACCESS_GROUP_ID_SYMBOL; #endif static void AddAccessGroup([[maybe_unused]] CFMutableDictionaryRef dict) { #ifdef KEYCHAIN_ACCESS_GROUP_ID_SYMBOL CFDictionaryAddValue(dict, kSecAttrAccessGroup, KEYCHAIN_ACCESS_GROUP_ID_SYMBOL); #elif defined(KEYCHAIN_ACCESS_GROUP_ID) CFDictionaryAddValue(dict, kSecAttrAccessGroup, CFSTR(#KEYCHAIN_ACCESS_GROUP_ID)); #endif } template struct ScopedCFTypeRef { ScopedCFTypeRef() = default; ScopedCFTypeRef(T ref) : ref_(ref) { } ScopedCFTypeRef(ScopedCFTypeRef& other) : ref_(other.ref_) { if (ref_) { CFRetain(ref_); } } ScopedCFTypeRef(ScopedCFTypeRef&& other) : ref_(std::exchange(other.ref_, nullptr)) { } ~ScopedCFTypeRef() { release(); } ScopedCFTypeRef& operator=(T ref) { release(); ref_ = ref; return *this; } void release() { if (ref_) { CFRelease(ref_); } ref_ = nullptr; } T& get() { return ref_; } T* getPtr() { return &ref_; } CFTypeRef* getGenericPtr() { return (CFTypeRef*)getPtr(); } operator T&() { return get(); } explicit operator bool() { return ref_ != nullptr; } T ref_ = nullptr; }; static std::string ConvertCFString(CFStringRef string) { if (const char* str = CFStringGetCStringPtr(string, kCFStringEncodingUTF8)) { return str; } CFIndex len = CFStringGetLength(string); std::string ret(CFStringGetMaximumSizeForEncoding(len, kCFStringEncodingUTF8), 0); CFStringGetBytes(string, CFRangeMake(0, len), kCFStringEncodingUTF8, '?', false, (UInt8*)ret.data(), ret.size(), &len); ret.resize(len); return ret; } static std::string SecStatusToString(OSStatus status) { std::string ret = std::to_string(status); if (__builtin_available(macOS 10.3, iOS 11.3, *)) { ScopedCFTypeRef string = SecCopyErrorMessageString(status, NULL); if (string) { ret += " ("; ret += ConvertCFString(string); ret += ")"; } } return ret; } static std::string ErrorToString(CFErrorRef error) { if (!error) { return "(null)"; } if (__builtin_available(macOS 10.3, iOS 11.3, *)) { CFIndex status = CFErrorGetCode(error); ScopedCFTypeRef string = CFErrorCopyFailureReason(error); if (string) { std::string ret = std::to_string(status); ret += " ("; ret += ConvertCFString(string); ret += ")"; return ret; } } if (ScopedCFTypeRef string = CFErrorCopyDescription(error)) { return ConvertCFString(string); } return "(unknown)"; } namespace discord { namespace dave { namespace mls { namespace detail { std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair( [[maybe_unused]] KeyPairContextType ctx, const std::string& id, ::mlspp::CipherSuite suite, bool& supported) { std::shared_ptr<::mlspp::SignaturePrivateKey> ret; CFStringRef keyType = nullptr; int keySize = 0; std::function convertKey; ScopedCFTypeRef query = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); CFDictionaryAddValue(query, kSecReturnRef, kCFBooleanTrue); CFDictionaryAddValue(query, kSecUseAuthenticationUI, kSecUseAuthenticationUISkip); AddAccessGroup(query); auto suiteId = suite.cipher_suite(); switch (suiteId) { case ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256: case ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384: case ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521: supported = true; keyType = kSecAttrKeyTypeECSECPrimeRandom; if (suiteId == ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521) { keySize = 521; } else if (suiteId == ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384) { keySize = 384; } else { keySize = 256; } convertKey = [keySize](CFDataRef data) { // https://developer.apple.com/documentation/security/1643698-seckeycopyexternalrepresentation // Input has a 1-byte header (always 0x04, per ANSI X9.63), followed by 3 // keySize-bit left-padded byte-aligned big-endian integers: X, Y, and K. // X and Y are the public key (represented as the coordinates); // K is the private key. bytes ret; constexpr size_t HeaderSize = 1; constexpr size_t ValueCount = 3; constexpr size_t PublicValues = 2; constexpr uint8_t HeaderByte = 0x04; // Convert keySize from bits to bytes (rounding up) CFIndex byteLen = (keySize + 7) / 8; CFIndex len = CFDataGetLength(data); if (len < 0 || (size_t)len < HeaderSize + ValueCount * byteLen) { DISCORD_LOG(LS_ERROR) << "Exported key blob too small in GetPersistedKeyPair/convertKey: " << len; return ret; } const uint8_t* ptr = CFDataGetBytePtr(data); if (ptr[0] != HeaderByte) { DISCORD_LOG(LS_ERROR) << "Exported key blob has unexpected format in GetPersistedKeyPair/convertKey: " << ptr[0]; return ret; } // Skip header, X, and Y, and extract K. ptr += HeaderSize + PublicValues * byteLen; ret.as_vec().assign(ptr, ptr + byteLen); return ret; }; break; default: // Other suites will need to store keys as generic data items return nullptr; } assert(keyType && keySize && convertKey); ScopedCFTypeRef sizeRef = CFNumberCreate(NULL, kCFNumberIntType, &keySize); std::string labelString = KeyLabelPrefix + id; std::string tagString = KeyTagPrefix + id; ScopedCFTypeRef labelStringRef = CFStringCreateWithCString(NULL, labelString.c_str(), kCFStringEncodingUTF8); ScopedCFTypeRef tagDataRef = CFDataCreate(NULL, (const UInt8*)tagString.c_str(), tagString.size()); CFDictionaryAddValue(query, kSecClass, kSecClassKey); CFDictionaryAddValue(query, kSecAttrKeyType, keyType); CFDictionaryAddValue(query, kSecAttrApplicationTag, tagDataRef); CFDictionaryAddValue(query, kSecAttrCanSign, kCFBooleanTrue); ScopedCFTypeRef cfError; ScopedCFTypeRef key; // If we get errSecMissingEntitlement, try again with the file-based keychain constexpr int AttemptCount = 2; for (int attempt = 0; attempt < AttemptCount && !key; attempt++) { cfError.release(); CFBooleanRef useDataProtection = attempt == 0 ? kCFBooleanTrue : kCFBooleanFalse; if (__builtin_available(macOS 10.15, *)) { CFDictionarySetValue(query, kSecUseDataProtectionKeychain, useDataProtection); } else if (attempt == 1) { return nullptr; } OSStatus status = SecItemCopyMatching(query, key.getGenericPtr()); if (status == errSecSuccess) { ScopedCFTypeRef updateQuery = CFDictionaryCreateMutableCopy(NULL, 0, query); ScopedCFTypeRef updateAttrs = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); CFDictionaryRemoveValue(updateQuery, kSecReturnRef); CFDictionaryAddValue( updateAttrs, kSecAttrAccessible, kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly); // Best effort OSStatus updateStatus = SecItemUpdate(query, updateAttrs); DISCORD_LOG(LS_INFO) << "Attempted to update permissions on existing key: " << SecStatusToString(updateStatus); } if (status == errSecItemNotFound) { DISCORD_LOG(LS_INFO) << "Item not found in GetPersistedKeyPair; generating new: " << SecStatusToString(status); ScopedCFTypeRef params = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); AddAccessGroup(params); CFDictionaryAddValue(params, kSecAttrKeyType, keyType); CFDictionaryAddValue(params, kSecAttrKeySizeInBits, sizeRef); CFDictionaryAddValue(params, kSecAttrCanEncrypt, kCFBooleanFalse); CFDictionaryAddValue(params, kSecAttrCanDecrypt, kCFBooleanFalse); CFDictionaryAddValue(params, kSecAttrCanWrap, kCFBooleanFalse); CFDictionaryAddValue(params, kSecAttrCanUnwrap, kCFBooleanFalse); CFDictionaryAddValue( params, kSecAttrAccessible, kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly); if (__builtin_available(macOS 10.15, *)) { CFDictionaryAddValue(params, kSecUseDataProtectionKeychain, useDataProtection); } ScopedCFTypeRef privParams = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); CFDictionaryAddValue(privParams, kSecAttrIsPermanent, kCFBooleanTrue); CFDictionaryAddValue(privParams, kSecAttrLabel, labelStringRef); CFDictionaryAddValue(privParams, kSecAttrApplicationTag, tagDataRef); CFDictionaryAddValue(params, kSecPrivateKeyAttrs, privParams); key = SecKeyCreateRandomKey(params, cfError.getPtr()); if (!key || cfError) { DISCORD_LOG(LS_WARNING) << "Failed to create key in GetPersistedKeyPair: " << ErrorToString(cfError); if (!cfError || CFErrorGetCode(cfError) != errSecMissingEntitlement) { return nullptr; } key.release(); } } else if (status != 0 || !key) { DISCORD_LOG(LS_WARNING) << "Item not found GetPersistedKeyPair: " << SecStatusToString(status); if (status != errSecMissingEntitlement) { return nullptr; } } } if (!key) { return nullptr; } ScopedCFTypeRef data = SecKeyCopyExternalRepresentation(key, cfError.getPtr()); if (!data) { DISCORD_LOG(LS_ERROR) << "Failed to export key in GetPersistedKeyPair: " << ErrorToString(cfError); return nullptr; } bytes converted = convertKey(data); if (converted.empty()) { DISCORD_LOG(LS_ERROR) << "Failed to convert key in GetPersistedKeyPair"; return nullptr; } return std::make_shared<::mlspp::SignaturePrivateKey>( ::mlspp::SignaturePrivateKey::parse(suite, converted)); } std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair( [[maybe_unused]] KeyPairContextType ctx, const std::string& id, ::mlspp::CipherSuite suite) { ::mlspp::SignaturePrivateKey ret; ScopedCFTypeRef query = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); ScopedCFTypeRef accountString = CFStringCreateWithCString(NULL, id.c_str(), kCFStringEncodingUTF8); CFDictionaryAddValue(query, kSecReturnData, kCFBooleanTrue); CFDictionaryAddValue(query, kSecUseAuthenticationUI, kSecUseAuthenticationUISkip); CFDictionaryAddValue(query, kSecAttrService, KeyServiceLabel); CFDictionaryAddValue(query, kSecAttrAccount, accountString); CFDictionaryAddValue(query, kSecClass, kSecClassGenericPassword); AddAccessGroup(query); // If we get errSecMissingEntitlement, try again with the file-based keychain constexpr int AttemptCount = 2; for (int attempt = 0; attempt < AttemptCount && ret.public_key.data.empty(); attempt++) { if (__builtin_available(macOS 10.15, *)) { CFDictionarySetValue(query, kSecUseDataProtectionKeychain, attempt == 0 ? kCFBooleanTrue : kCFBooleanFalse); } else if (attempt == 1) { return nullptr; } ScopedCFTypeRef result; OSStatus status = SecItemCopyMatching(query, result.getGenericPtr()); std::string curstr; if (status == 0 && result) { curstr.assign((char*)CFDataGetBytePtr(result), CFDataGetLength(result)); try { ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); } catch (std::exception& ex) { DISCORD_LOG(LS_WARNING) << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); return nullptr; } } else if (status == errSecItemNotFound) { DISCORD_LOG(LS_INFO) << "Did not receive item in GetPersistedKeyPair; generating new: " << SecStatusToString(status); ret = ::mlspp::SignaturePrivateKey::generate(suite); std::string newstr = ret.to_jwk(suite); ScopedCFTypeRef data = CFDataCreate(NULL, (const UInt8*)newstr.c_str(), newstr.length()); CFDictionaryRemoveValue(query, kSecReturnData); CFDictionaryAddValue(query, kSecValueData, data); status = SecItemAdd(query, nullptr); if (status) { DISCORD_LOG(LS_WARNING) << "Failed to create keychain item in GetPersistedKeyPair: " << SecStatusToString(status); if (status != errSecMissingEntitlement) { return nullptr; } ret = ::mlspp::SignaturePrivateKey(); } } else { DISCORD_LOG(LS_WARNING) << "Failed to retrieve item in GetPersistedKeyPair: " << SecStatusToString(status); if (status != errSecMissingEntitlement) { return nullptr; } } } if (!ret.public_key.data.empty()) { return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); } else { return nullptr; } } static bool DeleteWithQuery(CFMutableDictionaryRef query) { #if !TARGET_OS_IPHONE if (__builtin_available(macOS 10.15, *)) { CFDictionarySetValue(query, kSecUseDataProtectionKeychain, kCFBooleanTrue); } #endif auto ret = SecItemDelete(query); #if !TARGET_OS_IPHONE if (__builtin_available(macOS 10.15, *)) { if (ret == errSecMissingEntitlement) { CFDictionarySetValue(query, kSecUseDataProtectionKeychain, kCFBooleanFalse); ret = SecItemDelete(query); } } #endif return ret == errSecSuccess; } bool DeleteNativePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) { std::string tagString = KeyTagPrefix + id; ScopedCFTypeRef tagDataRef = CFDataCreate(NULL, (const UInt8*)tagString.c_str(), tagString.size()); ScopedCFTypeRef query = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); CFDictionaryAddValue(query, kSecClass, kSecClassKey); CFDictionaryAddValue(query, kSecAttrApplicationTag, tagDataRef); AddAccessGroup(query); return DeleteWithQuery(query); } bool DeleteGenericPersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) { ScopedCFTypeRef accountString = CFStringCreateWithCString(NULL, id.c_str(), kCFStringEncodingUTF8); ScopedCFTypeRef query = CFDictionaryCreateMutable( NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks); CFDictionaryAddValue(query, kSecAttrService, KeyServiceLabel); CFDictionaryAddValue(query, kSecAttrAccount, accountString); CFDictionaryAddValue(query, kSecClass, kSecClassGenericPassword); AddAccessGroup(query); return DeleteWithQuery(query); } } // namespace detail } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/detail/persisted_key_pair_generic.cpp ================================================ #include "mls/detail/persisted_key_pair.h" #include #include #include #include #include #include #include #ifdef _WIN32 #include #else #include #include #endif #include #include #include #include #include "mls/parameters.h" static const std::string_view KeyStorageDir = "Discord Key Storage"; static std::filesystem::path GetKeyStorageDirectory() { std::filesystem::path dir; #if defined(__ANDROID__) dir = std::filesystem::path("/data/data"); { std::ifstream idFile("/proc/self/cmdline", std::ios_base::in); std::string appId; std::getline(idFile, appId, '\0'); dir /= appId; } #else // __ANDROID__ #if defined(_WIN32) if (const wchar_t* appdata = _wgetenv(L"LOCALAPPDATA")) { dir = std::filesystem::path(appdata); } #else // _WIN32 if (const char* xdg = getenv("XDG_CONFIG_HOME")) { dir = std::filesystem::path(xdg); } else if (const char* home = getenv("HOME")) { dir = std::filesystem::path(home); dir /= ".config"; } #endif // !_WIN32 else { return dir; } #endif // !__ANDROID__ return dir / KeyStorageDir; } namespace discord { namespace dave { namespace mls { namespace detail { std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair( [[maybe_unused]] KeyPairContextType ctx, const std::string& id, ::mlspp::CipherSuite suite) { ::mlspp::SignaturePrivateKey ret; std::string curstr; std::filesystem::path dir = GetKeyStorageDirectory(); if (dir.empty()) { DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; return nullptr; } std::error_code errc; std::filesystem::create_directories(dir, errc); if (errc) { DISCORD_LOG(LS_ERROR) << "Failed to create key storage directory in GetPersistedKeyPair: " << errc; return nullptr; } std::filesystem::path file = dir / (id + ".key"); if (std::filesystem::exists(file)) { std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary); if (!ifs) { DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair"; return nullptr; } curstr = (std::stringstream() << ifs.rdbuf()).str(); if (!ifs) { DISCORD_LOG(LS_ERROR) << "Failed to read key in GetPersistedKeyPair"; return nullptr; } try { ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr); } catch (std::exception& ex) { DISCORD_LOG(LS_ERROR) << "Failed to parse key in GetPersistedKeyPair: " << ex.what(); return nullptr; } } else { ret = ::mlspp::SignaturePrivateKey::generate(suite); std::string newstr = ret.to_jwk(suite); std::filesystem::path tmpfile = file; tmpfile += ".tmp"; #ifdef _WIN32 int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE); #else int fd = open(tmpfile.c_str(), O_WRONLY | O_CLOEXEC | O_NOFOLLOW | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); #endif if (fd < 0) { DISCORD_LOG(LS_ERROR) << "Failed to open output file in GetPersistedKeyPair: " << errno << "(" << tmpfile << ")"; return nullptr; } #ifdef _WIN32 int wret = _write(fd, newstr.c_str(), static_cast(newstr.size())); _close(fd); #else ssize_t wret = write(fd, newstr.c_str(), newstr.size()); close(fd); #endif if (wret < 0 || (size_t)wret != newstr.size()) { DISCORD_LOG(LS_ERROR) << "Failed to write output file in GetPersistedKeyPair: " << errno; return nullptr; } std::filesystem::rename(tmpfile, file, errc); if (errc) { DISCORD_LOG(LS_ERROR) << "Failed to rename output file in GetPersistedKeyPair: " << errc; return nullptr; } } if (!ret.public_key.data.empty()) { return std::make_shared<::mlspp::SignaturePrivateKey>(std::move(ret)); } else { return nullptr; } } bool DeleteGenericPersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) { std::error_code errc; std::filesystem::path dir = GetKeyStorageDirectory(); if (dir.empty()) { DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair"; return false; } std::filesystem::path file = dir / (id + ".key"); return std::filesystem::remove(file, errc); } } // namespace detail } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/detail/persisted_key_pair_null.cpp ================================================ #include "mls/detail/persisted_key_pair.h" namespace discord { namespace dave { namespace mls { namespace detail { std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair( [[maybe_unused]] KeyPairContextType ctx, [[maybe_unused]] const std::string& keyID, [[maybe_unused]] ::mlspp::CipherSuite suite, bool& supported) { supported = false; return nullptr; } bool DeleteNativePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, [[maybe_unused]] const std::string& keyID) { return false; } } // namespace detail } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/detail/persisted_key_pair_win.cpp ================================================ #include "mls/detail/persisted_key_pair.h" #include #include #include #include #include #include #include #ifndef SECURITY_WIN32 #define SECURITY_WIN32 1 #endif #include #include #include #include #include #include #include #include #include "mls/parameters.h" static const std::string KeyTagPrefix = "discord-secure-frames-key-"; template struct ScopedNCryptHandle { ScopedNCryptHandle() = default; ScopedNCryptHandle(T handle) : handle_(handle) { } ScopedNCryptHandle(const ScopedNCryptHandle& other) = delete; ScopedNCryptHandle(ScopedNCryptHandle&& other) : handle_(std::exchange(other.handle_, T())) { } ~ScopedNCryptHandle() { finalize(); } ScopedNCryptHandle& operator=(T handle) { finalize(); handle_ = handle; return *this; } T release() { return std::exchange(handle_, T()); } void finalize() { if (auto handle = release()) { NCryptFreeObject(handle); } } T& get() { return handle_; } T* getPtr() { return &handle_; } operator T&() { return get(); } explicit operator bool() { return handle_ != T(); } T handle_ = T(); }; namespace discord { namespace dave { namespace mls { namespace detail { std::shared_ptr<::mlspp::SignaturePrivateKey> GetNativePersistedKeyPair( [[maybe_unused]] KeyPairContextType ctx, const std::string& id, ::mlspp::CipherSuite suite, bool& supported) { LPCWSTR keyType = nullptr; ULONG keyBlobMagic = 0; std::function convertBlob; auto suiteId = suite.cipher_suite(); switch (suiteId) { case ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256: case ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384: case ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521: supported = true; if (suiteId == ::mlspp::CipherSuite::ID::P521_AES256GCM_SHA512_P521) { keyType = BCRYPT_ECDSA_P521_ALGORITHM; keyBlobMagic = BCRYPT_ECDSA_PRIVATE_P521_MAGIC; } else if (suiteId == ::mlspp::CipherSuite::ID::P384_AES256GCM_SHA384_P384) { keyType = BCRYPT_ECDSA_P384_ALGORITHM; keyBlobMagic = BCRYPT_ECDSA_PRIVATE_P384_MAGIC; } else { keyType = BCRYPT_ECDSA_P256_ALGORITHM; keyBlobMagic = BCRYPT_ECDSA_PRIVATE_P256_MAGIC; } convertBlob = [](bytes& blob) { // https://learn.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_ecckey_blob // Input has an PBCRYPT_ECCKEY_BLOB header, followed by 3 cbKey-byte big-endian // integers: X, Y, and d. X and Y are the public key (represented as the coordinates); // d is the private key. constexpr size_t ValueCount = 3; constexpr size_t PublicValues = 2; if (blob.size() < sizeof(BCRYPT_ECCKEY_BLOB)) { DISCORD_LOG(LS_ERROR) << "Exported key blob too small in GetPersistedKeyPair/convertBlob: " << blob.size(); return false; } PBCRYPT_ECCKEY_BLOB header = (PBCRYPT_ECCKEY_BLOB)blob.data(); ULONG keySize = header->cbKey; if (blob.size() < sizeof(BCRYPT_ECCKEY_BLOB) + keySize * ValueCount) { DISCORD_LOG(LS_ERROR) << "Exported key blob too small in GetPersistedKeyPair/convertBlob: " << blob.size(); return false; } blob.resize(sizeof(BCRYPT_ECCKEY_BLOB) + keySize * ValueCount); blob.as_vec().erase(blob.begin(), blob.begin() + sizeof(BCRYPT_ECCKEY_BLOB) + keySize * PublicValues); return true; }; break; default: // Other suites will need to store keys as JWK files on disk return nullptr; } assert(keyType && keyBlobMagic && convertBlob); ScopedNCryptHandle provider; SECURITY_STATUS status = NCryptOpenStorageProvider(provider.getPtr(), MS_KEY_STORAGE_PROVIDER, 0); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to open storage provider in GetPersistedKeyPair: " << status; return nullptr; } std::filesystem::path keyName = KeyTagPrefix + id; ScopedNCryptHandle key; status = NCryptOpenKey(provider, key.getPtr(), keyName.c_str(), AT_SIGNATURE, NCRYPT_SILENT_FLAG); if (status == NTE_BAD_KEYSET) { DISCORD_LOG(LS_INFO) << "No key found in GetPersistedKeyPair; generating new"; status = NCryptCreatePersistedKey( provider, key.getPtr(), keyType, keyName.c_str(), AT_SIGNATURE, 0); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to create key in GetPersistedKeyPair: " << status; return nullptr; } DWORD exportPolicyValue = NCRYPT_ALLOW_EXPORT_FLAG | NCRYPT_ALLOW_PLAINTEXT_EXPORT_FLAG; status = NCryptSetProperty(key, NCRYPT_EXPORT_POLICY_PROPERTY, (PBYTE)&exportPolicyValue, sizeof(exportPolicyValue), NCRYPT_PERSIST_FLAG | NCRYPT_SILENT_FLAG); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to configure key export policy in GetPersistedKeyPair: " << status; return nullptr; } // struct { // DWORD dwVersion; // DWORD dwFlags; // LPCWSTR pszCreationTitle; // LPCWSTR pszFriendlyName; // LPCWSTR pszDescription; // } NCRYPT_UI_POLICY; NCRYPT_UI_POLICY uiPolicyValue = {1, 0, nullptr, nullptr, nullptr}; status = NCryptSetProperty(key, NCRYPT_UI_POLICY_PROPERTY, (PBYTE)&uiPolicyValue, sizeof(uiPolicyValue), NCRYPT_PERSIST_FLAG | NCRYPT_SILENT_FLAG); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to configure key UI policy in GetPersistedKeyPair: " << status; return nullptr; } status = NCryptFinalizeKey(key, NCRYPT_SILENT_FLAG); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to finalize key in GetPersistedKeyPair: " << status; return nullptr; } } else if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to open key in GetPersistedKeyPair: " << status; return nullptr; } DWORD keySize = 0; status = NCryptExportKey( key, NULL, BCRYPT_PRIVATE_KEY_BLOB, NULL, NULL, 0, &keySize, NCRYPT_SILENT_FLAG); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to size key in GetPersistedKeyPair: " << status; return nullptr; } bytes keyData(keySize); status = NCryptExportKey(key, NULL, BCRYPT_PRIVATE_KEY_BLOB, NULL, keyData.data(), keySize, &keySize, NCRYPT_SILENT_FLAG); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to export key in GetPersistedKeyPair: " << status; return nullptr; } if (keyData.size() < sizeof(BCRYPT_KEY_BLOB)) { DISCORD_LOG(LS_ERROR) << "Exported key blob too small in GetPersistedKeyPair/convertBlob: " << keyData.size(); return nullptr; } BCRYPT_KEY_BLOB* header = (BCRYPT_KEY_BLOB*)keyData.data(); if (header->Magic != keyBlobMagic) { DISCORD_LOG(LS_ERROR) << "Exported key blob has unexpected magic in GetPersistedKeyPair: " << header->Magic; return nullptr; } if (!convertBlob(keyData)) { DISCORD_LOG(LS_ERROR) << "Failed to convert key in GetPersistedKeyPair"; return nullptr; } return std::make_shared<::mlspp::SignaturePrivateKey>( ::mlspp::SignaturePrivateKey::parse(suite, keyData)); } bool DeleteNativePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& id) { ScopedNCryptHandle provider; SECURITY_STATUS status = NCryptOpenStorageProvider(provider.getPtr(), MS_KEY_STORAGE_PROVIDER, 0); if (status != ERROR_SUCCESS) { DISCORD_LOG(LS_ERROR) << "Failed to open storage provider in DeletePersistedKeyPair: " << status; return false; } std::filesystem::path keyName = KeyTagPrefix + id; ScopedNCryptHandle key; status = NCryptOpenKey(provider, key.getPtr(), keyName.c_str(), AT_SIGNATURE, NCRYPT_SILENT_FLAG); if (status != ERROR_SUCCESS) { return false; } auto ret = NCryptDeleteKey(key, NCRYPT_SILENT_FLAG); if (ret == ERROR_SUCCESS) { // If NCryptDeleteKey succeeds, it frees the handle, so our wrapper shouldn't also do so. key.release(); return true; } else { return false; } } } // namespace detail } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/parameters.cpp ================================================ #include "parameters.h" namespace discord { namespace dave { namespace mls { ::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion( [[maybe_unused]] ProtocolVersion version) noexcept { return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; } ::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept { return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; } ::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion( [[maybe_unused]] SignatureVersion version) noexcept { return ::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256; } ::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept { return ::mlspp::CipherSuite{CiphersuiteIDForProtocolVersion(version)}; } ::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept { auto capabilities = ::mlspp::Capabilities::create_default(); capabilities.cipher_suites = {CiphersuiteIDForProtocolVersion(version)}; capabilities.credentials = {::mlspp::CredentialType::basic}; return capabilities; } ::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion( [[maybe_unused]] ProtocolVersion version) noexcept { return ::mlspp::ExtensionList{}; } ::mlspp::ExtensionList GroupExtensionsForProtocolVersion( [[maybe_unused]] ProtocolVersion version, const ::mlspp::ExternalSender& externalSender) noexcept { auto extensionList = ::mlspp::ExtensionList{}; extensionList.add(::mlspp::ExternalSendersExtension{{ {externalSender.signature_key, externalSender.credential}, }}); return extensionList; } } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/parameters.h ================================================ #pragma once #include #include #include #include namespace discord { namespace dave { namespace mls { ::mlspp::CipherSuite::ID CiphersuiteIDForProtocolVersion(ProtocolVersion version) noexcept; ::mlspp::CipherSuite CiphersuiteForProtocolVersion(ProtocolVersion version) noexcept; ::mlspp::CipherSuite::ID CiphersuiteIDForSignatureVersion(SignatureVersion version) noexcept; ::mlspp::CipherSuite CiphersuiteForSignatureVersion(SignatureVersion version) noexcept; ::mlspp::Capabilities LeafNodeCapabilitiesForProtocolVersion(ProtocolVersion version) noexcept; ::mlspp::ExtensionList LeafNodeExtensionsForProtocolVersion(ProtocolVersion version) noexcept; ::mlspp::ExtensionList GroupExtensionsForProtocolVersion( ProtocolVersion version, const ::mlspp::ExternalSender& externalSender) noexcept; } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/persisted_key_pair.cpp ================================================ #include "mls/detail/persisted_key_pair.h" #include #include #include #include #include #include #include #include #include #include #include "mls/parameters.h" static const std::string SelfSignatureLabel = "DiscordSelfSignature"; static std::string MakeKeyID(const std::string& sessionID, ::mlspp::CipherSuite suite) { return sessionID + "-" + std::to_string((uint16_t)suite.cipher_suite()) + "-" + std::to_string(discord::dave::mls::KeyVersion); } static std::mutex mtx; static std::map> map; namespace discord { namespace dave { namespace mls { static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( [[maybe_unused]] KeyPairContextType ctx, const std::string& sessionID, ::mlspp::CipherSuite suite) { std::lock_guard lk(mtx); std::string id = MakeKeyID(sessionID, suite); if (auto it = map.find(id); it != map.end()) { return it->second; } std::shared_ptr<::mlspp::SignaturePrivateKey> ret; bool supported = false; ret = ::discord::dave::mls::detail::GetNativePersistedKeyPair(ctx, id, suite, supported); if (!ret && supported) { // Do not fall back on the generic route if we error here DISCORD_LOG(LS_ERROR) << "Encountered error in native key handling in GetPersistedKeyPair"; return nullptr; } if (!ret) { ret = ::discord::dave::mls::detail::GetGenericPersistedKeyPair(ctx, id, suite); } if (!ret) { DISCORD_LOG(LS_ERROR) << "Failed to get key in GetPersistedKeyPair"; return nullptr; } map.emplace(id, ret); return ret; } std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, const std::string& sessionID, ProtocolVersion version) { return GetPersistedKeyPair(ctx, sessionID, CiphersuiteForProtocolVersion(version)); } KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, const std::string& sessionID, SignatureVersion version) { auto suite = CiphersuiteForSignatureVersion(version); auto pair = GetPersistedKeyPair(ctx, sessionID, suite); if (!pair) { return {}; } bytes sign_data = from_ascii(sessionID + ":") + pair->public_key.data; return { pair->public_key.data.as_vec(), std::move(pair->sign(suite, SelfSignatureLabel, sign_data).as_vec()), }; } bool DeletePersistedKeyPair([[maybe_unused]] KeyPairContextType ctx, const std::string& sessionID, SignatureVersion version) { std::string id = MakeKeyID(sessionID, CiphersuiteForSignatureVersion(version)); std::lock_guard lk(mtx); map.erase(id); bool native = ::discord::dave::mls::detail::DeleteNativePersistedKeyPair(ctx, id); bool generic = ::discord::dave::mls::detail::DeleteGenericPersistedKeyPair(ctx, id); return native || generic; } } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/persisted_key_pair.h ================================================ #pragma once #include #include #include #ifdef __ANDROID__ #include #endif #include #include namespace mlspp { struct SignaturePrivateKey; }; namespace discord { namespace dave { namespace mls { std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(KeyPairContextType ctx, const std::string& sessionID, ProtocolVersion version); struct KeyAndSelfSignature { std::vector key; std::vector signature; }; KeyAndSelfSignature GetPersistedPublicKey(KeyPairContextType ctx, const std::string& sessionID, SignatureVersion version); bool DeletePersistedKeyPair(KeyPairContextType ctx, const std::string& sessionID, SignatureVersion version); constexpr unsigned KeyVersion = 1; } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/persisted_key_pair_null.cpp ================================================ #include "persisted_key_pair.h" namespace discord { namespace dave { namespace mls { std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair( [[maybe_unused]] KeyPairContextType, [[maybe_unused]] const std::string&, [[maybe_unused]] ProtocolVersion) { return nullptr; } bool DeletePersistedKeyPair([[maybe_unused]] KeyPairContextType, [[maybe_unused]] const std::string&, [[maybe_unused]] SignatureVersion) { return false; } } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/session.cpp ================================================ #include "session.h" #include #include #include #include #include #include #include #include #include #include "common.h" #include "mls/parameters.h" #include "mls/persisted_key_pair.h" #include "mls/user_credential.h" #include "mls/util.h" #include "mls_key_ratchet.h" #include "openssl/evp.h" #define TRACK_MLS_ERROR(reason) \ if (onMLSFailureCallback_) { \ onMLSFailureCallback_(__FUNCTION__, reason); \ } namespace discord { namespace dave { namespace mls { struct QueuedProposal { ::mlspp::ValidatedContent content; ::mlspp::bytes_ns::bytes ref; }; std::unique_ptr CreateSession(KeyPairContextType context, std::string authSessionId, MLSFailureCallback callback) noexcept { return std::make_unique(context, authSessionId, callback); } Session::Session(KeyPairContextType context, std::string authSessionId, MLSFailureCallback callback) noexcept : signingKeyId_(authSessionId) , keyPairContext_(context) , onMLSFailureCallback_(std::move(callback)) { DISCORD_LOG(LS_INFO) << "Creating a new MLS session"; } Session::~Session() noexcept = default; void Session::Init(ProtocolVersion protocolVersion, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept { Reset(); selfUserId_ = selfUserId; DISCORD_LOG(LS_INFO) << "Initializing MLS session with protocol version " << protocolVersion << " and group ID " << groupId; protocolVersion_ = protocolVersion; groupId_ = std::move(BigEndianBytesFrom(groupId).as_vec()); InitLeafNode(selfUserId, transientKey); if (externalSender_) { CreatePendingGroup(); } else { DISCORD_LOG(LS_INFO) << "Waiting for external sender to create a pending group"; } } void Session::Reset() noexcept { DISCORD_LOG(LS_INFO) << "Resetting MLS session"; ClearPendingState(); currentState_.reset(); outboundCachedGroupState_.reset(); protocolVersion_ = 0; groupId_.clear(); } void Session::SetProtocolVersion(ProtocolVersion version) noexcept { if (version != protocolVersion_) { // when we need to retain backwards compatibility // there may be some changes to the MLS objects required here // until then we can just update the stored version protocolVersion_ = version; } } std::vector Session::GetLastEpochAuthenticator() const noexcept { if (!currentState_) { DISCORD_LOG(LS_ERROR) << "Cannot get epoch authenticator without an established MLS group"; return {}; } return std::move(currentState_->epoch_authenticator().as_vec()); } void Session::SetExternalSender(const std::vector& marshalledExternalSender) noexcept try { if (currentState_) { DISCORD_LOG(LS_ERROR) << "Cannot set external sender after joining/creating an MLS group"; return; } DISCORD_LOG(LS_INFO) << "Unmarshalling MLS external sender"; DISCORD_LOG(LS_INFO) << "Sender: " << ::mlspp::bytes_ns::bytes(marshalledExternalSender); externalSender_ = std::make_unique<::mlspp::ExternalSender>( ::mlspp::tls::get<::mlspp::ExternalSender>(marshalledExternalSender)); if (!groupId_.empty()) { CreatePendingGroup(); } else { DISCORD_LOG(LS_INFO) << "Waiting for group ID to create a pending group"; } } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to unmarshal external sender: " << e.what(); TRACK_MLS_ERROR(e.what()); return; } std::optional> Session::ProcessProposals( std::vector proposals, std::set const& recognizedUserIDs) noexcept try { if (!pendingGroupState_ && !currentState_) { DISCORD_LOG(LS_ERROR) << "Cannot process proposals without any pending or established MLS group state"; return std::nullopt; } if (!stateWithProposals_) { stateWithProposals_ = std::make_unique<::mlspp::State>( pendingGroupState_ ? *pendingGroupState_ : *currentState_); } DISCORD_LOG(LS_INFO) << "Processing MLS proposals message of " << proposals.size() << " bytes"; DISCORD_LOG(LS_INFO) << "Proposals: " << ::mlspp::bytes_ns::bytes(proposals); ::mlspp::tls::istream inStream(proposals); bool isRevoke = false; inStream >> isRevoke; DISCORD_LOG(LS_INFO) << "Revoking: " << isRevoke; const auto suite = stateWithProposals_->cipher_suite(); if (isRevoke) { std::vector<::mlspp::bytes_ns::bytes> refs; inStream >> refs; for (const auto& ref : refs) { bool found = false; for (auto it = proposalQueue_.begin(); it != proposalQueue_.end(); it++) { if (it->ref == ref) { found = true; proposalQueue_.erase(it); break; } } if (!found) { DISCORD_LOG(LS_ERROR) << "Cannot revoke unrecognized proposal ref"; TRACK_MLS_ERROR("Unrecognized proposal revocation"); return std::nullopt; } } stateWithProposals_ = std::make_unique<::mlspp::State>( pendingGroupState_ ? *pendingGroupState_ : *currentState_); for (auto& prop : proposalQueue_) { // success will queue the proposal, failure will throw stateWithProposals_->handle(prop.content); } } else { std::vector<::mlspp::MLSMessage> messages; inStream >> messages; for (const auto& proposalMessage : messages) { auto validatedMessage = stateWithProposals_->unwrap(proposalMessage); if (!ValidateProposalMessage(validatedMessage.authenticated_content(), *stateWithProposals_, recognizedUserIDs)) { return std::nullopt; } // success will queue the proposal, failure will throw stateWithProposals_->handle(validatedMessage); auto ref = suite.ref(validatedMessage.authenticated_content()); proposalQueue_.push_back({ std::move(validatedMessage), std::move(ref), }); } } // generate a commit auto commitSecret = ::mlspp::hpke::random_bytes(suite.secret_size()); auto commitOpts = ::mlspp::CommitOpts{ {}, // no extra proposals true, // inline tree in welcome false, // do not force path {} // default leaf node options }; auto [commitMessage, welcomeMessage, newState] = stateWithProposals_->commit(commitSecret, commitOpts, {}); DISCORD_LOG(LS_INFO) << "Prepared commit/welcome/next state for MLS group from received proposals"; // combine the commit and welcome messages into a single buffer auto outStream = ::mlspp::tls::ostream(); outStream << commitMessage; // keep a copy of the commit, we can check incoming pending group commit later for a match pendingGroupCommit_ = std::make_unique<::mlspp::MLSMessage>(std::move(commitMessage)); // if there were any add proposals in this commit, then we also include the welcome message if (welcomeMessage.secrets.size() > 0) { outStream << welcomeMessage; } // cache the outbound state in case we're the winning sender outboundCachedGroupState_ = std::make_unique<::mlspp::State>(std::move(newState)); DISCORD_LOG(LS_INFO) << "Output: " << ::mlspp::bytes_ns::bytes(outStream.bytes()); return outStream.bytes(); } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to parse MLS proposals: " << e.what(); TRACK_MLS_ERROR(e.what()); return std::nullopt; } bool Session::IsRecognizedUserID(const ::mlspp::Credential& cred, std::set const& recognizedUserIDs) const { std::string uid = UserCredentialToString(cred, protocolVersion_); if (uid.empty()) { DISCORD_LOG(LS_ERROR) << "Attempted to verify credential of unexpected type"; return false; } if (recognizedUserIDs.find(uid) == recognizedUserIDs.end()) { DISCORD_LOG(LS_ERROR) << "Attempted to verify credential for unrecognized user ID: " << uid; return false; } return true; } bool Session::ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& targetState, std::set const& recognizedUserIDs) const { if (message.wire_format != ::mlspp::WireFormat::mls_public_message) { DISCORD_LOG(LS_ERROR) << "MLS proposal message must be PublicMessage"; TRACK_MLS_ERROR("Invalid proposal wire format"); return false; } if (message.content.epoch != targetState.epoch()) { DISCORD_LOG(LS_ERROR) << "MLS proposal message must be for current epoch (" << message.content.epoch << " != " << targetState.epoch() << ")"; TRACK_MLS_ERROR("Proposal epoch mismatch"); return false; } if (message.content.content_type() != ::mlspp::ContentType::proposal) { DISCORD_LOG(LS_ERROR) << "ProcessProposals called with non-proposal message"; TRACK_MLS_ERROR("Unexpected message type"); return false; } if (message.content.sender.sender_type() != ::mlspp::SenderType::external) { DISCORD_LOG(LS_ERROR) << "MLS proposal must be from external sender"; TRACK_MLS_ERROR("Unexpected proposal sender type"); return false; } const auto& proposal = ::mlspp::tls::var::get<::mlspp::Proposal>(message.content.content); switch (proposal.proposal_type()) { case ::mlspp::ProposalType::add: { const auto& credential = ::mlspp::tls::var::get<::mlspp::Add>(proposal.content).key_package.leaf_node.credential; if (!IsRecognizedUserID(credential, recognizedUserIDs)) { DISCORD_LOG(LS_ERROR) << "MLS add proposal must be for recognized user"; TRACK_MLS_ERROR("Unexpected user ID in add proposal"); return false; } break; } case ::mlspp::ProposalType::remove: // Remove proposals are always allowed (mlspp will validate that it's a recognized user) break; default: DISCORD_LOG(LS_ERROR) << "MLS proposal must be add or remove"; TRACK_MLS_ERROR("Unexpected proposal type"); return false; } return true; } bool Session::CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept { if (!stateWithProposals_) { return false; } if (commit.group_id() != groupId_) { DISCORD_LOG(LS_ERROR) << "MLS commit message was for unexpected group"; return false; } return true; } RosterVariant Session::ProcessCommit(std::vector commit) noexcept try { DISCORD_LOG(LS_INFO) << "Processing commit"; DISCORD_LOG(LS_INFO) << "Commit: " << ::mlspp::bytes_ns::bytes(commit); auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit); if (!CanProcessCommit(commitMessage)) { DISCORD_LOG(LS_ERROR) << "ProcessCommit called with unprocessable MLS commit"; return ignored_t{}; } // in case we're the sender of this commit // we need to pull the cached state from our outbound cache std::optional<::mlspp::State> optionalCachedState = std::nullopt; if (outboundCachedGroupState_) { optionalCachedState = *(outboundCachedGroupState_.get()); } auto validatedMessage = stateWithProposals_->unwrap(commitMessage); const auto& authenticatedContent = validatedMessage.authenticated_content(); if (authenticatedContent.wire_format != ::mlspp::WireFormat::mls_public_message) { throw std::invalid_argument("Invalid commit wire format"); } if (authenticatedContent.content.epoch != stateWithProposals_->epoch()) { throw std::invalid_argument("Commit epoch mismatch"); } if (authenticatedContent.content.content_type() != ::mlspp::ContentType::commit) { throw std::invalid_argument("Unexpected message type"); } if (authenticatedContent.content.sender.sender_type() != ::mlspp::SenderType::member) { throw std::invalid_argument("Unexpected commit sender type"); } const auto& commitContent = ::mlspp::tls::var::get<::mlspp::Commit>(authenticatedContent.content.content); for (const auto& proposalOrRef : commitContent.proposals) { if (::mlspp::tls::variant<::mlspp::ProposalOrRefType>::type(proposalOrRef.content) != ::mlspp::ProposalOrRefType::reference) { throw std::invalid_argument("Unexpected non-ref proposal"); } } auto newState = stateWithProposals_->handle(validatedMessage, optionalCachedState); if (!newState) { DISCORD_LOG(LS_ERROR) << "MLS commit handling did not produce a new state"; return failed_t{}; } DISCORD_LOG(LS_INFO) << "Successfully processed MLS commit, updating state; our leaf index is " << newState->index().val << "; current epoch is " << newState->epoch(); RosterMap ret = ReplaceState(std::make_unique<::mlspp::State>(std::move(*newState))); // reset the outbound cached group since we handled the commit for this epoch outboundCachedGroupState_.reset(); ClearPendingState(); return ret; } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to process MLS commit: " << e.what(); TRACK_MLS_ERROR(e.what()); return failed_t{}; } std::optional Session::ProcessWelcome( std::vector welcome, std::set const& recognizedUserIDs) noexcept try { if (!HasCryptographicStateForWelcome()) { DISCORD_LOG(LS_ERROR) << "Missing local cyrpto state necessary to process MLS welcome"; return std::nullopt; } if (!externalSender_) { DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome without an external sender"; return std::nullopt; } if (currentState_) { DISCORD_LOG(LS_ERROR) << "Cannot process MLS welcome after joining/creating an MLS group"; return std::nullopt; } DISCORD_LOG(LS_INFO) << "Processing welcome: " << ::mlspp::bytes_ns::bytes(welcome); // unmarshal the incoming welcome auto unmarshalledWelcome = ::mlspp::tls::get<::mlspp::Welcome>(welcome); // construct the state from the unmarshalled welcome auto newState = std::make_unique<::mlspp::State>( *joinInitPrivateKey_, *selfHPKEPrivateKey_, *selfSigPrivateKey_, *joinKeyPackage_, unmarshalledWelcome, std::nullopt, std::map<::mlspp::bytes_ns::bytes, ::mlspp::bytes_ns::bytes>()); // perform application-level verification of the new state if (!VerifyWelcomeState(*newState, recognizedUserIDs)) { DISCORD_LOG(LS_ERROR) << "Group received in MLS welcome is not valid"; return std::nullopt; } DISCORD_LOG(LS_INFO) << "Successfully welcomed to MLS Group, our leaf index is " << newState->index().val << "; current epoch is " << newState->epoch(); // make the verified state our new (and only) state RosterMap ret = ReplaceState(std::move(newState)); // clear out any pending state for creating/joining a group ClearPendingState(); return ret; } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to create group state from MLS welcome: " << e.what(); TRACK_MLS_ERROR(e.what()); return std::nullopt; } RosterMap Session::ReplaceState(std::unique_ptr<::mlspp::State>&& state) { RosterMap newRoster; for (const ::mlspp::LeafNode& node : state->roster()) { if (node.credential.type() != ::mlspp::CredentialType::basic) { throw std::invalid_argument("Unexpected credential type in roster"); } const auto& cred = node.credential.template get<::mlspp::BasicCredential>(); if (newRoster.emplace(FromBigEndianBytes(cred.identity), node.signature_key.data.as_vec()) .second != true) { throw std::invalid_argument("Duplicate identity in roster"); } } RosterMap changeMap; std::set_difference(newRoster.begin(), newRoster.end(), roster_.begin(), roster_.end(), std::inserter(changeMap, changeMap.end())); struct MissingItemWrapper { RosterMap& changeMap_; using iterator = RosterMap::iterator; using const_iterator = RosterMap::const_iterator; using value_type = RosterMap::value_type; iterator insert(const_iterator it, const value_type& value) { return changeMap_.try_emplace(it, value.first, std::vector{}); } iterator begin() { return changeMap_.begin(); } iterator end() { return changeMap_.end(); } }; MissingItemWrapper wrapper{changeMap}; std::set_difference(roster_.begin(), roster_.end(), newRoster.begin(), newRoster.end(), std::inserter(wrapper, wrapper.end())); roster_ = std::move(newRoster); currentState_ = std::move(state); return changeMap; } bool Session::HasCryptographicStateForWelcome() const noexcept { return joinKeyPackage_ && joinInitPrivateKey_ && selfSigPrivateKey_ && selfHPKEPrivateKey_; } bool Session::VerifyWelcomeState(::mlspp::State const& state, std::set const& recognizedUserIDs) const { if (!externalSender_) { DISCORD_LOG(LS_ERROR) << "Cannot verify MLS welcome without an external sender"; TRACK_MLS_ERROR("Missing external sender when processing Welcome"); return false; } auto ext = state.extensions().template find(); if (!ext) { DISCORD_LOG(LS_ERROR) << "MLS welcome missing external senders extension"; TRACK_MLS_ERROR("Welcome message missing external sender extension"); return false; } if (ext->senders.size() != 1) { DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected number of external senders: " << ext->senders.size(); TRACK_MLS_ERROR("Welcome message lists unexpected external sender count"); return false; } if (ext->senders.front() != *externalSender_) { DISCORD_LOG(LS_ERROR) << "MLS welcome lists unexpected external sender"; TRACK_MLS_ERROR("Welcome message lists unexpected external sender"); return false; } // TODO: Until we leverage revocation in the protocol // if we re-enable this change we will refuse welcome messages // because someone was previously supposed to be added but disconnected // before all in-flight proposals were handled. for (const auto& leaf : state.roster()) { if (!IsRecognizedUserID(leaf.credential, recognizedUserIDs)) { DISCORD_LOG(LS_ERROR) << "MLS welcome lists unrecognized user ID"; // TRACK_MLS_ERROR("Welcome message lists unrecognized user ID"); // return false; } } return true; } void Session::InitLeafNode(std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept try { auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); if (!transientKey) { if (!signingKeyId_.empty()) { transientKey = GetPersistedKeyPair(keyPairContext_, signingKeyId_, protocolVersion_); if (!transientKey) { DISCORD_LOG(LS_ERROR) << "Did not receive MLS signature private key from " "GetPersistedKeyPair; aborting"; return; } } else { transientKey = std::make_shared<::mlspp::SignaturePrivateKey>( ::mlspp::SignaturePrivateKey::generate(ciphersuite)); } } selfSigPrivateKey_ = transientKey; auto selfCredential = CreateUserCredential(selfUserId, protocolVersion_); selfHPKEPrivateKey_ = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); selfLeafNode_ = std::make_unique<::mlspp::LeafNode>(ciphersuite, selfHPKEPrivateKey_->public_key, selfSigPrivateKey_->public_key, std::move(selfCredential), LeafNodeCapabilitiesForProtocolVersion(protocolVersion_), ::mlspp::Lifetime::create_default(), LeafNodeExtensionsForProtocolVersion(protocolVersion_), *selfSigPrivateKey_); DISCORD_LOG(LS_INFO) << "Created MLS leaf node"; } catch (const std::exception& e) { DISCORD_LOG(LS_INFO) << "Failed to initialize MLS leaf node: " << e.what(); TRACK_MLS_ERROR(e.what()); } void Session::ResetJoinKeyPackage() noexcept try { if (!selfLeafNode_) { DISCORD_LOG(LS_ERROR) << "Cannot initialize join key package without a leaf node"; return; } auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); joinInitPrivateKey_ = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite)); joinKeyPackage_ = std::make_unique<::mlspp::KeyPackage>(ciphersuite, joinInitPrivateKey_->public_key, *selfLeafNode_, LeafNodeExtensionsForProtocolVersion(protocolVersion_), *selfSigPrivateKey_); DISCORD_LOG(LS_INFO) << "Generated key package: " << ::mlspp::bytes_ns::bytes(::mlspp::tls::marshal(*joinKeyPackage_)); } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to initialize join key package: " << e.what(); TRACK_MLS_ERROR(e.what()); } void Session::CreatePendingGroup() noexcept try { if (groupId_.empty()) { DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without a group ID"; return; } if (!externalSender_) { DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without ExternalSender"; return; } if (!selfLeafNode_) { DISCORD_LOG(LS_ERROR) << "Cannot create MLS group without self leaf node"; return; } DISCORD_LOG(LS_INFO) << "Creating a pending MLS group"; auto ciphersuite = CiphersuiteForProtocolVersion(protocolVersion_); pendingGroupState_ = std::make_unique<::mlspp::State>( groupId_, ciphersuite, *selfHPKEPrivateKey_, *selfSigPrivateKey_, *selfLeafNode_, GroupExtensionsForProtocolVersion(protocolVersion_, *externalSender_)); DISCORD_LOG(LS_INFO) << "Created a pending MLS group"; } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to create MLS group: " << e.what(); TRACK_MLS_ERROR(e.what()); return; } std::vector Session::GetMarshalledKeyPackage() noexcept try { // key packages are not meant to be re-used // so every time the client asks for a key package we create a new one ResetJoinKeyPackage(); if (!joinKeyPackage_) { DISCORD_LOG(LS_ERROR) << "Cannot marshal an uninitialized key package"; return {}; } return ::mlspp::tls::marshal(*joinKeyPackage_); } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to marshal join key package: " << e.what(); TRACK_MLS_ERROR(e.what()); return {}; } std::unique_ptr Session::GetKeyRatchet(std::string const& userId) const noexcept { if (!currentState_) { DISCORD_LOG(LS_INFO) << "Cannot get key ratchet without an established MLS group"; return nullptr; } // change the string user ID to a little endian 64 bit user ID auto u64userId = strtoull(userId.c_str(), nullptr, 10); auto userIdBytes = ::mlspp::bytes_ns::bytes(sizeof(u64userId)); memcpy(userIdBytes.data(), &u64userId, sizeof(u64userId)); // generate the base secret for the hash ratchet auto baseSecret = currentState_->do_export(Session::USER_MEDIA_KEY_BASE_LABEL, userIdBytes, kAesGcm128KeyBytes); // this assumes the MLS ciphersuite produces a kAesGcm128KeyBytes sized key // would need to be updated to a different ciphersuite if there's a future mismatch return std::make_unique(currentState_->cipher_suite(), std::move(baseSecret)); } void Session::GetPairwiseFingerprint(uint16_t version, std::string const& userId, PairwiseFingerprintCallback callback) const noexcept try { if (!currentState_ || !selfSigPrivateKey_) { throw std::invalid_argument("No established MLS group"); } uint64_t u64RemoteUserId = strtoull(userId.c_str(), nullptr, 10); uint64_t u64SelfUserId = strtoull(selfUserId_.c_str(), nullptr, 10); auto it = roster_.find(u64RemoteUserId); if (it == roster_.end()) { throw std::invalid_argument("Unknown user ID: " + userId); } ::mlspp::tls::ostream toHash1; ::mlspp::tls::ostream toHash2; toHash1 << version; toHash1.write_raw(it->second); toHash1 << u64RemoteUserId; toHash2 << version; toHash2.write_raw(selfSigPrivateKey_->public_key.data); toHash2 << u64SelfUserId; std::vector> keyData = { toHash1.bytes(), toHash2.bytes(), }; std::sort(keyData.begin(), keyData.end()); std::thread([callback = std::move(callback), data = ::mlspp::bytes_ns::bytes(std::move(keyData[0])) + keyData[1]] { static constexpr uint8_t salt[] = { 0x24, 0xca, 0xb1, 0x7a, 0x7a, 0xf8, 0xec, 0x2b, 0x82, 0xb4, 0x12, 0xb9, 0x2d, 0xab, 0x19, 0x2e, }; constexpr uint64_t N = 16384, r = 8, p = 2, max_mem = 32 * 1024 * 1024; constexpr size_t hash_len = 64; std::vector out(hash_len); int ret = EVP_PBE_scrypt((const char*)data.data(), data.size(), salt, sizeof(salt), N, r, p, max_mem, out.data(), out.size()); if (ret == 1) { callback(out); } else { callback({}); } }).detach(); } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to generate pairwise fingerprint: " << e.what(); callback({}); } void Session::ClearPendingState() { pendingGroupState_.reset(); pendingGroupCommit_.reset(); joinInitPrivateKey_.reset(); joinKeyPackage_.reset(); selfHPKEPrivateKey_.reset(); selfLeafNode_.reset(); stateWithProposals_.reset(); proposalQueue_.clear(); } } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/session.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include "mls/persisted_key_pair.h" #include "mls_key_ratchet.h" namespace mlspp { struct AuthenticatedContent; struct Credential; struct ExternalSender; struct HPKEPrivateKey; struct KeyPackage; struct LeafNode; struct MLSMessage; struct SignaturePrivateKey; class State; } // namespace mlspp namespace discord { namespace dave { namespace mls { struct QueuedProposal; class Session final : public ISession { public: Session(KeyPairContextType context, std::string authSessionId, MLSFailureCallback callback) noexcept; virtual ~Session() noexcept; virtual void Init( ProtocolVersion version, uint64_t groupId, std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept override; virtual void Reset() noexcept override; virtual void SetProtocolVersion(ProtocolVersion version) noexcept override; virtual ProtocolVersion GetProtocolVersion() const noexcept override { return protocolVersion_; } virtual std::vector GetLastEpochAuthenticator() const noexcept override; virtual void SetExternalSender( std::vector const& externalSenderPackage) noexcept override; virtual std::optional> ProcessProposals( std::vector proposals, std::set const& recognizedUserIDs) noexcept override; virtual RosterVariant ProcessCommit(std::vector commit) noexcept override; virtual std::optional ProcessWelcome( std::vector welcome, std::set const& recognizedUserIDs) noexcept override; virtual std::vector GetMarshalledKeyPackage() noexcept override; virtual std::unique_ptr GetKeyRatchet( std::string const& userId) const noexcept override; using PairwiseFingerprintCallback = std::function const&)>; virtual void GetPairwiseFingerprint( uint16_t version, std::string const& userId, PairwiseFingerprintCallback callback) const noexcept override; private: void InitLeafNode(std::string const& selfUserId, std::shared_ptr<::mlspp::SignaturePrivateKey>& transientKey) noexcept; void ResetJoinKeyPackage() noexcept; void CreatePendingGroup() noexcept; bool HasCryptographicStateForWelcome() const noexcept; bool IsRecognizedUserID(const ::mlspp::Credential& cred, std::set const& recognizedUserIDs) const; bool ValidateProposalMessage(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& targetState, std::set const& recognizedUserIDs) const; bool VerifyWelcomeState(::mlspp::State const& state, std::set const& recognizedUserIDs) const; bool CanProcessCommit(const ::mlspp::MLSMessage& commit) noexcept; RosterMap ReplaceState(std::unique_ptr<::mlspp::State>&& state); void ClearPendingState(); inline static const std::string USER_MEDIA_KEY_BASE_LABEL = "Discord Secure Frames v0"; ProtocolVersion protocolVersion_; std::vector groupId_; std::string signingKeyId_; std::string selfUserId_; KeyPairContextType keyPairContext_{nullptr}; std::unique_ptr<::mlspp::LeafNode> selfLeafNode_; std::shared_ptr<::mlspp::SignaturePrivateKey> selfSigPrivateKey_; std::unique_ptr<::mlspp::HPKEPrivateKey> selfHPKEPrivateKey_; std::unique_ptr<::mlspp::HPKEPrivateKey> joinInitPrivateKey_; std::unique_ptr<::mlspp::KeyPackage> joinKeyPackage_; std::unique_ptr<::mlspp::ExternalSender> externalSender_; std::unique_ptr<::mlspp::State> pendingGroupState_; std::unique_ptr<::mlspp::MLSMessage> pendingGroupCommit_; std::unique_ptr<::mlspp::State> outboundCachedGroupState_; std::unique_ptr<::mlspp::State> currentState_; RosterMap roster_; std::unique_ptr<::mlspp::State> stateWithProposals_; std::list proposalQueue_; MLSFailureCallback onMLSFailureCallback_{}; }; } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/user_credential.cpp ================================================ #include "user_credential.h" #include #include "mls/util.h" namespace discord { namespace dave { namespace mls { ::mlspp::Credential CreateUserCredential(const std::string& userId, [[maybe_unused]] ProtocolVersion version) { // convert the string user ID to a big endian uint64_t auto userID = std::stoull(userId); auto credentialBytes = BigEndianBytesFrom(userID); return ::mlspp::Credential::basic(credentialBytes); } std::string UserCredentialToString(const ::mlspp::Credential& cred, [[maybe_unused]] ProtocolVersion version) { if (cred.type() != ::mlspp::CredentialType::basic) { return ""; } const auto& basic = cred.template get<::mlspp::BasicCredential>(); auto uidVal = FromBigEndianBytes(basic.identity); return std::to_string(uidVal); } } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/user_credential.h ================================================ #pragma once #include #include #include namespace discord { namespace dave { namespace mls { ::mlspp::Credential CreateUserCredential(const std::string& userId, ProtocolVersion version); std::string UserCredentialToString(const ::mlspp::Credential& cred, ProtocolVersion version); } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/util.cpp ================================================ #include "util.h" namespace discord { namespace dave { namespace mls { ::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept { auto buffer = ::mlspp::bytes_ns::bytes(); buffer.reserve(sizeof(value)); for (int i = sizeof(value) - 1; i >= 0; --i) { buffer.push_back(static_cast(value >> (i * 8))); } return buffer; } uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& buffer) noexcept { uint64_t val = 0; if (buffer.size() <= sizeof(val)) { for (uint8_t byte : buffer) { val = (val << 8) | byte; } } return val; } } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls/util.h ================================================ #pragma once #include #include namespace discord { namespace dave { namespace mls { ::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept; uint64_t FromBigEndianBytes(const ::mlspp::bytes_ns::bytes& value) noexcept; } // namespace mls } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls_key_ratchet.cpp ================================================ #include "mls_key_ratchet.h" #include #include #include "common.h" namespace discord { namespace dave { MlsKeyRatchet::MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept : hashRatchet_(suite, std::move(baseSecret)) { } MlsKeyRatchet::~MlsKeyRatchet() noexcept = default; EncryptionKey MlsKeyRatchet::GetKey(KeyGeneration generation) noexcept { DISCORD_LOG(LS_INFO) << "Retrieving key for generation " << generation << " from HashRatchet"; try { auto keyAndNonce = hashRatchet_.get(generation); assert(keyAndNonce.key.size() >= kAesGcm128KeyBytes); return std::move(keyAndNonce.key.as_vec()); } catch (const std::exception& e) { DISCORD_LOG(LS_ERROR) << "Failed to retrieve key for generation " << generation << ": " << e.what(); return {}; } } void MlsKeyRatchet::DeleteKey(KeyGeneration generation) noexcept { hashRatchet_.erase(generation); } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/mls_key_ratchet.h ================================================ #pragma once #include #include #include namespace discord { namespace dave { class MlsKeyRatchet : public IKeyRatchet { public: MlsKeyRatchet(::mlspp::CipherSuite suite, bytes baseSecret) noexcept; ~MlsKeyRatchet() noexcept override; EncryptionKey GetKey(KeyGeneration generation) noexcept override; void DeleteKey(KeyGeneration generation) noexcept override; const ::mlspp::HashRatchet& GetHashRatchet() const noexcept { return hashRatchet_; } private: ::mlspp::HashRatchet hashRatchet_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/openssl_cryptor.cpp ================================================ #include "openssl_cryptor.h" #include #include #include #include "common.h" namespace discord { namespace dave { void PrintSSLErrors() { ERR_print_errors_cb( [](const char* str, size_t len, [[maybe_unused]] void* ctx) { DISCORD_LOG(LS_ERROR) << std::string(str, len); return 1; }, nullptr); } OpenSSLCryptor::OpenSSLCryptor(const EncryptionKey& encryptionKey) { if (!cipherCtx_) { cipherCtx_ = EVP_CIPHER_CTX_new(); } else { EVP_CIPHER_CTX_reset(cipherCtx_); } auto initResult = EVP_CipherInit_ex(cipherCtx_, EVP_aes_128_gcm(), nullptr, encryptionKey.data(), nullptr, 0); if (initResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to initialize AEAD context"; PrintSSLErrors(); } } OpenSSLCryptor::~OpenSSLCryptor() { EVP_CIPHER_CTX_free(cipherCtx_); } bool OpenSSLCryptor::Encrypt(ArrayView ciphertextBufferOut, ArrayView plaintextBuffer, ArrayView nonceBuffer, ArrayView additionalData, ArrayView tagBufferOut) { if (!cipherCtx_) { DISCORD_LOG(LS_ERROR) << "Encrypt: AEAD context is not initialized"; return false; } auto contextResult = EVP_EncryptInit_ex(cipherCtx_, nullptr, nullptr, nullptr, nonceBuffer.data()); if (contextResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to set nonce for encryption"; PrintSSLErrors(); return false; } int ciphertextOutSize = 0; if (additionalData.size() > 0) { if (additionalData.size() > std::numeric_limits::max()) { DISCORD_LOG(LS_ERROR) << "Additional data size exceeds the maximum supported size"; return false; } auto aadResult = EVP_EncryptUpdate(cipherCtx_, nullptr, &ciphertextOutSize, additionalData.data(), static_cast(additionalData.size())); if (aadResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to update encryption with additional data"; PrintSSLErrors(); return false; } } if (plaintextBuffer.size() > std::numeric_limits::max()) { DISCORD_LOG(LS_ERROR) << "Plaintext buffer size exceeds the maximum supported size"; return false; } auto updateResult = EVP_EncryptUpdate(cipherCtx_, ciphertextBufferOut.data(), &ciphertextOutSize, plaintextBuffer.data(), static_cast(plaintextBuffer.size())); if (updateResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to encrypt plaintext"; PrintSSLErrors(); return false; } auto finalizeResult = EVP_EncryptFinal_ex(cipherCtx_, ciphertextBufferOut.data(), &ciphertextOutSize); if (finalizeResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to finalize encryption"; PrintSSLErrors(); return false; } auto tagResult = EVP_CIPHER_CTX_ctrl( cipherCtx_, EVP_CTRL_GCM_GET_TAG, kAesGcm128TruncatedTagBytes, tagBufferOut.data()); if (tagResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to get truncated authentication tag"; PrintSSLErrors(); return false; } return true; } bool OpenSSLCryptor::Decrypt(ArrayView plaintextBufferOut, ArrayView ciphertextBuffer, ArrayView tagBuffer, ArrayView nonceBuffer, ArrayView additionalData) { if (!cipherCtx_) { DISCORD_LOG(LS_ERROR) << "Decrypt: AEAD context is not initialized"; return false; } auto contextResult = EVP_DecryptInit_ex(cipherCtx_, nullptr, nullptr, nullptr, nonceBuffer.data()); if (contextResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to set nonce for decryption"; PrintSSLErrors(); return false; } int plaintextOutSize = 0; if (additionalData.size() > 0) { if (additionalData.size() > std::numeric_limits::max()) { DISCORD_LOG(LS_ERROR) << "Additional data size exceeds the maximum supported size"; return false; } auto aadResult = EVP_DecryptUpdate(cipherCtx_, nullptr, &plaintextOutSize, additionalData.data(), static_cast(additionalData.size())); if (aadResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to update decryption with additional data"; PrintSSLErrors(); return false; } } if (ciphertextBuffer.size() > std::numeric_limits::max()) { DISCORD_LOG(LS_ERROR) << "Ciphertext buffer size exceeds the maximum supported size"; return false; } auto updateResult = EVP_DecryptUpdate(cipherCtx_, plaintextBufferOut.data(), &plaintextOutSize, ciphertextBuffer.data(), static_cast(ciphertextBuffer.size())); if (updateResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to decrypt ciphertext"; PrintSSLErrors(); return false; } // make a copy of the tag since the interface expects a const tag for decryption std::vector tagBufferCopy(tagBuffer.begin(), tagBuffer.end()); auto tagResult = EVP_CIPHER_CTX_ctrl( cipherCtx_, EVP_CTRL_GCM_SET_TAG, kAesGcm128TruncatedTagBytes, tagBufferCopy.data()); if (tagResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to set expected truncated authentication tag for decryption"; PrintSSLErrors(); return false; } auto finalizeResult = EVP_DecryptFinal_ex(cipherCtx_, plaintextBufferOut.data(), &plaintextOutSize); if (finalizeResult != 1) { DISCORD_LOG(LS_ERROR) << "Failed to finalize decryption"; PrintSSLErrors(); return false; } return true; } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/openssl_cryptor.h ================================================ #pragma once #include #include "cryptor.h" namespace discord { namespace dave { class OpenSSLCryptor : public ICryptor { public: OpenSSLCryptor(const EncryptionKey& encryptionKey); ~OpenSSLCryptor(); bool IsValid() const { return cipherCtx_ != nullptr; } bool Encrypt(ArrayView ciphertextBufferOut, ArrayView plaintextBuffer, ArrayView nonceBuffer, ArrayView additionalData, ArrayView tagBufferOut) override; bool Decrypt(ArrayView plaintextBufferOut, ArrayView ciphertextBuffer, ArrayView tagBuffer, ArrayView nonceBuffer, ArrayView additionalData) override; private: EVP_CIPHER_CTX* cipherCtx_ = nullptr; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/utils/clock.h ================================================ #pragma once #include namespace discord { namespace dave { class IClock { public: using BaseClock = std::chrono::steady_clock; using TimePoint = BaseClock::time_point; using Duration = BaseClock::duration; virtual ~IClock() = default; virtual TimePoint Now() const = 0; }; class Clock : public IClock { public: TimePoint Now() const override { return BaseClock::now(); } }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/utils/leb128.cpp ================================================ #include "leb128.h" // The following code was copied from the webrtc source code: // https://webrtc.googlesource.com/src/+/refs/heads/main/modules/rtp_rtcp/source/leb128.cc namespace discord { namespace dave { size_t Leb128Size(uint64_t value) { int size = 0; while (value >= 0x80) { ++size; value >>= 7; } return size + 1; } uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end) { uint64_t value = 0; int fillBits = 0; while (readAt != end && fillBits < 64 - 7) { uint8_t leb128Byte = *readAt; value |= uint64_t{leb128Byte & 0x7Fu} << fillBits; ++readAt; fillBits += 7; if ((leb128Byte & 0x80) == 0) { return value; } } // Read 9 bytes and didn't find the terminator byte. Check if 10th byte // is that terminator, however to fit result into uint64_t it may carry only // single bit. if (readAt != end && *readAt <= 1) { value |= uint64_t{*readAt} << fillBits; ++readAt; return value; } // Failed to find terminator leb128 byte. readAt = nullptr; return 0; } size_t WriteLeb128(uint64_t value, uint8_t* buffer) { int size = 0; while (value >= 0x80) { buffer[size] = 0x80 | (value & 0x7F); ++size; value >>= 7; } buffer[size] = static_cast(value); ++size; return size; } } // namespace dave } // namespace discord ================================================ FILE: cpp/src/utils/leb128.h ================================================ #pragma once #include #include namespace discord { namespace dave { constexpr size_t Leb128MaxSize = 10; // Returns number of bytes needed to store `value` in leb128 format. size_t Leb128Size(uint64_t value); // Reads leb128 encoded value and advance read_at by number of bytes consumed. // Sets read_at to nullptr on error. uint64_t ReadLeb128(const uint8_t*& readAt, const uint8_t* end); // Encodes `value` in leb128 format. Assumes buffer has size of at least // Leb128Size(value). Returns number of bytes consumed. size_t WriteLeb128(uint64_t value, uint8_t* buffer); } // namespace dave } // namespace discord ================================================ FILE: cpp/src/utils/scope_exit.h ================================================ #pragma once #include #include #include namespace discord { namespace dave { class [[nodiscard]] ScopeExit final { public: template explicit ScopeExit(Cleanup&& cleanup) : cleanup_{std::forward(cleanup)} { } ScopeExit(ScopeExit&& rhs) : cleanup_{std::move(rhs.cleanup_)} { rhs.cleanup_ = nullptr; } ~ScopeExit() { if (cleanup_) { cleanup_(); } } ScopeExit& operator=(ScopeExit&& rhs) { cleanup_ = std::move(rhs.cleanup_); rhs.cleanup_ = nullptr; return *this; } void Dismiss() { cleanup_ = nullptr; } private: ScopeExit(ScopeExit const&) = delete; ScopeExit& operator=(ScopeExit const&) = delete; std::function cleanup_; }; } // namespace dave } // namespace discord ================================================ FILE: cpp/src/version.cpp ================================================ #include namespace discord { namespace dave { constexpr ProtocolVersion CurrentDaveProtocolVersion = 1; ProtocolVersion MaxSupportedProtocolVersion() { return CurrentDaveProtocolVersion; } } // namespace dave } // namespace discord ================================================ FILE: cpp/test/CMakeLists.txt ================================================ enable_testing() find_package(GTest CONFIG REQUIRED) SET(TEST_APP_NAME "libdave_test") file(GLOB_RECURSE TEST_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.h") file(GLOB_RECURSE TEST_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") add_executable(${TEST_APP_NAME} ${TEST_HEADERS} ${TEST_SOURCES}) add_dependencies(${TEST_APP_NAME} ${LIB_NAME}) target_include_directories(${TEST_APP_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/src) target_link_libraries(libdave_test PRIVATE ${LIB_NAME} GTest::gtest_main GTest::gmock MLSPP::bytes MLSPP::mlspp) add_test(NAME ${TEST_APP_NAME} COMMAND ${TEST_APP_NAME}) if(WIN32 AND BUILD_SHARED_LIBS) add_custom_command(TARGET ${TEST_APP_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ COMMENT "Copying ${LIB_NAME}.dll to test directory" ) endif() if(WIN32 AND ENABLE_SANITIZERS) if (NOT EXISTS "${ASAN_RUNTIME_DLL}") message(FATAL_ERROR "ASAN runtime DLL not found at ${ASAN_RUNTIME_DLL}") endif() add_custom_command(TARGET ${TEST_APP_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different "${ASAN_RUNTIME_DLL}" $ COMMENT "Copying ASAN runtime DLL to test directory" ) endif() add_subdirectory(capi) ================================================ FILE: cpp/test/capi/CMakeLists.txt ================================================ set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED ON) # Small wrapper library for the external sender SET(EXTERNAL_SENDER_WRAPPER_LIB "external_sender") SET(EXTERNAL_SENDER_WRAPPER_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../external_sender.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/external_sender_wrapper.cpp") SET(EXTERNAL_SENDER_WRAPPER_HEADERS "${CMAKE_CURRENT_SOURCE_DIR}/../external_sender.h" "${CMAKE_CURRENT_SOURCE_DIR}/external_sender_wrapper.h") add_library(${EXTERNAL_SENDER_WRAPPER_LIB} ${EXTERNAL_SENDER_WRAPPER_SOURCES} ${EXTERNAL_SENDER_WRAPPER_HEADERS}) target_include_directories(${EXTERNAL_SENDER_WRAPPER_LIB} PRIVATE ${PROJECT_SOURCE_DIR}/includes ${PROJECT_SOURCE_DIR}/src) target_link_libraries(${EXTERNAL_SENDER_WRAPPER_LIB} PRIVATE ${LIB_NAME} MLSPP::bytes MLSPP::mlspp) SET(TEST_APP_NAME "capi_test") file(GLOB_RECURSE TEST_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.h") file(GLOB_RECURSE TEST_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/*.c") add_executable(${TEST_APP_NAME} ${TEST_HEADERS} ${TEST_SOURCES}) add_dependencies(${TEST_APP_NAME} ${LIB_NAME} ${EXTERNAL_SENDER_WRAPPER_LIB}) target_include_directories(${TEST_APP_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/includes) if (BUILD_SHARED_LIBS) set(LINKER_LANG C) else() set(LINKER_LANG CXX) endif() set_target_properties(${TEST_APP_NAME} PROPERTIES LINKER_LANGUAGE ${LINKER_LANG}) target_link_libraries(${TEST_APP_NAME} PRIVATE ${LIB_NAME} ${EXTERNAL_SENDER_WRAPPER_LIB}) if(WIN32 AND BUILD_SHARED_LIBS) add_custom_command(TARGET ${TEST_APP_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ COMMENT "Copying ${LIB_NAME}.dll to test directory" ) endif() if(WIN32 AND ENABLE_SANITIZERS) if (NOT EXISTS "${ASAN_RUNTIME_DLL}") message(FATAL_ERROR "ASAN runtime DLL not found at ${ASAN_RUNTIME_DLL}") endif() add_custom_command(TARGET ${TEST_APP_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different "${ASAN_RUNTIME_DLL}" $ COMMENT "Copying ASAN runtime DLL to test directory" ) endif() add_test(NAME ${TEST_APP_NAME} COMMAND ${TEST_APP_NAME}) ================================================ FILE: cpp/test/capi/basic_tests.c ================================================ #include #include #include #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #include typedef CRITICAL_SECTION mutex_t; typedef CONDITION_VARIABLE cond_t; #else #include #include typedef pthread_mutex_t mutex_t; typedef pthread_cond_t cond_t; #endif #include #include "external_sender_wrapper.h" #include "test_helpers.h" #define RUN_TEST(test) \ do { \ printf("Running %s...\n", #test); \ if (test()) { \ printf(" PASSED\n"); \ passed++; \ } \ else { \ printf(" FAILED\n"); \ failed++; \ } \ } while (0) static int TestEncryptorCreateDestroy(void) { DAVEEncryptorHandle encryptor = daveEncryptorCreate(); TEST_ASSERT(encryptor != NULL, "Failed to create encryptor"); daveEncryptorDestroy(encryptor); return 1; } static int TestDecryptorCreateDestroy(void) { DAVEDecryptorHandle decryptor = daveDecryptorCreate(); TEST_ASSERT(decryptor != NULL, "Failed to create decryptor"); daveDecryptorDestroy(decryptor); return 1; } static int TestMaxProtocolVersion(void) { uint16_t maxProtocolVersion = daveMaxSupportedProtocolVersion(); TEST_ASSERT_EQ(maxProtocolVersion, 1, "Max protocol version should be 1"); return 1; } static int TestEncryptorPassthrough(void) { DAVEEncryptorHandle encryptor = daveEncryptorCreate(); TEST_ASSERT(encryptor != NULL, "Failed to create encryptor"); TEST_ASSERT_EQ(daveEncryptorHasKeyRatchet(encryptor), false, "Encryptor should not have a key ratchet"); TEST_ASSERT_EQ(daveEncryptorIsPassthroughMode(encryptor), false, "Encryptor should not be in passthrough mode"); // Set passthrough mode daveEncryptorSetPassthroughMode(encryptor, true); TEST_ASSERT_EQ(daveEncryptorIsPassthroughMode(encryptor), true, "Encryptor should be in passthrough mode"); TEST_ASSERT_EQ(daveEncryptorHasKeyRatchet(encryptor), false, "Encryptor should not have a key ratchet"); daveEncryptorAssignSsrcToCodec(encryptor, 0, DAVE_CODEC_OPUS); // Create test data const char* hexData = "0dc5aedd5bdc3f20be5697e54dd1f437"; size_t inputDataLength = 0; uint8_t* inputData = GetBufferFromHex(hexData, &inputDataLength); TEST_ASSERT(inputData != NULL, "Failed to get input data"); // Allocate output buffer size_t outputDataLength = daveEncryptorGetMaxCiphertextByteSize(encryptor, DAVE_MEDIA_TYPE_AUDIO, inputDataLength); uint8_t* outputData = (uint8_t*)malloc(outputDataLength); size_t bytesWritten = 0; // Encrypt in passthrough mode DAVEEncryptorResultCode result = daveEncryptorEncrypt(encryptor, DAVE_MEDIA_TYPE_AUDIO, 0, inputData, inputDataLength, outputData, outputDataLength, &bytesWritten); TEST_ASSERT_EQ(result, DAVE_ENCRYPTOR_RESULT_CODE_SUCCESS, "Encryption should succeed"); TEST_ASSERT_EQ(bytesWritten, inputDataLength, "Bytes written should match input length"); TEST_ASSERT(memcmp(inputData, outputData, inputDataLength) == 0, "Output should match input in passthrough mode"); // Cleanup free(inputData); free(outputData); daveEncryptorDestroy(encryptor); return 1; } static int TestDecryptorPassthrough(void) { DAVEDecryptorHandle decryptor = daveDecryptorCreate(); TEST_ASSERT(decryptor != NULL, "Decryptor should be created"); // Set passthrough mode daveDecryptorTransitionToPassthroughMode(decryptor, 1); // Create test data const char* hexData = "0dc5aedd5bdc3f20be5697e54dd1f437"; size_t inputDataLength = 0; uint8_t* inputData = GetBufferFromHex(hexData, &inputDataLength); TEST_ASSERT(inputData != NULL, "Input data should be allocated"); // Allocate output buffer size_t outputDataLength = daveDecryptorGetMaxPlaintextByteSize(decryptor, DAVE_MEDIA_TYPE_AUDIO, inputDataLength); uint8_t* outputData = (uint8_t*)malloc(outputDataLength); size_t bytesWritten = 0; // Decrypt in passthrough mode DAVEDecryptorResultCode result = daveDecryptorDecrypt(decryptor, DAVE_MEDIA_TYPE_AUDIO, inputData, inputDataLength, outputData, outputDataLength, &bytesWritten); TEST_ASSERT_EQ(result, DAVE_DECRYPTOR_RESULT_CODE_SUCCESS, "Decryption should succeed"); TEST_ASSERT_EQ(bytesWritten, inputDataLength, "Bytes written should match input length"); TEST_ASSERT(memcmp(inputData, outputData, inputDataLength) == 0, "Output should match input in passthrough mode"); // Cleanup free(inputData); free(outputData); daveDecryptorDestroy(decryptor); return 1; } static int TestPassthroughInOutBuffer(void) { const char* RandomBytes = "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9"; size_t incomingFrameLength = 0; uint8_t* incomingFrame = GetBufferFromHex(RandomBytes, &incomingFrameLength); TEST_ASSERT(incomingFrame != NULL, "Failed to allocate incoming frame"); uint8_t* frameCopy = (uint8_t*)malloc(incomingFrameLength); TEST_ASSERT(frameCopy != NULL, "Failed to allocate frame copy"); memcpy(frameCopy, incomingFrame, incomingFrameLength); // Encryptor test DAVEEncryptorHandle encryptor = daveEncryptorCreate(); TEST_ASSERT(encryptor != NULL, "Failed to create encryptor"); daveEncryptorAssignSsrcToCodec(encryptor, 0, DAVE_CODEC_OPUS); daveEncryptorSetPassthroughMode(encryptor, true); size_t bytesWritten = 0; DAVEEncryptorResultCode encryptResult = daveEncryptorEncrypt(encryptor, DAVE_MEDIA_TYPE_AUDIO, 0, incomingFrame, incomingFrameLength, incomingFrame, incomingFrameLength, &bytesWritten); TEST_ASSERT_EQ(encryptResult, DAVE_ENCRYPTOR_RESULT_CODE_SUCCESS, "Encryption should succeed"); TEST_ASSERT_EQ(bytesWritten, incomingFrameLength, "Bytes written should match input length"); TEST_ASSERT(memcmp(incomingFrame, frameCopy, bytesWritten) == 0, "Encrypted data should match input in passthrough mode"); // Decryptor test DAVEDecryptorHandle decryptor = daveDecryptorCreate(); TEST_ASSERT(decryptor != NULL, "Failed to create decryptor"); daveDecryptorTransitionToPassthroughMode(decryptor, true); bytesWritten = 0; DAVEDecryptorResultCode decryptResult = daveDecryptorDecrypt(decryptor, DAVE_MEDIA_TYPE_AUDIO, incomingFrame, incomingFrameLength, incomingFrame, incomingFrameLength, &bytesWritten); TEST_ASSERT_EQ(decryptResult, DAVE_DECRYPTOR_RESULT_CODE_SUCCESS, "Decryption should succeed"); TEST_ASSERT_EQ(bytesWritten, incomingFrameLength, "Bytes written should match input length"); TEST_ASSERT(memcmp(incomingFrame, frameCopy, bytesWritten) == 0, "Decrypted data should match input in passthrough mode"); // Cleanup free(incomingFrame); free(frameCopy); daveEncryptorDestroy(encryptor); daveDecryptorDestroy(decryptor); return 1; } static int TestPassthroughTwoBuffers(void) { const char* RandomBytes = "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9"; size_t incomingFrameLength = 0; uint8_t* incomingFrame = GetBufferFromHex(RandomBytes, &incomingFrameLength); TEST_ASSERT(incomingFrame != NULL, "Failed to allocate incoming frame"); uint8_t* encryptedFrame = (uint8_t*)malloc(incomingFrameLength * 2); TEST_ASSERT(encryptedFrame != NULL, "Failed to allocate encrypted frame"); uint8_t* decryptedFrame = (uint8_t*)malloc(incomingFrameLength); TEST_ASSERT(decryptedFrame != NULL, "Failed to allocate decrypted frame"); // Encryptor test DAVEEncryptorHandle encryptor = daveEncryptorCreate(); TEST_ASSERT(encryptor != NULL, "Failed to create encryptor"); daveEncryptorAssignSsrcToCodec(encryptor, 0, DAVE_CODEC_OPUS); daveEncryptorSetPassthroughMode(encryptor, true); size_t bytesWritten = 0; DAVEEncryptorResultCode encryptResult = daveEncryptorEncrypt(encryptor, DAVE_MEDIA_TYPE_AUDIO, 0, incomingFrame, incomingFrameLength, encryptedFrame, incomingFrameLength * 2, &bytesWritten); TEST_ASSERT_EQ(encryptResult, DAVE_ENCRYPTOR_RESULT_CODE_SUCCESS, "Encryption should succeed"); TEST_ASSERT_EQ(bytesWritten, incomingFrameLength, "Bytes written should match input length"); TEST_ASSERT(memcmp(incomingFrame, encryptedFrame, bytesWritten) == 0, "Encrypted data should match input in passthrough mode"); // Decryptor test DAVEDecryptorHandle decryptor = daveDecryptorCreate(); TEST_ASSERT(decryptor != NULL, "Failed to create decryptor"); daveDecryptorTransitionToPassthroughMode(decryptor, true); size_t bytesDecrypted = 0; DAVEDecryptorResultCode decryptResult = daveDecryptorDecrypt(decryptor, DAVE_MEDIA_TYPE_AUDIO, encryptedFrame, bytesWritten, decryptedFrame, incomingFrameLength, &bytesDecrypted); TEST_ASSERT_EQ(decryptResult, DAVE_DECRYPTOR_RESULT_CODE_SUCCESS, "Decryption should succeed"); TEST_ASSERT_EQ( bytesDecrypted, incomingFrameLength, "Bytes decrypted should match input length"); TEST_ASSERT(memcmp(encryptedFrame, decryptedFrame, bytesDecrypted) == 0, "Decrypted data should match encrypted data"); // Cleanup free(incomingFrame); free(encryptedFrame); free(decryptedFrame); daveEncryptorDestroy(encryptor); daveDecryptorDestroy(decryptor); return 1; } static void TestSessionFailureCallback(const char* source, const char* reason, void* userData) { (void)userData; printf("Session failure: %s: %s\n", source, reason); } typedef struct { mutex_t mutex; cond_t cond; uint8_t* pairwiseFingerprint; size_t pairwiseFingerprintLength; } PairwiseFingerprintData; static void PairwiseFingerprintDataInit(PairwiseFingerprintData* data) { #ifdef _WIN32 InitializeCriticalSection(&data->mutex); InitializeConditionVariable(&data->cond); #else pthread_mutex_init(&data->mutex, NULL); pthread_cond_init(&data->cond, NULL); #endif data->pairwiseFingerprint = NULL; data->pairwiseFingerprintLength = 0; } static void PairwiseFingerprintDataDestroy(PairwiseFingerprintData* data) { #ifdef _WIN32 DeleteCriticalSection(&data->mutex); // CONDITION_VARIABLE does not need cleanup #else pthread_mutex_destroy(&data->mutex); pthread_cond_destroy(&data->cond); #endif free(data->pairwiseFingerprint); data->pairwiseFingerprint = NULL; data->pairwiseFingerprintLength = 0; } static void PairwiseFingerprintDataWait(PairwiseFingerprintData* data) { #ifdef _WIN32 EnterCriticalSection(&data->mutex); if (data->pairwiseFingerprint == NULL) { SleepConditionVariableCS(&data->cond, &data->mutex, INFINITE); } LeaveCriticalSection(&data->mutex); #else pthread_mutex_lock(&data->mutex); if (data->pairwiseFingerprint == NULL) { pthread_cond_wait(&data->cond, &data->mutex); } pthread_mutex_unlock(&data->mutex); #endif } static void PairwiseFingerprintCallback(const uint8_t* pairwiseFingerprint, size_t pairwiseFingerprintLength, void* userData) { if (userData == NULL) { return; } PairwiseFingerprintData* data = (PairwiseFingerprintData*)userData; #ifdef _WIN32 EnterCriticalSection(&data->mutex); data->pairwiseFingerprint = (uint8_t*)malloc(pairwiseFingerprintLength); memcpy(data->pairwiseFingerprint, pairwiseFingerprint, pairwiseFingerprintLength); data->pairwiseFingerprintLength = pairwiseFingerprintLength; WakeConditionVariable(&data->cond); LeaveCriticalSection(&data->mutex); #else pthread_mutex_lock(&data->mutex); data->pairwiseFingerprint = (uint8_t*)malloc(pairwiseFingerprintLength); memcpy(data->pairwiseFingerprint, pairwiseFingerprint, pairwiseFingerprintLength); data->pairwiseFingerprintLength = pairwiseFingerprintLength; pthread_cond_signal(&data->cond); pthread_mutex_unlock(&data->mutex); #endif } static int TestSession(void) { uint64_t groupId = 1234567890; const char* userA = "1234123412341234"; const char* userB = "5678567856785678"; printf("Creating external sender\n"); DAVEExternalSenderHandle externalSender = daveExternalSenderCreate(groupId); TEST_ASSERT(externalSender != NULL, "Failed to create external sender"); // Create sessions printf("Creating sessions\n"); DAVESessionHandle sessionA = daveSessionCreate(NULL, NULL, TestSessionFailureCallback, NULL); DAVESessionHandle sessionB = daveSessionCreate(NULL, NULL, TestSessionFailureCallback, NULL); TEST_ASSERT(sessionA != NULL, "Failed to create session"); TEST_ASSERT(sessionB != NULL, "Failed to create session"); // Set external sender printf("Setting external sender\n"); uint8_t* marshalledExternalSender = NULL; size_t marshalledExternalSenderLength = 0; daveExternalSenderGetMarshalledExternalSender( externalSender, &marshalledExternalSender, &marshalledExternalSenderLength); TEST_ASSERT(marshalledExternalSender != NULL, "Failed to get marshalled external sender"); daveSessionSetExternalSender( sessionA, marshalledExternalSender, marshalledExternalSenderLength); daveSessionSetExternalSender( sessionB, marshalledExternalSender, marshalledExternalSenderLength); daveFree(marshalledExternalSender); // Init sessions daveSessionInit(sessionA, 1, groupId, userA); daveSessionInit(sessionB, 1, groupId, userB); TEST_ASSERT_EQ(daveSessionGetProtocolVersion(sessionA), 1, "Protocol version should be 1"); TEST_ASSERT_EQ(daveSessionGetProtocolVersion(sessionB), 1, "Protocol version should be 1"); // Get key packages printf("Getting key packages\n"); uint8_t* keyPackageA = NULL; size_t keyPackageALength = 0; daveSessionGetMarshalledKeyPackage(sessionA, &keyPackageA, &keyPackageALength); TEST_ASSERT(keyPackageA != NULL, "Failed to get key package"); uint8_t* keyPackageB = NULL; size_t keyPackageBLength = 0; daveSessionGetMarshalledKeyPackage(sessionB, &keyPackageB, &keyPackageBLength); TEST_ASSERT(keyPackageB != NULL, "Failed to get key package"); // Make add proposal for user B printf("Proposing add\n"); uint8_t* proposal = NULL; size_t proposalLength = 0; daveExternalSenderProposeAdd( externalSender, 0, keyPackageB, keyPackageBLength, &proposal, &proposalLength); TEST_ASSERT(proposal != NULL, "Failed to propose add user B"); daveFree(keyPackageA); daveFree(keyPackageB); // Process proposal of user B printf("Processing proposals\n"); uint8_t* commitWelcome = NULL; size_t commitWelcomeLength = 0; const char* recognizedUserIds[] = {userA, userB}; daveSessionProcessProposals(sessionA, proposal, proposalLength, recognizedUserIds, 2, &commitWelcome, &commitWelcomeLength); TEST_ASSERT(commitWelcome != NULL, "Failed to process proposals"); daveFree(proposal); // Split commit welcome printf("Splitting commit welcome\n"); uint8_t* commit = NULL; size_t commitLength = 0; uint8_t* welcome = NULL; size_t welcomeLength = 0; daveExternalSenderSplitCommitWelcome(externalSender, commitWelcome, commitWelcomeLength, &commit, &commitLength, &welcome, &welcomeLength); TEST_ASSERT(commit != NULL, "Failed to split commit welcome"); TEST_ASSERT(welcome != NULL, "Failed to split commit welcome"); daveFree(commitWelcome); // Process commit generated by user A printf("Processing commit welcome\n"); DAVECommitResultHandle commitResult = daveSessionProcessCommit(sessionA, commit, commitLength); DAVEWelcomeResultHandle welcomeResult = daveSessionProcessWelcome(sessionB, welcome, welcomeLength, recognizedUserIds, 2); daveFree(commit); daveFree(welcome); // Check commit welcome results printf("Checking commit welcome results\n"); TEST_ASSERT_EQ(daveCommitResultIsFailed(commitResult), false, "Commit should not be failed"); TEST_ASSERT_EQ(daveCommitResultIsIgnored(commitResult), false, "Commit should not be ignored"); uint64_t* rosterIds = NULL; size_t rosterIdsLength = 0; daveCommitResultGetRosterMemberIds(commitResult, &rosterIds, &rosterIdsLength); TEST_ASSERT(rosterIds != NULL, "Failed to get roster member ids"); TEST_ASSERT_EQ(rosterIdsLength, 2, "Roster member ids length should be 2"); TEST_ASSERT(rosterIds[0] == 1234123412341234, "Roster member id should be user A"); TEST_ASSERT(rosterIds[1] == 5678567856785678, "Roster member id should be user B"); daveFree(rosterIds); daveWelcomeResultGetRosterMemberIds(welcomeResult, &rosterIds, &rosterIdsLength); TEST_ASSERT(rosterIds != NULL, "Failed to get roster member ids"); TEST_ASSERT_EQ(rosterIdsLength, 2, "Roster member ids length should be 2"); TEST_ASSERT(rosterIds[0] == 1234123412341234, "Roster member id should be user A"); TEST_ASSERT(rosterIds[1] == 5678567856785678, "Roster member id should be user B"); uint8_t* signature = NULL; size_t signatureLength = 0; daveCommitResultGetRosterMemberSignature( commitResult, rosterIds[0], &signature, &signatureLength); TEST_ASSERT(signature != NULL, "Failed to get signature"); TEST_ASSERT(signatureLength > 0, "Signature length should be greater than 0"); daveFree(signature); daveCommitResultGetRosterMemberSignature( commitResult, rosterIds[1], &signature, &signatureLength); TEST_ASSERT(signature != NULL, "Failed to get signature"); TEST_ASSERT(signatureLength > 0, "Signature length should be greater than 0"); daveFree(signature); daveWelcomeResultGetRosterMemberSignature( welcomeResult, rosterIds[0], &signature, &signatureLength); TEST_ASSERT(signature != NULL, "Failed to get signature"); TEST_ASSERT(signatureLength > 0, "Signature length should be greater than 0"); daveFree(signature); daveWelcomeResultGetRosterMemberSignature( welcomeResult, rosterIds[1], &signature, &signatureLength); TEST_ASSERT(signature != NULL, "Failed to get signature"); TEST_ASSERT(signatureLength > 0, "Signature length should be greater than 0"); daveFree(signature); daveFree(rosterIds); daveCommitResultDestroy(commitResult); daveWelcomeResultDestroy(welcomeResult); // Match authenticators printf("Matching authenticators\n"); uint8_t* authenticatorA = NULL; size_t authenticatorALength = 0; daveSessionGetLastEpochAuthenticator(sessionA, &authenticatorA, &authenticatorALength); TEST_ASSERT(authenticatorA != NULL, "Failed to get authenticator"); uint8_t* authenticatorB = NULL; size_t authenticatorBLength = 0; daveSessionGetLastEpochAuthenticator(sessionB, &authenticatorB, &authenticatorBLength); TEST_ASSERT(authenticatorB != NULL, "Failed to get authenticator"); TEST_ASSERT(memcmp(authenticatorA, authenticatorB, authenticatorALength) == 0, "Authenticators should match"); daveFree(authenticatorA); daveFree(authenticatorB); // Get pairwise fingerprints printf("Matching pairwise fingerprints\n"); PairwiseFingerprintData pairwiseFingerprintDataA; PairwiseFingerprintDataInit(&pairwiseFingerprintDataA); PairwiseFingerprintData pairwiseFingerprintDataB; PairwiseFingerprintDataInit(&pairwiseFingerprintDataB); daveSessionGetPairwiseFingerprint( sessionA, 1, userB, &PairwiseFingerprintCallback, &pairwiseFingerprintDataA); daveSessionGetPairwiseFingerprint( sessionB, 1, userA, &PairwiseFingerprintCallback, &pairwiseFingerprintDataB); PairwiseFingerprintDataWait(&pairwiseFingerprintDataA); PairwiseFingerprintDataWait(&pairwiseFingerprintDataB); TEST_ASSERT(pairwiseFingerprintDataA.pairwiseFingerprintLength == pairwiseFingerprintDataB.pairwiseFingerprintLength, "Pairwise fingerprint lengths should match"); TEST_ASSERT(memcmp(pairwiseFingerprintDataA.pairwiseFingerprint, pairwiseFingerprintDataB.pairwiseFingerprint, pairwiseFingerprintDataA.pairwiseFingerprintLength) == 0, "Pairwise fingerprint should match"); PairwiseFingerprintDataDestroy(&pairwiseFingerprintDataA); PairwiseFingerprintDataDestroy(&pairwiseFingerprintDataB); // Get key ratchets printf("Getting key ratchets\n"); DAVEKeyRatchetHandle keyRatchetA = daveSessionGetKeyRatchet(sessionA, userA); DAVEKeyRatchetHandle keyRatchetB = daveSessionGetKeyRatchet(sessionB, userA); TEST_ASSERT(keyRatchetA != NULL, "Failed to get key ratchet"); TEST_ASSERT(keyRatchetB != NULL, "Failed to get key ratchet"); // Setup encryptor printf("Setting up encryptor\n"); DAVEEncryptorHandle encryptorA = daveEncryptorCreate(); daveEncryptorAssignSsrcToCodec(encryptorA, 0, DAVE_CODEC_OPUS); daveEncryptorSetPassthroughMode(encryptorA, false); daveEncryptorSetKeyRatchet(encryptorA, keyRatchetA); daveKeyRatchetDestroy(keyRatchetA); TEST_ASSERT_EQ(daveEncryptorHasKeyRatchet(encryptorA), true, "Encryptor should have a key ratchet"); TEST_ASSERT_EQ(daveEncryptorIsPassthroughMode(encryptorA), false, "Encryptor should not be in passthrough mode"); // Setup decryptor printf("Setting up decryptor\n"); DAVEDecryptorHandle decryptorA = daveDecryptorCreate(); daveDecryptorTransitionToPassthroughMode(decryptorA, false); daveDecryptorTransitionToKeyRatchet(decryptorA, keyRatchetB); daveKeyRatchetDestroy(keyRatchetB); // Create test data printf("Creating test data\n"); const char* hexData = "0dc5aedd5bdc3f20be5697e54dd1f437"; size_t inputDataLength = 0; uint8_t* inputData = GetBufferFromHex(hexData, &inputDataLength); TEST_ASSERT(inputData != NULL, "Failed to get input data"); // Encrypt data printf("Encrypting data\n"); size_t encryptedFrameLength = daveEncryptorGetMaxCiphertextByteSize(encryptorA, DAVE_MEDIA_TYPE_AUDIO, inputDataLength); uint8_t* encryptedFrame = (uint8_t*)malloc(encryptedFrameLength); daveEncryptorEncrypt(encryptorA, DAVE_MEDIA_TYPE_AUDIO, 0, inputData, inputDataLength, encryptedFrame, encryptedFrameLength, &encryptedFrameLength); TEST_ASSERT(encryptedFrame != NULL, "Failed to encrypt data"); TEST_ASSERT(encryptedFrameLength > inputDataLength, "Encrypted data length should be greater than input data length"); TEST_ASSERT(memcmp(inputData, encryptedFrame, inputDataLength) != 0, "Encrypted data should not match input data"); // Decrypt data printf("Decrypting data\n"); size_t decryptedFrameLength = daveDecryptorGetMaxPlaintextByteSize(decryptorA, DAVE_MEDIA_TYPE_AUDIO, encryptedFrameLength); uint8_t* decryptedFrame = (uint8_t*)malloc(decryptedFrameLength); daveDecryptorDecrypt(decryptorA, DAVE_MEDIA_TYPE_AUDIO, encryptedFrame, encryptedFrameLength, decryptedFrame, decryptedFrameLength, &decryptedFrameLength); TEST_ASSERT(decryptedFrame != NULL, "Failed to decrypt data"); TEST_ASSERT_EQ(decryptedFrameLength, inputDataLength, "Decrypted data length should be equal to input data length"); TEST_ASSERT(memcmp(inputData, decryptedFrame, inputDataLength) == 0, "Decrypted data should match input data"); // Check encryptor stats printf("Checking encryptor stats\n"); DAVEEncryptorStats encryptorStats; daveEncryptorGetStats(encryptorA, DAVE_MEDIA_TYPE_AUDIO, &encryptorStats); TEST_ASSERT_EQ(encryptorStats.encryptSuccessCount, 1, "Encryptor should have at least one successful encryption"); TEST_ASSERT_EQ( encryptorStats.encryptFailureCount, 0, "Encryptor should have no failed encryptions"); TEST_ASSERT(encryptorStats.encryptDuration > 0, "Encryptor should have a duration"); TEST_ASSERT_EQ( encryptorStats.encryptAttempts, 1, "Encryptor should have at least one encryption attempt"); TEST_ASSERT_EQ(encryptorStats.encryptMaxAttempts, 1, "Encryptor should have a maximum number of encryption attempts"); TEST_ASSERT_EQ( encryptorStats.encryptMissingKeyCount, 0, "Encryptor should have no missing keys"); // Check decryptor stats printf("Checking decryptor stats\n"); DAVEDecryptorStats decryptorStats; daveDecryptorGetStats(decryptorA, DAVE_MEDIA_TYPE_AUDIO, &decryptorStats); TEST_ASSERT_EQ(decryptorStats.decryptSuccessCount, 1, "Decryptor should have at least one successful decryption"); TEST_ASSERT_EQ( decryptorStats.decryptFailureCount, 0, "Decryptor should have no failed decryptions"); TEST_ASSERT(decryptorStats.decryptDuration > 0, "Decryptor should have a duration"); TEST_ASSERT_EQ( decryptorStats.decryptAttempts, 1, "Decryptor should have at least one decryption attempt"); TEST_ASSERT_EQ( decryptorStats.decryptMissingKeyCount, 0, "Decryptor should have no missing keys"); TEST_ASSERT_EQ( decryptorStats.decryptInvalidNonceCount, 0, "Decryptor should have no invalid nonces"); // Clean up printf("Cleaning up\n"); free(inputData); free(encryptedFrame); free(decryptedFrame); daveEncryptorDestroy(encryptorA); daveDecryptorDestroy(decryptorA); daveSessionDestroy(sessionA); daveSessionDestroy(sessionB); daveExternalSenderDestroy(externalSender); return 1; } static int TestExceptions(void) { printf("Testing exception catching\n"); DAVESessionHandle session = daveSessionCreate(NULL, NULL, TestSessionFailureCallback, NULL); TEST_ASSERT(session != NULL, "Failed to create session"); PairwiseFingerprintData pairwiseFingerprintData; PairwiseFingerprintDataInit(&pairwiseFingerprintData); daveSessionGetPairwiseFingerprint( session, 1, "1234123412341234", &PairwiseFingerprintCallback, &pairwiseFingerprintData); PairwiseFingerprintDataWait(&pairwiseFingerprintData); TEST_ASSERT_EQ(pairwiseFingerprintData.pairwiseFingerprintLength, 0, "Expected empty fingerprint when exception is caught"); PairwiseFingerprintDataDestroy(&pairwiseFingerprintData); daveSessionDestroy(session); return 1; } int main(void) { int passed = 0; int failed = 0; printf("\n=== Running C API Tests ===\n\n"); RUN_TEST(TestEncryptorCreateDestroy); RUN_TEST(TestDecryptorCreateDestroy); RUN_TEST(TestMaxProtocolVersion); RUN_TEST(TestEncryptorPassthrough); RUN_TEST(TestDecryptorPassthrough); RUN_TEST(TestPassthroughInOutBuffer); RUN_TEST(TestPassthroughTwoBuffers); RUN_TEST(TestSession); RUN_TEST(TestExceptions); printf("\n=== Test Results ===\n"); printf("Passed: %d\n", passed); printf("Failed: %d\n", failed); printf("Total: %d\n", passed + failed); return (failed == 0) ? 0 : 1; } ================================================ FILE: cpp/test/capi/external_sender_wrapper.cpp ================================================ #include "external_sender_wrapper.h" #include #include #include "../external_sender.h" namespace { void CopyVectorToOutputBuffer(std::vector const& vector, uint8_t** data, size_t* length) { if (data == nullptr || length == nullptr) { return; } if (vector.empty()) { *data = nullptr; *length = 0; return; } *data = reinterpret_cast(malloc(vector.size())); memcpy(*data, vector.data(), vector.size()); *length = vector.size(); } } // anonymous namespace DAVEExternalSenderHandle daveExternalSenderCreate(uint64_t groupId) { auto protocolVersion = daveMaxSupportedProtocolVersion(); auto externalSender = std::make_unique(protocolVersion, groupId); return reinterpret_cast(externalSender.release()); } void daveExternalSenderDestroy(DAVEExternalSenderHandle externalSenderHandle) { auto externalSender = reinterpret_cast(externalSenderHandle); delete externalSender; } void daveExternalSenderGetMarshalledExternalSender(DAVEExternalSenderHandle externalSenderHandle, uint8_t** marshalledExternalSender, size_t* length) { auto externalSender = reinterpret_cast(externalSenderHandle); auto externalSenderVec = externalSender->GetMarshalledExternalSender(); CopyVectorToOutputBuffer(externalSenderVec, marshalledExternalSender, length); } void daveExternalSenderProposeAdd(DAVEExternalSenderHandle externalSenderHandle, uint32_t epoch, uint8_t* keyPackage, size_t keyPackageLength, uint8_t** proposal, size_t* proposalLength) { auto externalSender = reinterpret_cast(externalSenderHandle); auto keyPackageVec = std::vector(keyPackage, keyPackage + keyPackageLength); auto result = externalSender->ProposeAdd(epoch, std::move(keyPackageVec)); CopyVectorToOutputBuffer(result, proposal, proposalLength); } void daveExternalSenderSplitCommitWelcome(DAVEExternalSenderHandle externalSenderHandle, uint8_t* commitWelcome, size_t commitWelcomeLength, uint8_t** commit, size_t* commitLength, uint8_t** welcome, size_t* welcomeLength) { auto externalSender = reinterpret_cast(externalSenderHandle); auto commitWelcomeVec = std::vector(commitWelcome, commitWelcome + commitWelcomeLength); auto [commitBytes, welcomeBytes] = externalSender->SplitCommitWelcome(std::move(commitWelcomeVec)); CopyVectorToOutputBuffer(commitBytes, commit, commitLength); CopyVectorToOutputBuffer(welcomeBytes, welcome, welcomeLength); } ================================================ FILE: cpp/test/capi/external_sender_wrapper.h ================================================ #include #ifdef __cplusplus extern "C" { #endif DECLARE_OPAQUE_HANDLE(DAVEExternalSenderHandle); DAVE_EXPORT DAVEExternalSenderHandle daveExternalSenderCreate(uint64_t groupId); DAVE_EXPORT void daveExternalSenderDestroy(DAVEExternalSenderHandle externalSender); DAVE_EXPORT void daveExternalSenderGetMarshalledExternalSender( DAVEExternalSenderHandle externalSender, uint8_t** marshalledExternalSender, size_t* length); DAVE_EXPORT void daveExternalSenderProposeAdd(DAVEExternalSenderHandle externalSender, uint32_t epoch, uint8_t* keyPackage, size_t keyPackageLength, uint8_t** proposal, size_t* proposalLength); DAVE_EXPORT void daveExternalSenderSplitCommitWelcome(DAVEExternalSenderHandle externalSender, uint8_t* commitWelcome, size_t commitWelcomeLength, uint8_t** commit, size_t* commitLength, uint8_t** welcome, size_t* welcomeLength); #ifdef __cplusplus } #endif ================================================ FILE: cpp/test/capi/test_helpers.c ================================================ #include "test_helpers.h" #include #include static int HexDigitToValue(char c) { if (c >= '0' && c <= '9') { return c - '0'; } if (c >= 'a' && c <= 'f') { return c - 'a' + 10; } if (c >= 'A' && c <= 'F') { return c - 'A' + 10; } return -1; } uint8_t* GetBufferFromHex(const char* hex, size_t* outLength) { if (!hex || !outLength) { return NULL; } size_t hexLength = strlen(hex); if (hexLength % 2 != 0) { *outLength = 0; return NULL; } size_t bufferLength = hexLength / 2; uint8_t* buffer = (uint8_t*)malloc(bufferLength); if (!buffer) { *outLength = 0; return NULL; } for (size_t i = 0; i < hexLength; i += 2) { int high = HexDigitToValue(hex[i]); int low = HexDigitToValue(hex[i + 1]); if (high < 0 || low < 0) { free(buffer); *outLength = 0; return NULL; } buffer[i / 2] = (uint8_t)((high << 4) | low); } *outLength = bufferLength; return buffer; } ================================================ FILE: cpp/test/capi/test_helpers.h ================================================ #ifndef TEST_HELPERS_H #define TEST_HELPERS_H #include #include #define TEST_ASSERT(condition, message) \ do { \ if (!(condition)) { \ fprintf(stderr, "FAILED: %s (at %s:%d)\n", message, __FILE__, __LINE__); \ return 0; \ } \ } while (0) #define TEST_ASSERT_EQ(a, b, message) \ do { \ if ((a) != (b)) { \ fprintf(stderr, \ "FAILED: %s - expected %lld, got %lld (at %s:%d)\n", \ message, \ (long long)(b), \ (long long)(a), \ __FILE__, \ __LINE__); \ return 0; \ } \ } while (0) uint8_t* GetBufferFromHex(const char* hex, size_t* outLength); #endif /* TEST_HELPERS_H */ ================================================ FILE: cpp/test/codec_utils_tests.cpp ================================================ #include #include "gtest/gtest.h" #include #include "codec_utils.h" #include "decryptor.h" #include "encryptor.h" #include "frame_processors.h" #include "dave_test.h" #include "static_key_ratchet.h" namespace discord { namespace dave { namespace test { TEST_F(DaveTests, RandomOpusFrame) { constexpr std::string_view randomBytes = "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(randomBytes); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::Opus); auto& unencryptedBytes = frameProcessor.GetUnencryptedBytes(); auto& encryptedBytes = frameProcessor.GetEncryptedBytes(); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(incomingFrame.size(), 76u); EXPECT_EQ(unencryptedBytes.size(), 0u); EXPECT_EQ(encryptedBytes.size(), incomingFrame.size()); EXPECT_EQ(unencryptedRanges.size(), 0u); } TEST_F(DaveTests, SplitReconstruct) { std::string randomBytes = "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9" "0000000000000000 00 000a 140a 280a 3c0a 14 fafa"; randomBytes.erase(std::remove(randomBytes.begin(), randomBytes.end(), ' '), randomBytes.end()); // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(randomBytes); auto reconstructedFrame = std::make_unique(incomingFrame.size()); InboundFrameProcessor frameProcessor; frameProcessor.ParseFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size())); memcpy(frameProcessor.GetPlaintext().data(), frameProcessor.GetCiphertext().data(), frameProcessor.GetCiphertext().size()); auto bytesWritten = frameProcessor.ReconstructFrame( MakeArrayView(reconstructedFrame.get(), incomingFrame.size())); EXPECT_EQ(bytesWritten, 76u); EXPECT_EQ(memcmp(incomingFrame.data(), reconstructedFrame.get(), bytesWritten), 0); } TEST_F(DaveTests, H264SliceOneByteExpGolomb) { // start code, nal unit header // 3 exponential golomb values (first_mb_in_slice, slice_type, pic_parameter_set_id) // then slice payloads constexpr std::string_view kH264SliceHex = "0000000161e0fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264SliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 6u); } TEST_F(DaveTests, H264ShortIDROneByteExpGolomb) { // SPS NAL UNIT, PPS NAL UNIT, then IDR NAL Unit // for IDR: nal unit header, then 3 exponential golomb values (first_mb_in_slice, slice_type, // pic_parameter_set_id) then IDR payloads constexpr std::string_view kH264ShortIDR = "000000016742c00d8c8d40d0fbc900f08846a00000000168ce3c800000000165b8fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264ShortIDR); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 33u); } TEST_F(DaveTests, H264ShortIDRTwoByteExpGolomb) { // SPS NAL UNIT, PPS NAL UNIT, then IDR NAL Unit // for IDR: nal unit header, then 3 exponential golomb values (first_mb_in_slice, slice_type, // pic_parameter_set_id) then IDR payloads constexpr std::string_view kH264ShortIDR = "000000016742c00d8c8d40d0fbc900f08846a00000000168ce3c8000000001654760fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264ShortIDR); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 34u); } TEST_F(DaveTests, H264LongIDROneByteExpGolomb) { // SPS NAL UNIT, PPS NAL UNIT, SEI NAL unit, then IDR NAL Unit // which has nal unit header, // then 3 exponential golomb values (first_mb_in_slice, slice_type, pic_parameter_set_id) // then IDR payloads constexpr std::string_view kH264LongIDR = "00000001274d0033ab402802dd00da08846a000000000128ee3c800000000106051a47564adc5c4c433f94efc511" "3cd143a801ffccccff020004ca90800000000125b8fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264LongIDR); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 67u); } TEST_F(DaveTests, H264LongIDRTwoByteExpGolomb) { // SPS NAL UNIT, PPS NAL UNIT, SEI NAL unit, then IDR NAL Unit // which has nal unit header, then 3 exponential golomb values // (first_mb_in_slice, slice_type, pic_parameter_set_id) then IDR payloads constexpr std::string_view kH264LongIDR = "00000001274d0033ab402802dd00da08846a000000000128ee3c800000000106051a47564adc5c4c433f94efc5" "11" "3cd143a801ffccccff020004ca908000000001254760fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264LongIDR); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 68u); } TEST_F(DaveTests, H264EmulationPreventionInEarlyExpGolomb) { constexpr std::string_view kH264SliceHex = "00000001610000038000e0fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264SliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 11u); } TEST_F(DaveTests, H264ThreeByteShortCodeExtension) { constexpr std::string_view kH264MixedShortCodes = "000000012764001fac2b602802dd8088000003000800000301b46d0e1970" "00000128ee3cb0000001258880ababab"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264MixedShortCodes); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 45u); auto bytesToEncrypt = frameProcessor.GetEncryptedBytes(); auto encryptedBytes = frameProcessor.GetCiphertextBytes(); EXPECT_EQ(bytesToEncrypt.size(), encryptedBytes.size()); memcpy(encryptedFrame.get(), bytesToEncrypt.data(), bytesToEncrypt.size()); frameProcessor.ReconstructFrame(MakeArrayView( encryptedFrame.get(), bytesToEncrypt.size() + frameProcessor.GetUnencryptedBytes().size())); constexpr std::string_view kExpectedUnencryptedHeaderHex = "000000012764001fac2b602802dd8088000003000800000301b46d0e19700000000128ee3cb000000001258880"; auto expectedUnencryptedHeader = GetBufferFromHex(kExpectedUnencryptedHeaderHex); auto compareResultExpected = memcmp( encryptedFrame.get(), expectedUnencryptedHeader.data(), expectedUnencryptedHeader.size()); EXPECT_EQ(compareResultExpected, 0); } TEST_F(DaveTests, H264TwoSliceTest) { // start code, nal unit header // 3 exponential golomb values (first_mb_in_slice, slice_type, pic_parameter_set_id) // then slice payload // and repeated again constexpr std::string_view kH264TwoSliceHex = "0000000161e0fafafa0000000161e0fafafa"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH264TwoSliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H264); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 2u); EXPECT_EQ(unencryptedRanges[0].offset, 0u); EXPECT_EQ(unencryptedRanges[0].size, 6u); EXPECT_EQ(unencryptedRanges[1].offset, 9u); EXPECT_EQ(unencryptedRanges[1].size, 6u); } TEST_F(DaveTests, H265IdrSlice) { constexpr std::string_view kH265IdrSliceHex = "0000000140010c01ffff016000000300b0000003000003005d17024" "000000001420101016000000300b0000003000003005da00280802d16205ee45914bff2e7f13fa2" "000000014401c072f05324000000014e01051a47564adc5c4c433f94efc5113cd143a803ee0000ee02001fc8b88" "0000000012801abab"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH265IdrSliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 119u); } TEST_F(DaveTests, H265TsaSlice) { constexpr std::string_view kH265TsaSliceHex = "000000010201abab"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH265TsaSliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 6u); } TEST_F(DaveTests, H265SimpleThreeByteCodeExtension) { constexpr std::string_view kH265TsaSliceHexShort = "0000010201abab"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH265TsaSliceHexShort); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 6u); } TEST_F(DaveTests, H265MultipleThreeByteCodeExtensions) { constexpr std::string_view kH265IdrSliceHex = "00000140010c01ffff016000000300b0000003000003005d17024" "0000001420101016000000300b0000003000003005da00280802d16205ee45914bff2e7f13fa2" "000000014401c072f05324000000014e01051a47564adc5c4c433f94efc5113cd143a803ee0000ee02001fc8b88" "00000012801abab"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH265IdrSliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 1u); EXPECT_EQ(unencryptedRanges.front().offset, 0u); EXPECT_EQ(unencryptedRanges.front().size, 119u); } TEST_F(DaveTests, H265TwoIdrSlice) { constexpr std::string_view kH265TwoIdrSliceHex = "0000010201abab0000010201abab"; // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(kH265TwoIdrSliceHex); auto encryptedFrame = std::make_unique(incomingFrame.size() * 2); OutboundFrameProcessor frameProcessor; frameProcessor.ProcessFrame( MakeArrayView(incomingFrame.data(), incomingFrame.size()), Codec::H265); auto unencryptedRanges = frameProcessor.GetUnencryptedRanges(); EXPECT_EQ(unencryptedRanges.size(), 2u); EXPECT_EQ(unencryptedRanges[0].offset, 0u); EXPECT_EQ(unencryptedRanges[0].size, 6u); EXPECT_EQ(unencryptedRanges[1].offset, 8u); EXPECT_EQ(unencryptedRanges[1].size, 6u); } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/cryptor_manager_tests.cpp ================================================ #include #include #include #include #include "common.h" #include "cryptor_manager.h" #include "utils/clock.h" #include "dave_test.h" #include "static_key_ratchet.h" using namespace testing; using namespace std::chrono_literals; namespace discord { namespace dave { namespace test { // Gap can't be larger than the amount of bits allocated for it if we want to handle wraparound // correctly static_assert(kMaxGenerationGap < kGenerationWrap, "Gap can't be larger than wraparound value"); class MockKeyRatchet : public IKeyRatchet { public: MockKeyRatchet() { ON_CALL(*this, GetKey).WillByDefault([](KeyGeneration generation) { auto userId = std::string("12345678901234567890"); return MakeStaticSenderKey(userId + std::to_string(generation)); }); } MOCK_METHOD(EncryptionKey, GetKey, (KeyGeneration generation), (override, noexcept)); MOCK_METHOD(void, DeleteKey, (KeyGeneration generation), (override, noexcept)); }; class MockClock : public IClock { public: TimePoint Now() const override { return now_; } void SetNow(TimePoint now) { now_ = now; } void Advance(Duration duration) { now_ += duration; } private: TimePoint now_{std::chrono::steady_clock::now()}; }; TEST_F(DaveTests, CryptorManagerCheckMaxGap) { auto mockKeyRatchet = std::make_unique(); EXPECT_CALL(*mockKeyRatchet, GetKey(0)); EXPECT_CALL(*mockKeyRatchet, GetKey(kMaxGenerationGap)); EXPECT_CALL(*mockKeyRatchet, GetKey(kMaxGenerationGap + 1)); MockClock clock; CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; // Give plenty of room to not trigger the max lifetime generations check clock.Advance(kMaxGenerationGap * 48h); auto cryptor = cryptorManager.GetCryptor(0); EXPECT_NE(cryptor, nullptr); EXPECT_EQ(cryptorManager.GetCryptor(0), cryptor); EXPECT_NE(cryptorManager.GetCryptor(kMaxGenerationGap), nullptr); EXPECT_EQ(cryptorManager.GetCryptor(kMaxGenerationGap + 1), nullptr); cryptorManager.ReportCryptorSuccess( kMaxGenerationGap, static_cast(kMaxGenerationGap << kRatchetGenerationShiftBits)); EXPECT_NE(cryptorManager.GetCryptor(kMaxGenerationGap + 1), nullptr); } TEST_F(DaveTests, CryptorManagerCheckExpiry) { auto mockKeyRatchet = std::make_unique(); EXPECT_CALL(*mockKeyRatchet, GetKey(0)); EXPECT_CALL(*mockKeyRatchet, GetKey(1)); EXPECT_CALL(*mockKeyRatchet, DeleteKey(0)); MockClock clock; CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); clock.Advance(1000000h); EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); EXPECT_NE(cryptorManager.GetCryptor(1), nullptr); cryptorManager.ReportCryptorSuccess(1, 1 << kRatchetGenerationShiftBits); clock.Advance(kCryptorExpiry - 1us); EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); clock.Advance(2us); EXPECT_EQ(cryptorManager.GetCryptor(0), nullptr); } TEST_F(DaveTests, CryptorManagerDeleteOldKeys) { auto mockKeyRatchet = std::make_unique(); EXPECT_CALL(*mockKeyRatchet, GetKey(0)); EXPECT_CALL(*mockKeyRatchet, GetKey(5)); EXPECT_CALL(*mockKeyRatchet, DeleteKey(0)); EXPECT_CALL(*mockKeyRatchet, DeleteKey(1)); EXPECT_CALL(*mockKeyRatchet, DeleteKey(2)); EXPECT_CALL(*mockKeyRatchet, DeleteKey(3)); EXPECT_CALL(*mockKeyRatchet, DeleteKey(4)); MockClock clock; CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; // Give plenty of room to not trigger the max lifetime generations check clock.Advance(kMaxGenerationGap * 48h); EXPECT_NE(cryptorManager.GetCryptor(0), nullptr); EXPECT_NE(cryptorManager.GetCryptor(5), nullptr); cryptorManager.ReportCryptorSuccess(5, 5 << kRatchetGenerationShiftBits); clock.Advance(kCryptorExpiry + 1us); EXPECT_NE(cryptorManager.GetCryptor(5), nullptr); } TEST_F(DaveTests, CryptorManagerGenerationWrap) { EXPECT_EQ(ComputeWrappedGeneration(0, 0), KeyGeneration{0}); EXPECT_EQ(ComputeWrappedGeneration(0, 1), KeyGeneration{1}); EXPECT_EQ(ComputeWrappedGeneration(0, 250), KeyGeneration{250}); EXPECT_EQ(ComputeWrappedGeneration(11 * kGenerationWrap + 42, 42), KeyGeneration{11 * kGenerationWrap + 42}); EXPECT_EQ(ComputeWrappedGeneration(11 * kGenerationWrap + 42, 50), KeyGeneration{11 * kGenerationWrap + 50}); EXPECT_EQ(ComputeWrappedGeneration(11 * kGenerationWrap + 42, 10), KeyGeneration{12 * kGenerationWrap + 10}); } TEST_F(DaveTests, CryptorManagerBigNonce) { EXPECT_EQ(ComputeWrappedBigNonce(0, 0), 0u); EXPECT_EQ(ComputeWrappedBigNonce(0, 1), 1u); EXPECT_EQ(ComputeWrappedBigNonce(0, 250), 250u); EXPECT_EQ(ComputeWrappedBigNonce(11, 10), 11 << kRatchetGenerationShiftBits | 10u); EXPECT_EQ(ComputeWrappedBigNonce(11, 42), 11 << kRatchetGenerationShiftBits | 42u); EXPECT_EQ(ComputeWrappedBigNonce(11, 50), 11 << kRatchetGenerationShiftBits | 50u); EXPECT_EQ(ComputeWrappedBigNonce(11, 2 << kRatchetGenerationShiftBits | 34), 11 << kRatchetGenerationShiftBits | 34u); EXPECT_EQ(ComputeWrappedBigNonce(11, 37 << kRatchetGenerationShiftBits | 139), 11 << kRatchetGenerationShiftBits | 139u); EXPECT_EQ(ComputeWrappedBigNonce(11, 89 << kRatchetGenerationShiftBits | 294), 11 << kRatchetGenerationShiftBits | 294u); } TEST_F(DaveTests, CryptorManagerNoReprocess) { auto mockKeyRatchet = std::make_unique(); EXPECT_CALL(*mockKeyRatchet, GetKey(0)); MockClock clock; CryptorManager cryptorManager{clock, std::move(mockKeyRatchet)}; // Give plenty of room to not trigger the max lifetime generations check clock.Advance(kMaxGenerationGap * 48h); auto cryptor = cryptorManager.GetCryptor(0); EXPECT_NE(cryptor, nullptr); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 0)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 1)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 2)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, std::numeric_limits::max())); cryptorManager.ReportCryptorSuccess(0, 0); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 0)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 1)); cryptorManager.ReportCryptorSuccess(0, 1); cryptorManager.ReportCryptorSuccess(0, 2); cryptorManager.ReportCryptorSuccess(0, 5); cryptorManager.ReportCryptorSuccess(0, 7); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 0)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 1)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 2)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 5)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 7)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 4)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 6)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 8)); cryptorManager.ReportCryptorSuccess(0, 4); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 4)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 6)); cryptorManager.ReportCryptorSuccess(0, 6); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 3)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 6)); cryptorManager.ReportCryptorSuccess(0, 10 + kMaxMissingNonces); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 3)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 7)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 8)); EXPECT_FALSE(cryptorManager.CanProcessNonce(0, 9)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 10)); EXPECT_TRUE(cryptorManager.CanProcessNonce(0, 11)); } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/cryptor_tests.cpp ================================================ #include #include "decryptor.h" #include "encryptor.h" #include "frame_processors.h" #include "dave_test.h" #include "static_key_ratchet.h" using namespace testing; using namespace std::chrono_literals; namespace discord { namespace dave { namespace test { constexpr std::string_view RandomBytes = "0dc5aedd5bdc3f20be5697e54dd1f437b896a36f858c6f20bbd69e2a493ca170c4f0c1b9acd4" "9d324b92afa788d09b12b29115a2feb3552b60fff983234a6c9608af3933683efc6b0f5579a9"; TEST_F(DaveTests, PassthroughInOutBuffer) { auto incomingFrame = GetBufferFromHex(RandomBytes); auto frameCopy = incomingFrame; auto frameViewIn = MakeArrayView(incomingFrame.data(), incomingFrame.size()); auto frameViewOut = MakeArrayView(incomingFrame.data(), incomingFrame.size()); EXPECT_NE(incomingFrame.data(), frameCopy.data()); Encryptor encryptor; encryptor.AssignSsrcToCodec(0, Codec::Opus); encryptor.SetPassthroughMode(true); size_t bytesWritten = 0; auto encryptResult = encryptor.Encrypt(MediaType::Audio, 0, frameViewIn, frameViewOut, &bytesWritten); EXPECT_EQ(encryptResult, 0); EXPECT_EQ(bytesWritten, frameCopy.size()); EXPECT_EQ(memcmp(incomingFrame.data(), frameCopy.data(), bytesWritten), 0); Decryptor decryptor; decryptor.TransitionToPassthroughMode(true, 0s); bytesWritten = 0; auto decryptResult = decryptor.Decrypt(MediaType::Audio, frameViewIn, frameViewOut, &bytesWritten); EXPECT_EQ(decryptResult, Decryptor::ResultCode::Success); EXPECT_EQ(bytesWritten, frameCopy.size()); EXPECT_EQ(memcmp(incomingFrame.data(), frameCopy.data(), bytesWritten), 0); } TEST_F(DaveTests, PassthroughTwoBuffers) { auto incomingFrame = GetBufferFromHex(RandomBytes); auto encryptedFrame = std::vector(incomingFrame.size() * 2); auto decryptedFrame = std::vector(incomingFrame.size()); Encryptor encryptor; encryptor.AssignSsrcToCodec(0, Codec::Opus); encryptor.SetPassthroughMode(true); size_t bytesWritten = 0; auto encryptResult = encryptor.Encrypt(MediaType::Audio, 0, {incomingFrame.data(), incomingFrame.size()}, {encryptedFrame.data(), encryptedFrame.size()}, &bytesWritten); EXPECT_EQ(encryptResult, 0); EXPECT_EQ(bytesWritten, incomingFrame.size()); EXPECT_EQ(memcmp(incomingFrame.data(), encryptedFrame.data(), bytesWritten), 0); Decryptor decryptor; decryptor.TransitionToPassthroughMode(true, 0s); size_t bytesDecrypted = 0; auto decryptResult = decryptor.Decrypt(MediaType::Audio, {encryptedFrame.data(), bytesWritten}, {decryptedFrame.data(), decryptedFrame.size()}, &bytesDecrypted); EXPECT_EQ(decryptResult, Decryptor::ResultCode::Success); EXPECT_EQ(bytesDecrypted, incomingFrame.size()); EXPECT_EQ(memcmp(encryptedFrame.data(), decryptedFrame.data(), decryptResult), 0); } TEST_F(DaveTests, SilencePacketPassthrough) { const std::vector WorkerSilencePacket = {248, 255, 254}; Decryptor decryptor; decryptor.TransitionToKeyRatchet(std::make_unique("0123456789876543210"), 0s); auto decryptedFrame = std::vector(WorkerSilencePacket.size()); size_t bytesWritten = 0; auto decryptResult = decryptor.Decrypt(MediaType::Audio, {WorkerSilencePacket.data(), WorkerSilencePacket.size()}, {decryptedFrame.data(), decryptedFrame.size()}, &bytesWritten); EXPECT_EQ(decryptResult, Decryptor::ResultCode::Success); EXPECT_EQ(bytesWritten, WorkerSilencePacket.size()); EXPECT_EQ(memcmp(WorkerSilencePacket.data(), decryptedFrame.data(), decryptResult), 0); } TEST_F(DaveTests, RandomOpusFrameEncryptDecrypt) { Encryptor encryptor; Decryptor decryptor; // set static key ratchet for testing encryptor.SetKeyRatchet(std::make_unique("0123456789876543210")); decryptor.TransitionToKeyRatchet(std::make_unique("0123456789876543210"), 0s); // load the hex encoded sample frame to a buffer auto incomingFrame = GetBufferFromHex(RandomBytes); auto encryptedFrame = std::vector(incomingFrame.size() * 2); auto decryptedFrame = std::vector(incomingFrame.size()); for (size_t i = 0; i < 1; i++) { // encrypt frame size_t bytesWritten = 0; encryptor.AssignSsrcToCodec(0, Codec::Opus); auto encryptResult = encryptor.Encrypt(MediaType::Audio, 0, {incomingFrame.data(), incomingFrame.size()}, {encryptedFrame.data(), encryptedFrame.size()}, &bytesWritten); EXPECT_EQ(encryptResult, 0); EXPECT_GE(bytesWritten, incomingFrame.size()); // decrypt frame size_t bytesDecrypted = 0; auto decryptResult = decryptor.Decrypt(MediaType::Audio, {encryptedFrame.data(), bytesWritten}, {decryptedFrame.data(), decryptedFrame.size()}, &bytesDecrypted); EXPECT_EQ(decryptResult, Decryptor::ResultCode::Success); EXPECT_EQ(bytesDecrypted, incomingFrame.size()); EXPECT_EQ(memcmp(incomingFrame.data(), decryptedFrame.data(), incomingFrame.size()), 0); } } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/dave_test.cpp ================================================ #include "dave_test.h" namespace discord { namespace dave { namespace test { std::vector GetBufferFromHex(const std::string_view& hex) { auto hexLength = hex.length(); if (hexLength % 2 != 0) { return {}; } auto buffer = std::vector(hexLength / 2); for (unsigned int i = 0; i < hexLength; i += 2) { auto byte = std::string(hex.substr(i, 2)); buffer[i / 2] = static_cast(std::stoi(byte, nullptr, 16)); } return buffer; } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/dave_test.h ================================================ #include "gtest/gtest.h" #include "common.h" namespace discord { namespace dave { namespace test { std::vector GetBufferFromHex(const std::string_view& hex); class DaveTests : public ::testing::Test { public: void SetUp() override {} void TearDown() override {} }; } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/external_sender.cpp ================================================ #include "external_sender.h" #include "mls/parameters.h" #include "mls/util.h" namespace discord { namespace dave { namespace test { ::mlspp::bytes_ns::bytes BigEndianBytesFrom(uint64_t value) noexcept { auto buffer = ::mlspp::bytes_ns::bytes(); buffer.reserve(sizeof(value)); for (int i = sizeof(value) - 1; i >= 0; --i) { buffer.push_back(static_cast(value >> (i * 8))); } return buffer; } ::mlspp::CipherSuite CiphersuiteForProtocolVersion( [[maybe_unused]] ProtocolVersion version) noexcept { return ::mlspp::CipherSuite{::mlspp::CipherSuite::ID::P256_AES128GCM_SHA256_P256}; } ExternalSender::ExternalSender(discord::dave::ProtocolVersion protocolVersion, uint64_t groupId) { ciphersuite_ = CiphersuiteForProtocolVersion(protocolVersion); groupId_ = std::move(BigEndianBytesFrom(groupId).as_vec()); signaturePrivateKey_ = std::make_shared<::mlspp::SignaturePrivateKey>( ::mlspp::SignaturePrivateKey::generate(ciphersuite_)); externalSender_.signature_key = signaturePrivateKey_->public_key; externalSender_.credential = ::mlspp::Credential::basic({0x00, 0x01, 0x01, 0x00}); } std::vector ExternalSender::GetMarshalledExternalSender() { return ::mlspp::tls::marshal(externalSender_); } std::vector ExternalSender::ProposeAdd(uint32_t epoch, std::vector const& keyPackage) { const auto keyPackageBytes = ::mlspp::bytes_ns::bytes(keyPackage); auto proposal = ::mlspp::Proposal{::mlspp::Add{{::mlspp::tls::get<::mlspp::KeyPackage>(keyPackageBytes)}}}; auto message = ::mlspp::external_proposal( ciphersuite_, groupId_, epoch, proposal, signerIndex_, *signaturePrivateKey_); bool isRevoke = false; ::mlspp::tls::ostream out; out << isRevoke; out << std::vector<::mlspp::MLSMessage>{message}; return out.bytes(); } std::pair, std::vector> ExternalSender::SplitCommitWelcome( std::vector const& commitWelcome) { auto commitWelcomeBytes = ::mlspp::bytes_ns::bytes(commitWelcome); ::mlspp::tls::istream in(commitWelcomeBytes); ::mlspp::MLSMessage commitMessage; ::mlspp::Welcome welcomeMessage; in >> commitMessage; in >> welcomeMessage; ::mlspp::tls::ostream commitOut; commitOut << commitMessage; auto commitBytes = commitOut.bytes(); ::mlspp::tls::ostream welcomeOut; welcomeOut << welcomeMessage; auto welcomeBytes = welcomeOut.bytes(); return std::make_pair(commitBytes, welcomeBytes); } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/external_sender.h ================================================ #pragma once #include #include namespace discord { namespace dave { namespace test { class ExternalSender { public: ExternalSender(discord::dave::ProtocolVersion protocolVersion, uint64_t groupId); std::vector GetMarshalledExternalSender(); std::vector ProposeAdd(uint32_t epoch, std::vector const& keyPackage); std::pair, std::vector> SplitCommitWelcome( std::vector const& commitWelcome); private: uint32_t signerIndex_{0}; ::mlspp::CipherSuite ciphersuite_; ::mlspp::bytes_ns::bytes groupId_; std::shared_ptr<::mlspp::SignaturePrivateKey> signaturePrivateKey_; ::mlspp::ExternalSender externalSender_; }; } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/static_key_ratchet.cpp ================================================ #include "static_key_ratchet.h" #include #include #include #include "common.h" namespace discord { namespace dave { namespace test { EncryptionKey MakeStaticSenderKey(const std::string& userID) { auto u64userID = strtoull(userID.c_str(), nullptr, 10); return MakeStaticSenderKey(u64userID); } EncryptionKey MakeStaticSenderKey(uint64_t u64userID) { static_assert(kAesGcm128KeyBytes == 2 * sizeof(u64userID)); EncryptionKey senderKey(kAesGcm128KeyBytes); const uint8_t* bytePtr = reinterpret_cast(&u64userID); std::copy_n(bytePtr, sizeof(u64userID), senderKey.begin()); std::copy_n(bytePtr, sizeof(u64userID), senderKey.begin() + sizeof(u64userID)); return senderKey; } StaticKeyRatchet::StaticKeyRatchet(const std::string& userId) noexcept : u64userID_(strtoull(userId.c_str(), nullptr, 10)) { } EncryptionKey StaticKeyRatchet::GetKey(KeyGeneration generation) noexcept { DISCORD_LOG(LS_INFO) << "Retrieving static key for generation " << generation << " for user " << u64userID_; return MakeStaticSenderKey(u64userID_); } void StaticKeyRatchet::DeleteKey([[maybe_unused]] KeyGeneration generation) noexcept { // noop } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/static_key_ratchet.h ================================================ #pragma once #include #include namespace discord { namespace dave { namespace test { EncryptionKey MakeStaticSenderKey(const std::string& userID); EncryptionKey MakeStaticSenderKey(uint64_t u64userID); class StaticKeyRatchet : public IKeyRatchet { public: StaticKeyRatchet(const std::string& userId) noexcept; ~StaticKeyRatchet() noexcept override = default; EncryptionKey GetKey(KeyGeneration generation) noexcept override; void DeleteKey(KeyGeneration generation) noexcept override; private: uint64_t u64userID_; }; } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/test/xssl_cryptor_tests.cpp ================================================ #include "gtest/gtest.h" #include #ifdef WITH_BORINGSSL #include "boringssl_cryptor.h" #else #include "openssl_cryptor.h" #endif #include "dave_test.h" #include "static_key_ratchet.h" namespace discord { namespace dave { namespace test { #ifdef WITH_BORINGSSL using CryptorVariant = BoringSSLCryptor; #else using CryptorVariant = OpenSSLCryptor; #endif TEST_F(DaveTests, XSSLEncryptDecrypt) { constexpr size_t PLAINTEXT_SIZE = 1024; auto plaintextBufferIn = std::vector(PLAINTEXT_SIZE, 0); auto additionalDataBuffer = std::vector(PLAINTEXT_SIZE, 0); auto plaintextBufferOut = std::vector(PLAINTEXT_SIZE, 0); auto ciphertextBuffer = std::vector(PLAINTEXT_SIZE, 0); auto nonceBuffer = std::vector(kAesGcm128NonceBytes, 0); auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); auto plaintextIn = MakeArrayView(plaintextBufferIn.data(), plaintextBufferIn.size()); auto additionalData = MakeArrayView(additionalDataBuffer.data(), additionalDataBuffer.size()); auto plaintextOut = MakeArrayView(plaintextBufferOut.data(), plaintextBufferOut.size()); auto ciphertextOut = MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); auto ciphertextIn = MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); auto nonce = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); auto tagOut = MakeArrayView(tagBuffer.data(), tagBuffer.size()); auto tagIn = MakeArrayView(tagBuffer.data(), tagBuffer.size()); CryptorVariant cryptor(MakeStaticSenderKey("12345678901234567890")); EXPECT_TRUE(cryptor.Encrypt(ciphertextOut, plaintextIn, nonce, additionalData, tagOut)); // The ciphertext should not be the same as the plaintext EXPECT_FALSE(memcmp(plaintextBufferIn.data(), ciphertextBuffer.data(), PLAINTEXT_SIZE) == 0); EXPECT_TRUE(cryptor.Decrypt(plaintextOut, ciphertextIn, tagIn, nonce, additionalData)); // The plaintext should be the same as the original plaintext EXPECT_TRUE(memcmp(plaintextBufferIn.data(), plaintextBufferOut.data(), PLAINTEXT_SIZE) == 0); } TEST_F(DaveTests, XSSLAdditionalDataAuth) { constexpr size_t PLAINTEXT_SIZE = 1024; auto plaintextBufferIn = std::vector(PLAINTEXT_SIZE, 0); auto additionalDataBuffer = std::vector(PLAINTEXT_SIZE, 0); auto plaintextBufferOut = std::vector(PLAINTEXT_SIZE, 0); auto ciphertextBuffer = std::vector(PLAINTEXT_SIZE, 0); auto nonceBuffer = std::vector(kAesGcm128NonceBytes, 0); auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); auto plaintextIn = MakeArrayView(plaintextBufferIn.data(), plaintextBufferIn.size()); auto additionalData = MakeArrayView(additionalDataBuffer.data(), additionalDataBuffer.size()); auto plaintextOut = MakeArrayView(plaintextBufferOut.data(), plaintextBufferOut.size()); auto ciphertextOut = MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); auto ciphertextIn = MakeArrayView(ciphertextBuffer.data(), ciphertextBuffer.size()); auto nonce = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); auto tagOut = MakeArrayView(tagBuffer.data(), tagBuffer.size()); auto tagIn = MakeArrayView(tagBuffer.data(), tagBuffer.size()); CryptorVariant cryptor(MakeStaticSenderKey("12345678901234567890")); EXPECT_TRUE(cryptor.Encrypt(ciphertextOut, plaintextIn, nonce, additionalData, tagOut)); // We modify the additional data before decryption additionalDataBuffer[0] = 1; EXPECT_FALSE(cryptor.Decrypt(plaintextOut, ciphertextIn, tagIn, nonce, additionalData)); } TEST_F(DaveTests, XSSLKeyDiff) { constexpr size_t PLAINTEXT_SIZE = 1024; auto plaintextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); auto additionalDataBuffer1 = std::vector(PLAINTEXT_SIZE, 0); auto plaintextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); auto additionalDataBuffer2 = std::vector(PLAINTEXT_SIZE, 0); auto ciphertextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); auto ciphertextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); auto nonceBuffer = std::vector(kAesGcm128NonceBytes, 0); auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); auto plaintext1 = MakeArrayView(plaintextBuffer1.data(), plaintextBuffer1.size()); auto additionalData1 = MakeArrayView(additionalDataBuffer1.data(), additionalDataBuffer1.size()); auto plaintext2 = MakeArrayView(plaintextBuffer2.data(), plaintextBuffer2.size()); auto additionalData2 = MakeArrayView(additionalDataBuffer2.data(), additionalDataBuffer2.size()); auto ciphertext1 = MakeArrayView(ciphertextBuffer1.data(), ciphertextBuffer1.size()); auto ciphertext2 = MakeArrayView(ciphertextBuffer2.data(), ciphertextBuffer2.size()); auto nonce = MakeArrayView(nonceBuffer.data(), nonceBuffer.size()); auto tag = MakeArrayView(tagBuffer.data(), tagBuffer.size()); CryptorVariant cryptor1(MakeStaticSenderKey("12345678901234567890")); CryptorVariant cryptor2(MakeStaticSenderKey("09876543210987654321")); EXPECT_TRUE(cryptor1.Encrypt(ciphertext1, plaintext1, nonce, additionalData1, tag)); EXPECT_TRUE(cryptor2.Encrypt(ciphertext2, plaintext2, nonce, additionalData2, tag)); EXPECT_FALSE(memcmp(ciphertextBuffer1.data(), ciphertextBuffer2.data(), PLAINTEXT_SIZE) == 0); } TEST_F(DaveTests, XSSLNonceDiff) { constexpr size_t PLAINTEXT_SIZE = 1024; auto plaintextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); auto additionalDataBuffer1 = std::vector(PLAINTEXT_SIZE, 0); auto plaintextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); auto additionalDataBuffer2 = std::vector(PLAINTEXT_SIZE, 0); auto ciphertextBuffer1 = std::vector(PLAINTEXT_SIZE, 0); auto ciphertextBuffer2 = std::vector(PLAINTEXT_SIZE, 0); auto nonceBuffer1 = std::vector(kAesGcm128NonceBytes, 0); auto nonceBuffer2 = std::vector(kAesGcm128NonceBytes, 1); auto tagBuffer = std::vector(kAesGcm128TruncatedTagBytes, 0); auto plaintext1 = MakeArrayView(plaintextBuffer1.data(), plaintextBuffer1.size()); auto additionalData1 = MakeArrayView(additionalDataBuffer1.data(), additionalDataBuffer1.size()); auto plaintext2 = MakeArrayView(plaintextBuffer2.data(), plaintextBuffer2.size()); auto additionalData2 = MakeArrayView(additionalDataBuffer2.data(), additionalDataBuffer2.size()); auto ciphertext1 = MakeArrayView(ciphertextBuffer1.data(), ciphertextBuffer1.size()); auto ciphertext2 = MakeArrayView(ciphertextBuffer2.data(), ciphertextBuffer2.size()); auto nonce1 = MakeArrayView(nonceBuffer1.data(), nonceBuffer1.size()); auto nonce2 = MakeArrayView(nonceBuffer2.data(), nonceBuffer2.size()); auto tag = MakeArrayView(tagBuffer.data(), tagBuffer.size()); CryptorVariant cryptor(MakeStaticSenderKey("12345678901234567890")); EXPECT_TRUE(cryptor.Encrypt(ciphertext1, plaintext1, nonce1, additionalData1, tag)); EXPECT_TRUE(cryptor.Encrypt(ciphertext2, plaintext2, nonce2, additionalData2, tag)); EXPECT_FALSE(memcmp(ciphertextBuffer1.data(), ciphertextBuffer2.data(), PLAINTEXT_SIZE) == 0); } } // namespace test } // namespace dave } // namespace discord ================================================ FILE: cpp/vcpkg-alts/boringssl/overlay-ports/mlspp/portfile.cmake ================================================ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO cisco/mlspp REF "${VERSION}" SHA512 5d37631e2c47daae1133ef074e60cc09ca2d395f9e11c416f829060e374051cf219d2d7fe98dae49d1d045292e07d6a09f4814a5f16e6cc05e67e7cd96f146c4 ) if(VCPKG_TARGET_IS_OSX AND EXISTS "/usr/local/include/openssl/") set(VCPKG_INCLUDE_OVERRIDE "-DCMAKE_CXX_FLAGS=-I${CURRENT_INSTALLED_DIR}/include") endif() vcpkg_cmake_configure( SOURCE_PATH "${SOURCE_PATH}" OPTIONS ${VCPKG_INCLUDE_OVERRIDE} -DDISABLE_GREASE=ON -DVCPKG_MANIFEST_DIR="alternatives/boringssl" -DMLS_CXX_NAMESPACE="mlspp" ) vcpkg_cmake_install() vcpkg_cmake_config_fixup(PACKAGE_NAME "MLSPP" CONFIG_PATH "share/MLSPP") ================================================ FILE: cpp/vcpkg-alts/boringssl/overlay-ports/mlspp/vcpkg.json ================================================ { "name": "mlspp", "version-string": "1cc50a124a3bc4e143a787ec934280dc70c1034d", "description": "Cisco MLS C++ library", "dependencies": [ { "name": "boringssl", "version>=": "2023-10-13" }, "nlohmann-json", { "name": "vcpkg-cmake", "host": true }, { "name": "vcpkg-cmake-config", "host": true } ], "builtin-baseline": "eb33d2f7583405fca184bcdf7fdd5828ec88ac05" } ================================================ FILE: cpp/vcpkg-alts/boringssl/vcpkg.json ================================================ { "name": "libdave", "license": "MIT", "dependencies": [ { "name": "boringssl", "version>=": "2023-10-13" }, "gtest", "mlspp" ], "builtin-baseline": "7adc2e4d49e8d0efc07a369079faa6bc3dbb90f3", "vcpkg-configuration": { "overlay-ports": [ "./overlay-ports" ] } } ================================================ FILE: cpp/vcpkg-alts/openssl_1.1/overlay-ports/mlspp/portfile.cmake ================================================ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO cisco/mlspp REF "${VERSION}" SHA512 5d37631e2c47daae1133ef074e60cc09ca2d395f9e11c416f829060e374051cf219d2d7fe98dae49d1d045292e07d6a09f4814a5f16e6cc05e67e7cd96f146c4 ) if(VCPKG_TARGET_IS_OSX AND EXISTS "/usr/local/include/openssl/") set(VCPKG_INCLUDE_OVERRIDE "-DCMAKE_CXX_FLAGS=-I${CURRENT_INSTALLED_DIR}/include") endif() vcpkg_cmake_configure( SOURCE_PATH "${SOURCE_PATH}" OPTIONS ${VCPKG_INCLUDE_OVERRIDE} -DDISABLE_GREASE=ON -DVCPKG_MANIFEST_DIR="alternatives/openssl_1.1" -DMLS_CXX_NAMESPACE="mlspp" ) vcpkg_cmake_install() vcpkg_cmake_config_fixup(PACKAGE_NAME "MLSPP" CONFIG_PATH "share/MLSPP") ================================================ FILE: cpp/vcpkg-alts/openssl_1.1/overlay-ports/mlspp/vcpkg.json ================================================ { "name": "mlspp", "version-string": "1cc50a124a3bc4e143a787ec934280dc70c1034d", "description": "Cisco MLS C++ library", "dependencies": [ { "name": "openssl", "version>=": "1.1.1n" }, "catch2", "nlohmann-json", { "name": "vcpkg-cmake", "host": true }, { "name": "vcpkg-cmake-config", "host": true } ], "builtin-baseline": "eb33d2f7583405fca184bcdf7fdd5828ec88ac05", "overrides": [ { "name": "openssl", "version-string": "1.1.1n" } ] } ================================================ FILE: cpp/vcpkg-alts/openssl_1.1/vcpkg.json ================================================ { "name": "libdave", "license": "MIT", "dependencies": [ { "name": "openssl", "version>=": "1.1.1n" }, "gtest", "mlspp" ], "builtin-baseline": "d07689ef165f033de5c0710e4f67c193a85373e1", "vcpkg-configuration": { "overlay-ports": [ "./overlay-ports" ] }, "overrides": [ { "name": "openssl", "version-string": "1.1.1n" } ] } ================================================ FILE: cpp/vcpkg-alts/openssl_3/overlay-ports/mlspp/portfile.cmake ================================================ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO cisco/mlspp REF "${VERSION}" SHA512 5d37631e2c47daae1133ef074e60cc09ca2d395f9e11c416f829060e374051cf219d2d7fe98dae49d1d045292e07d6a09f4814a5f16e6cc05e67e7cd96f146c4 ) if(VCPKG_TARGET_IS_OSX AND EXISTS "/usr/local/include/openssl/") set(VCPKG_INCLUDE_OVERRIDE "-DCMAKE_CXX_FLAGS=-I${CURRENT_INSTALLED_DIR}/include") endif() vcpkg_cmake_configure( SOURCE_PATH "${SOURCE_PATH}" OPTIONS ${VCPKG_INCLUDE_OVERRIDE} -DDISABLE_GREASE=ON -DVCPKG_MANIFEST_DIR="alternatives/openssl_3" -DMLS_CXX_NAMESPACE="mlspp" ) vcpkg_cmake_install() vcpkg_cmake_config_fixup(PACKAGE_NAME "MLSPP" CONFIG_PATH "share/MLSPP") ================================================ FILE: cpp/vcpkg-alts/openssl_3/overlay-ports/mlspp/vcpkg.json ================================================ { "name": "mlspp", "version-string": "1cc50a124a3bc4e143a787ec934280dc70c1034d", "description": "Cisco MLS C++ library", "dependencies": [ { "name": "openssl", "version>=": "3.0.7" }, "catch2", "nlohmann-json", { "name": "vcpkg-cmake", "host": true }, { "name": "vcpkg-cmake-config", "host": true } ], "builtin-baseline": "eb33d2f7583405fca184bcdf7fdd5828ec88ac05", "overrides": [ { "name": "openssl", "version": "3.0.7" } ] } ================================================ FILE: cpp/vcpkg-alts/openssl_3/vcpkg.json ================================================ { "name": "libdave", "license": "MIT", "dependencies": [ { "name": "openssl", "version>=": "3.0.7" }, "gtest", "mlspp" ], "builtin-baseline": "d07689ef165f033de5c0710e4f67c193a85373e1", "vcpkg-configuration": { "overlay-ports": [ "./overlay-ports" ] }, "overrides": [ { "name": "openssl", "version": "3.0.7" } ] } ================================================ FILE: cpp/vcpkg-alts/wasm/overlay-ports/mlspp/portfile.cmake ================================================ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO cisco/mlspp REF "${VERSION}" SHA512 5d37631e2c47daae1133ef074e60cc09ca2d395f9e11c416f829060e374051cf219d2d7fe98dae49d1d045292e07d6a09f4814a5f16e6cc05e67e7cd96f146c4 ) if(VCPKG_TARGET_IS_EMSCRIPTEN) set(VCPKG_C_FLAGS "${VCPKG_C_FLAGS} -s WASM=1") set(VCPKG_CXX_FLAGS "${VCPKG_CXX_FLAGS} -s WASM=1") set(VCPKG_LINKER_FLAGS "${VCPKG_LINKER_FLAGS} -s WASM=1 -s ALLOW_MEMORY_GROWTH=1") endif() if(VCPKG_TARGET_IS_OSX AND EXISTS "/usr/local/include/openssl/") set(VCPKG_INCLUDE_OVERRIDE "-DCMAKE_CXX_FLAGS=-I${CURRENT_INSTALLED_DIR}/include") endif() vcpkg_cmake_configure( SOURCE_PATH "${SOURCE_PATH}" OPTIONS ${VCPKG_INCLUDE_OVERRIDE} -DDISABLE_GREASE=ON -DVCPKG_MANIFEST_DIR="alternatives/openssl_3" -DMLS_CXX_NAMESPACE="mlspp" ) vcpkg_cmake_install() vcpkg_cmake_config_fixup(PACKAGE_NAME "MLSPP" CONFIG_PATH "share/MLSPP") ================================================ FILE: cpp/vcpkg-alts/wasm/overlay-ports/mlspp/vcpkg.json ================================================ { "name": "mlspp", "version-string": "1cc50a124a3bc4e143a787ec934280dc70c1034d", "description": "Cisco MLS C++ library", "dependencies": [ { "name": "openssl", "version>=": "3.0.7" }, "nlohmann-json", { "name": "vcpkg-cmake", "host": true }, { "name": "vcpkg-cmake-config", "host": true } ], "builtin-baseline": "eb33d2f7583405fca184bcdf7fdd5828ec88ac05" } ================================================ FILE: cpp/vcpkg-alts/wasm/vcpkg.json ================================================ { "name": "libdave", "license": "MIT", "dependencies": [ { "name": "openssl", "version>=": "3.0.7" }, "gtest", "mlspp" ], "builtin-baseline": "7adc2e4d49e8d0efc07a369079faa6bc3dbb90f3", "vcpkg-configuration": { "overlay-ports": [ "./overlay-ports" ] } } ================================================ FILE: js/.gitignore ================================================ # Logs logs *.log npm-debug.log* yarn-debug.log* yarn-error.log* lerna-debug.log* .pnpm-debug.log* # Diagnostic reports (https://nodejs.org/api/report.html) report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json # Runtime data pids *.pid *.seed *.pid.lock # Directory for instrumented libs generated by jscoverage/JSCover lib-cov # Coverage directory used by tools like istanbul coverage *.lcov # nyc test coverage .nyc_output # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) .grunt # Bower dependency directory (https://bower.io/) bower_components # node-waf configuration .lock-wscript # Compiled binary addons (https://nodejs.org/api/addons.html) build/Release # Dependency directories node_modules/ jspm_packages/ # Snowpack dependency directory (https://snowpack.dev/) web_modules/ # TypeScript cache *.tsbuildinfo # Optional npm cache directory .npm # Optional eslint cache .eslintcache # Optional stylelint cache .stylelintcache # Microbundle cache .rpt2_cache/ .rts2_cache_cjs/ .rts2_cache_es/ .rts2_cache_umd/ # Optional REPL history .node_repl_history # Output of 'npm pack' *.tgz # Yarn Integrity file .yarn-integrity # dotenv environment variable files .env .env.development.local .env.test.local .env.production.local .env.local # parcel-bundler cache (https://parceljs.org/) .cache .parcel-cache # Next.js build output .next out # Nuxt.js build / generate output .nuxt dist # Gatsby files .cache/ # Comment in the public line in if your project uses Gatsby and not Next.js # https://nextjs.org/blog/next-9-1#public-directory-support # public # vuepress build output .vuepress/dist # vuepress v2.x temp and cache directory .temp .cache # Docusaurus cache and generated files .docusaurus # Serverless directories .serverless/ # FuseBox cache .fusebox/ # DynamoDB Local files .dynamodb/ # TernJS port file .tern-port # Stores VSCode versions used for testing VSCode extensions .vscode-test # yarn v2 .yarn/cache .yarn/unplugged .yarn/build-state.yml .yarn/install-state.gz .pnp.* # Source maps *.wasm.map ================================================ FILE: js/.npmrc ================================================ node-linker = hoisted ================================================ FILE: js/README.md ================================================ ## libdave JS Contains the package @discordapp/libdave. This is leveraged by Discord clients to enable out-of-band verifications of DAVE protocol call members and the MLS epoch authenticator. ### Testing Testing uses [Jest](https://jestjs.io/). You can run the tests with `pnpm jest`. ### Dependencies - [@noble/hashes](https://github.com/paulmillr/noble-hashes) - [base64-js](https://www.npmjs.com/package/base64-js) ================================================ FILE: js/__tests__/DisplayableCode-test.ts ================================================ import {describe, expect, test} from '@jest/globals'; import {generateDisplayableCode} from '../src/DisplayableCode'; describe('DisplayableCode', () => { test('expectedOutput', () => { const shortData = new Uint8Array([0xaa, 0xbb, 0xcc, 0xdd, 0xee]); expect(generateDisplayableCode(shortData, 5, 5)).toBe('05870'); const longDataBuffer = Buffer.from('aabbccddeebbccddeeffccddeeffaaddeeffaabbeeffaabbccffaabbccdd', 'hex'); const longData = Uint8Array.from(longDataBuffer); expect(generateDisplayableCode(longData, 30, 5)).toBe('058708105556138052119572494877'); }); test('expectedFailure', () => { const tooShortData = new Uint8Array([0xaa, 0xbb, 0xcc, 0xdd]); expect(() => { generateDisplayableCode(tooShortData, 5, 5); }).toThrow(); const goodData = new Uint8Array([0xaa, 0xbb, 0xcc, 0xdd]); expect(() => { generateDisplayableCode(goodData, 4, 3); }).toThrow(); const randomData = new Uint8Array(1024); globalThis.crypto.getRandomValues(randomData); expect(() => { generateDisplayableCode(randomData, 1024, 11); }).toThrow(); }); }); ================================================ FILE: js/__tests__/KeyFingerprint-test.ts ================================================ import {describe, expect, test} from '@jest/globals'; import {generateKeyFingerprint} from '../src/KeyFingerprint'; describe('KeyFingerprint', () => { test('expectedOutput', async () => { const shortData = new Uint8Array(33); expect((await generateKeyFingerprint(0, shortData, '1234')).join('')).toBe( '000000000000000000000000000000000000000004210', ); const longData = new Uint8Array(65); expect((await generateKeyFingerprint(0, longData, '12345678')).join('')).toBe( '0000000000000000000000000000000000000000000000000000000000000000000000001889778', ); }); test('expectedFailure', async () => { const data = new Uint8Array(33); await expect(generateKeyFingerprint(1, data, '1234')).rejects.toThrow(); await expect(generateKeyFingerprint(0, data, 'abcd')).rejects.toThrow(); await expect(generateKeyFingerprint(0, new Uint8Array(0), '1234')).rejects.toThrow(); }); }); ================================================ FILE: js/__tests__/KeySerialization-test.ts ================================================ import {describe, expect, test} from '@jest/globals'; import {serializeKey} from '../src/KeySerialization'; describe('KeySerialization', () => { test('expectedOutput', async () => { const zeroData = new Uint8Array(6); expect(serializeKey(zeroData)).toBe('AAAAAAAA'); const moreData = new Uint8Array([0, 1, 0xff, 0x7f, 0x80]); expect(serializeKey(moreData)).toBe('AAH/f4A='); }); }); ================================================ FILE: js/__tests__/PairwiseFingerprint-test.ts ================================================ import {describe, expect, test} from '@jest/globals'; import {generatePairwiseFingerprint} from '../src/PairwiseFingerprint'; describe('PairwiseFingerprint', () => { test('expectedOutput', async () => { const data1 = new Uint8Array(33); const data2 = new Uint8Array(65); expect(generatePairwiseFingerprint(0, data1, '1234', data2, '5678')).resolves.toEqual( new Uint8Array([ 133, 129, 241, 44, 36, 135, 79, 195, 27, 28, 151, 69, 124, 197, 189, 41, 192, 7, 16, 45, 79, 247, 138, 58, 126, 161, 178, 136, 12, 109, 96, 164, 169, 92, 2, 232, 136, 174, 74, 156, 173, 144, 191, 184, 34, 45, 242, 136, 41, 133, 14, 158, 119, 79, 204, 48, 6, 220, 121, 6, 242, 11, 164, 60, ]), ); }); test('badSort', async () => { const data1 = new Uint8Array([0, 100]); const data2 = new Uint8Array([0, 20]); expect(generatePairwiseFingerprint(0, data1, '1', data2, '2')).resolves.toEqual( new Uint8Array([ 141, 169, 194, 143, 22, 72, 22, 245, 13, 140, 66, 228, 159, 195, 101, 106, 119, 240, 69, 191, 178, 227, 194, 126, 162, 255, 222, 148, 138, 5, 33, 215, 240, 167, 234, 245, 149, 182, 46, 20, 4, 83, 191, 31, 165, 74, 253, 165, 199, 16, 29, 71, 193, 205, 169, 154, 255, 154, 34, 30, 94, 171, 247, 43, ]), ); }); test('expectedFailure', async () => { const data = new Uint8Array(33); await expect(generatePairwiseFingerprint(1, data, '1234', data, '5678')).rejects.toThrow(); await expect(generatePairwiseFingerprint(0, data, 'abcd', data, '5678')).rejects.toThrow(); await expect(generatePairwiseFingerprint(0, new Uint8Array(0), '1234', data, '5678')).rejects.toThrow(); }); }); ================================================ FILE: js/jest-setup.js ================================================ const crypto = require('crypto'); function convertAlgorithm(name) { switch (name) { case 'SHA-512': return 'sha512'; default: return name; } } Object.defineProperty(globalThis, 'crypto', { value: { getRandomValues: (arr) => crypto.randomBytes(arr.length), subtle: { digest: (algorithm, data) => { return crypto.hash(convertAlgorithm(algorithm), data, 'buffer').buffer; }, }, }, }); ================================================ FILE: js/jest.config.js ================================================ /** * For a detailed explanation regarding each configuration property, visit: * https://jestjs.io/docs/configuration */ /** @type {import('ts-jest').JestConfigWithTsJest} */ module.exports = { preset: 'ts-jest', testEnvironment: 'node', roots: ['/src', '/__tests__'], setupFiles: ['/jest-setup.js'], transform: { '^.+\\.tsx?$': ['ts-jest', { tsconfig: 'tsconfig.json', }], }, }; ================================================ FILE: js/package.json ================================================ { "name": "@discordapp/libdave", "description": "Discord's DAVE library for end-to-end encryption", "license": "MIT", "version": "1.0.1", "private": true, "main": "src/index.ts", "exports": { ".": { "import": "./src/index.ts" }, "./wasm": { "import": "./src/wasm.ts" } }, "files": [ "src/DisplayableCode.ts", "src/index.ts", "src/KeyFingerprint.ts", "src/KeySerialization.ts", "src/PairwiseFingerprint.ts", "src/wasm.ts", "wasm/libdave.js", "wasm/libdave.d.ts", "wasm/libdave.wasm" ], "scripts": { "test": "jest", "build:watch": "tsc --watch", "build": "tsc", "clean": "rm -rf dist", "prepublishOnly": "npm run clean && npm run build" }, "devDependencies": { "@jest/globals": "^30.2.0", "jest": "^30.2.0", "ts-jest": "^29.4.6", "typescript": "^5.9.3" }, "dependencies": { "@noble/hashes": "1.5.0", "base64-js": "1.5.1" }, "bundledDependencies": true } ================================================ FILE: js/src/DisplayableCode.ts ================================================ const MAX_GROUP_SIZE = 8; export function generateDisplayableCode(data: Uint8Array, desiredLength: number, groupSize: number): string { if (data.byteLength < desiredLength) { throw new Error('data.byteLength must be greater than or equal to desiredLength'); } if (desiredLength % groupSize !== 0) { throw new Error('desiredLength must be a multiple of groupSize'); } if (groupSize > MAX_GROUP_SIZE) { throw new Error(`groupSize must be less than or equal to ${MAX_GROUP_SIZE}`); } const groupModulus = BigInt(10 ** groupSize); let result = ''; for (let i = 0; i < desiredLength; i += groupSize) { let groupValue = BigInt(0); for (let j = groupSize; j > 0; --j) { const nextByte = data[i + (groupSize - j)]; if (nextByte === undefined) { throw new Error('Out of bounds access from data array'); } groupValue = (groupValue << 8n) | BigInt(nextByte); } groupValue %= groupModulus; result += groupValue.toString().padStart(groupSize, '0'); } return result; } ================================================ FILE: js/src/KeyFingerprint.ts ================================================ const VERSION_LEN = 2; const UID_LEN = 8; export async function generateKeyFingerprint(version: number, key: Uint8Array, userId: string): Promise { if (version !== 0) { throw new Error('unsupported fingerprint format version'); } if (key.byteLength === 0) { throw new Error('zero-length key'); } if (userId.length === 0) { throw new Error('zero-length user ID'); } const userIdInt = BigInt(userId); if (userIdInt < 0n || userIdInt >= 2n ** 64n) { throw new Error('user ID out of range'); } let lbuf = new Uint8Array(VERSION_LEN + key.byteLength + UID_LEN); lbuf.set(key, VERSION_LEN); const dv = new DataView(lbuf.buffer); dv.setUint16(0, version); dv.setBigUint64(VERSION_LEN + key.byteLength, userIdInt); return lbuf; } ================================================ FILE: js/src/KeySerialization.ts ================================================ import base64 from 'base64-js'; export function serializeKey(data: Uint8Array): string { return base64.fromByteArray(data); } ================================================ FILE: js/src/PairwiseFingerprint.ts ================================================ import {generateKeyFingerprint} from './KeyFingerprint'; import {scryptAsync} from '@noble/hashes/scrypt'; const salt = Uint8Array.of( 0x24, 0xca, 0xb1, 0x7a, 0x7a, 0xf8, 0xec, 0x2b, 0x82, 0xb4, 0x12, 0xb9, 0x2d, 0xab, 0x19, 0x2e, ); const scryptParams = { N: 16384, r: 8, p: 2, dkLen: 64, }; function compareArrays(a: Uint8Array, b: Uint8Array) { for (let i = 0; i < a.length && i < b.length; i++) { if (a[i] != b[i]) return a[i]! - b[i]!; } return a.length - b.length; } export async function generatePairwiseFingerprint( version: number, keyA: Uint8Array, userIdA: string, keyB: Uint8Array, userIdB: string, ): Promise { const fingerprints = await Promise.all([ generateKeyFingerprint(version, keyA, userIdA), generateKeyFingerprint(version, keyB, userIdB), ]); fingerprints.sort(compareArrays); const input = new Uint8Array(fingerprints[0].byteLength + fingerprints[1].byteLength); input.set(fingerprints[0], 0); input.set(fingerprints[1], fingerprints[0].byteLength); const ret = await scryptAsync(input, salt, scryptParams); return new Uint8Array(ret); } ================================================ FILE: js/src/index.ts ================================================ export {generateDisplayableCode} from './DisplayableCode'; export {generateKeyFingerprint} from './KeyFingerprint'; export {generatePairwiseFingerprint} from './PairwiseFingerprint'; export {serializeKey} from './KeySerialization'; ================================================ FILE: js/src/wasm.ts ================================================ export {default as DaveModuleFactory} from '../wasm/libdave'; export type {MainModule as DaveModule} from '../wasm/libdave'; export * from '../wasm/libdave'; ================================================ FILE: js/tsconfig.json ================================================ { "compilerOptions": { "esModuleInterop": true, "skipLibCheck": true, "target": "es2022", "allowJs": true, "resolveJsonModule": true, "moduleDetection": "force", "isolatedModules": true, "strict": true, "noUncheckedIndexedAccess": true, "noImplicitOverride": true, "module": "NodeNext", "moduleResolution": "NodeNext", "outDir": "dist", "sourceMap": true, "declaration": true, "declarationMap": true, "lib": ["es2022", "dom", "dom.iterable"] }, "include": ["src/**/*"], "exclude": ["node_modules", "dist", "__tests__/**/*"] } ================================================ FILE: js/wasm/.gitignore ================================================ libdave.* ================================================ FILE: samples/typescript/DaveSessionManager.ts ================================================ import type {Session, TransientKeys, DaveModule, SignaturePrivateKey} from '@discordapp/libdave/wasm'; const MLS_NEW_GROUP_EXPECTED_EPOCH = '1'; const DAVE_PROTOCOL_INIT_TRANSITION_ID = 0; export class DaveSessionManager { private readonly dave: DaveModule; private readonly transientKeys: TransientKeys | null; private readonly mlsSession: Session; private readonly selfUserId: string; private readonly groupId: string; private readonly recognizedUserIds: Set = new Set(); private readonly daveProtocolTransitions: Map = new Map(); private latestPreparedTransitionVersion: number = 0; constructor(dave: DaveModule, transientKeys: TransientKeys | null, selfUserId: string, groupId: string) { this.dave = dave; this.transientKeys = transientKeys; this.selfUserId = selfUserId; this.groupId = groupId; // These are only used with persistent key storage and can be ignored most of the time const context = ''; const authSessionId = ''; this.mlsSession = new dave.Session(context, authSessionId, (source: string, reason: string) => { console.error(`MLS failure: ${source} ${reason}`); }); } // Add an allowed user to the connection public createUser(userId: string) { this.recognizedUserIds.add(userId); this._setupKeyRatchetForUser(userId, this.latestPreparedTransitionVersion); } // Remove an allowed user from the connection public destroyUser(userId: string) { this.recognizedUserIds.delete(userId); // TODO: Signal the relevant media code that a user has left the call and the associated Encryptor/Decryptor should be destroyed } // Incoming Voice Gateway Requests // Opcode SELECT_PROTOCOL_ACK (1) public onSelectProtocolAck(protocolVersion: number) { this._handleDaveProtocolInit(protocolVersion); } // Opcode DAVE_PROTOCOL_PREPARE_TRANSITION (21) public onDaveProtocolPrepareTransition(transitionId: number, protocolVersion: number) { this._prepareDaveProtocolRatchets(transitionId, protocolVersion); this._maybeSendDaveProtocolReadyForTransition(transitionId); } // Opcode DAVE_PROTOCOL_EXECUTE_TRANSITION (22) public onDaveProtocolExecuteTransition(transitionId: number) { this._handleDaveProtocolExecuteTransition(transitionId); } // Opcode DAVE_PROTOCOL_PREPARE_EPOCH (24) public onDaveProtocolPrepareEpoch(epoch: string, protocolVersion: number) { this._handleDaveProtocolPrepareEpoch(epoch, protocolVersion, this.groupId); if (epoch === MLS_NEW_GROUP_EXPECTED_EPOCH) { this._sendMLSKeyPackage(); } } // Opcode MLS_EXTERNAL_SENDER_PACKAGE (25) public onDaveProtocolMLSExternalSenderPackage(externalSenderPackage: ArrayBuffer) { this.mlsSession.SetExternalSender(externalSenderPackage); } // Opcode MLS_PROPOSALS (27) public onMLSProposals(proposals: ArrayBuffer) { const commitWelcome = this.mlsSession.ProcessProposals(proposals, this._getRecognizedUserIDs()); if (commitWelcome) { this._sendMLSCommitWelcome(commitWelcome); } } // Opcode MLS_PREPARE_COMMIT_TRANSITION (29) public onMLSPrepareCommitTransition(transitionId: number, commit: ArrayBuffer) { const processedCommit = this.mlsSession.ProcessCommit(commit); const joinedGroup = processedCommit.rosterUpdate != null; if (processedCommit.ignored) { return; } if (joinedGroup) { this._prepareDaveProtocolRatchets(transitionId, this.mlsSession.GetProtocolVersion()); this._maybeSendDaveProtocolReadyForTransition(transitionId); } else { this._flagMLSInvalidCommitWelcome(transitionId); this._handleDaveProtocolInit(this.mlsSession.GetProtocolVersion()); } } // Opcode MLS_WELCOME (30) public onMLSWelcome(transitionId: number, welcome: ArrayBuffer) { const roster = this.mlsSession.ProcessWelcome(welcome, this._getRecognizedUserIDs()); const joinedGroup = roster != null; if (joinedGroup) { this._prepareDaveProtocolRatchets(transitionId, this.mlsSession.GetProtocolVersion()); this._maybeSendDaveProtocolReadyForTransition(transitionId); } else { this._flagMLSInvalidCommitWelcome(transitionId); this._sendMLSKeyPackage(); } } // Outgoing Voice Gateway Responses // Opcode MLS_KEY_PACKAGE (26) private _sendMLSKeyPackage() { const _keyPackage = this.mlsSession.GetMarshalledKeyPackage(); // TODO: Send keyPackage to the voice gateway using the MLS_KEY_PACKAGE (26) opcode } // Opcode DAVE_PROTOCOL_READY_FOR_TRANSITION (23) private _maybeSendDaveProtocolReadyForTransition(transitionId: number) { if (transitionId !== DAVE_PROTOCOL_INIT_TRANSITION_ID) { // TODO: Send the transition ready message to the voice gateway using the DAVE_PROTOCOL_READY_FOR_TRANSITION (23) opcode } } // Opcode MLS_COMMIT_WELCOME (28) private _sendMLSCommitWelcome(commitWelcomeMessage: ArrayBuffer) { // TODO: Send the commit welcome message to the voice gateway using the MLS_COMMIT_WELCOME (28) opcode } // Opcode MLS_INVALID_COMMIT_WELCOME (31) private _flagMLSInvalidCommitWelcome(transitionId: number) { // TODO: Send the invalid commit welcome message to the voice gateway using the MLS_INVALID_COMMIT_WELCOME (31) opcode } // Internal methods private _setupKeyRatchetForUser(userId: string, protocolVersion: number) { const keyRatchet = this._makeUserKeyRatchet(userId, protocolVersion); // TODO: Signal the relevant media code that a key ratchet has changed and the associated Encryptor/Decryptor needs to be updated } private _handleDaveProtocolInit(protocolVersion: number) { if (protocolVersion > 0) { this._handleDaveProtocolPrepareEpoch(MLS_NEW_GROUP_EXPECTED_EPOCH, protocolVersion, this.groupId); this._sendMLSKeyPackage(); } else { this._prepareDaveProtocolRatchets(DAVE_PROTOCOL_INIT_TRANSITION_ID, protocolVersion); this._handleDaveProtocolExecuteTransition(DAVE_PROTOCOL_INIT_TRANSITION_ID); } } private _handleDaveProtocolPrepareEpoch(epoch: string, protocolVersion: number, groupId: string): void { if (epoch === MLS_NEW_GROUP_EXPECTED_EPOCH) { let privateKey: SignaturePrivateKey | null = null; if (this.transientKeys != null) { privateKey = this.transientKeys.GetTransientPrivateKey(protocolVersion); } this.mlsSession.Init(protocolVersion, BigInt(groupId), this.selfUserId, privateKey); } } private _handleDaveProtocolExecuteTransition(transitionID: number): void { if (!this.daveProtocolTransitions.has(transitionID)) { return; } const protocolVersion = this.daveProtocolTransitions.get(transitionID)!; this.daveProtocolTransitions.delete(transitionID); if (protocolVersion === this.dave.kDisabledVersion) { this.mlsSession.Reset(); } this._setupKeyRatchetForUser(this.selfUserId, protocolVersion); } private _getRecognizedUserIDs(): string[] { return Array.from(this.recognizedUserIds).concat([this.selfUserId]); } private _makeUserKeyRatchet(userId: string, protocolVersion: number): any { if (protocolVersion === this.dave.kDisabledVersion) { return null; } return this.mlsSession.GetKeyRatchet(userId); } private _prepareDaveProtocolRatchets(transitionID: number, protocolVersion: number): void { for (const userId of this._getRecognizedUserIDs()) { if (userId === this.selfUserId) { continue; } this._setupKeyRatchetForUser(userId, protocolVersion); } if (transitionID === this.dave.kInitTransitionId) { this._setupKeyRatchetForUser(this.selfUserId, protocolVersion); } else { this.daveProtocolTransitions.set(transitionID, protocolVersion); } this.latestPreparedTransitionVersion = protocolVersion; } } ================================================ FILE: samples/typescript/README.md ================================================ # TypeScript DAVE Session Manager This directory contains an example implementation of a TypeScript class capable of handling Voice Gateway events for DAVE (Discord Audio Video Encryption) support. ## Overview The `DaveSessionManager` class can handle and respond to voice gateway events concerning DAVE. ## Features - Voice Gateway event handling - DAVE key ratchet generation ## Usage 1. **Initialize**: Create a `DaveSessionManager` object per selfUserId/groupId pair and call the relevant method upon receiving an opcode from the Voice Gateway. 2. **Implementation**: Look for `TODO`s and make sure to implement the relevant networking or media code. 3. **Key Updates**: Whenever you are signaled that a new key ratchet is available for a given user, make sure to update the associated encryptor/decryptor.