Repository: halfgaar/FlashMQ Branch: master Commit: af7ab0c2be24 Files: 235 Total size: 1.7 MB Directory structure: gitextract_eh9hwlw7/ ├── .build-ci.sh ├── .clang-format ├── .clang-format.original ├── .dockerignore ├── .get-os-codename-and-stamp.sh ├── .github/ │ └── workflows/ │ ├── building.yml │ ├── codestyle.yml │ ├── docker.yml │ ├── linting.yml │ ├── testing.yml │ └── testing_non_sse.yml ├── .gitignore ├── CMakeLists.shared ├── CMakeLists.txt ├── Dockerfile ├── FlashMQTests/ │ ├── CMakeLists.txt │ ├── UTF-8-test.txt │ ├── bridgeprefixtests.cpp │ ├── conffiletemp.cpp │ ├── conffiletemp.h │ ├── configtests.cpp │ ├── dnstests.cpp │ ├── filecloser.cpp │ ├── filecloser.h │ ├── flashmqtempdir.cpp │ ├── flashmqtempdir.h │ ├── main.cpp │ ├── mainappasfork.cpp │ ├── mainappasfork.h │ ├── mainappinthread.cpp │ ├── mainappinthread.h │ ├── maintests.cpp │ ├── maintests.h │ ├── plugins/ │ │ ├── curlfunctions.cpp │ │ ├── curlfunctions.h │ │ ├── test_plugin.cpp │ │ ├── test_plugin.h │ │ └── test_plugin.pro │ ├── plugintests.cpp │ ├── retaintests.cpp │ ├── run-make-from-ci.sh │ ├── run-tests-from-ci.sh │ ├── sharedsubscriptionstests.cpp │ ├── subscriptionidtests.cpp │ ├── testhelpers.cpp │ ├── testhelpers.h │ ├── testinitializer.cpp │ ├── testinitializer.h │ ├── tst_maintests.cpp │ ├── utiltests.cpp │ ├── websockettests.cpp │ └── willtests.cpp ├── LICENSE ├── README.md ├── acksender.cpp ├── acksender.h ├── acltree.cpp ├── acltree.h ├── backgroundworker.cpp ├── backgroundworker.h ├── bindaddr.cpp ├── bindaddr.h ├── bridgeconfig.cpp ├── bridgeconfig.h ├── bridgeinfodb.cpp ├── bridgeinfodb.h ├── build.sh ├── checkedsharedptr.h ├── checkedweakptr.h ├── cirbuf.cpp ├── cirbuf.h ├── client.cpp ├── client.h ├── clientacceptqueue.cpp ├── clientacceptqueue.h ├── configfileparser.cpp ├── configfileparser.h ├── debian/ │ ├── conffiles │ ├── flashmq.service │ ├── postinst │ ├── postrm │ ├── preinst │ └── prerm ├── derivablecounter.cpp ├── derivablecounter.h ├── dnsresolver.cpp ├── dnsresolver.h ├── driftcounter.cpp ├── driftcounter.h ├── enums.h ├── evpencodectxmanager.cpp ├── evpencodectxmanager.h ├── examples/ │ └── plugin_libcurl/ │ ├── CMakeLists.txt │ ├── LICENSE │ ├── README.md │ ├── build.sh │ ├── src/ │ │ ├── authenticatingclient.cpp │ │ ├── authenticatingclient.h │ │ ├── curl_functions.cpp │ │ ├── curl_functions.h │ │ ├── plugin_libcurl.cpp │ │ ├── pluginstate.cpp │ │ └── pluginstate.h │ └── vendor/ │ ├── flashmq_plugin.h │ └── flashmq_public.h ├── exceptions.cpp ├── exceptions.h ├── fdmanaged.cpp ├── fdmanaged.h ├── flags.h ├── flashmq.conf ├── flashmq_plugin.cpp ├── flashmq_plugin.h ├── flashmq_plugin_deprecated.h ├── flashmq_public.h ├── flashmqtestclient.cpp ├── flashmqtestclient.h ├── fmqmain.cpp ├── fmqmain.h ├── fmqsockaddr.cpp ├── fmqsockaddr.h ├── fmqssl.cpp ├── fmqssl.h ├── forward_declarations.h ├── fuzz-helper.sh ├── globals.cpp ├── globals.h ├── globalstats.cpp ├── globalstats.h ├── globber.cpp ├── globber.h ├── haproxy.cpp ├── haproxy.h ├── http.cpp ├── http.h ├── iowrapper.cpp ├── iowrapper.h ├── listener.cpp ├── listener.h ├── lockedsharedptr.h ├── lockedweakptr.cpp ├── lockedweakptr.h ├── logger.cpp ├── logger.h ├── main.cpp ├── mainapp.cpp ├── mainapp.h ├── man/ │ ├── .gitignore │ ├── Makefile │ ├── README.md │ ├── docbook5-refentry-xslt/ │ │ ├── docbook5-refentry-to-html5.xsl │ │ └── docbook5-refentry-to-manpage.xsl │ ├── flashmq-docbook5-refentry-to-html5.xsl │ ├── flashmq-docbook5-refentry-to-manpage.xsl │ ├── flashmq.1 │ ├── flashmq.1.dbk5 │ ├── flashmq.1.html │ ├── flashmq.conf.5 │ ├── flashmq.conf.5.dbk5 │ ├── flashmq.conf.5.html │ ├── refentry.colophon.dbk5 │ └── reference.dbk5 ├── mosquittoauthoptcompatwrap.cpp ├── mosquittoauthoptcompatwrap.h ├── mqtt5properties.cpp ├── mqtt5properties.h ├── mqttpacket.cpp ├── mqttpacket.h ├── mutexowned.h ├── network.cpp ├── network.h ├── nocopy.cpp ├── nocopy.h ├── oneinstancelock.cpp ├── oneinstancelock.h ├── packetdatatypes.cpp ├── packetdatatypes.h ├── persistencefile.cpp ├── persistencefile.h ├── persistencefunctions.cpp ├── persistencefunctions.h ├── plugin.cpp ├── plugin.h ├── pluginloader.cpp ├── pluginloader.h ├── publishcopyfactory.cpp ├── publishcopyfactory.h ├── qospacketqueue.cpp ├── qospacketqueue.h ├── queuedtasks.cpp ├── queuedtasks.h ├── release.sh ├── retainedmessage.cpp ├── retainedmessage.h ├── retainedmessagesdb.cpp ├── retainedmessagesdb.h ├── rwlockguard.cpp ├── rwlockguard.h ├── scopedsocket.cpp ├── scopedsocket.h ├── sdnotify.cpp ├── sdnotify.h ├── session.cpp ├── session.h ├── sessionsandsubscriptionsdb.cpp ├── sessionsandsubscriptionsdb.h ├── settings.cpp ├── settings.h ├── sharedsubscribers.cpp ├── sharedsubscribers.h ├── sslctxmanager.cpp ├── sslctxmanager.h ├── subscription.cpp ├── subscription.h ├── subscriptionstore.cpp ├── subscriptionstore.h ├── threaddata.cpp ├── threaddata.h ├── threadglobals.cpp ├── threadglobals.h ├── threadlocalutils.cpp ├── threadlocalutils.h ├── threadlocked.h ├── threadloop.cpp ├── threadloop.h ├── types.cpp ├── types.h ├── unscopedlock.cpp ├── unscopedlock.h ├── utils.cpp ├── utils.h ├── variablebyteint.cpp ├── variablebyteint.h ├── x509manager.cpp └── x509manager.h ================================================ FILE CONTENTS ================================================ ================================================ FILE: .build-ci.sh ================================================ #!/bin/bash set -e # If any step reports a problem consider the whole build a failure wget https://github.com/linuxdeploy/linuxdeploy/releases/download/continuous/linuxdeploy-x86_64.AppImage sudo mv linuxdeploy-x86_64.AppImage /usr/local/bin sudo chmod +x /usr/local/bin/linuxdeploy-x86_64.AppImage sudo apt update sudo apt install -y shellcheck shellcheck debian/post* debian/pre* ./build.sh ./FlashMQBuildRelease/FlashMQ --version sudo dpkg -i ./FlashMQBuildRelease/*.deb set +e # Prevent Travis internals from breaking our build, see https://github.com/travis-ci/travis-ci/issues/891 ================================================ FILE: .clang-format ================================================ --- Language: Cpp # BasedOnStyle: LLVM AccessModifierOffset: -4 AlignAfterOpenBracket: BlockIndent AlignArrayOfStructures: None AlignConsecutiveAssignments: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: true AlignConsecutiveBitFields: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: false AlignConsecutiveDeclarations: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: false AlignConsecutiveMacros: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: false AlignEscapedNewlines: Right AlignOperands: Align AlignTrailingComments: false AllowAllArgumentsOnNextLine: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortEnumsOnASingleLine: true AllowShortBlocksOnASingleLine: Never AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: None AllowShortLambdasOnASingleLine: All AllowShortIfStatementsOnASingleLine: Never AllowShortLoopsOnASingleLine: false AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: false AlwaysBreakTemplateDeclarations: MultiLine AttributeMacros: - __capability BinPackArguments: true BinPackParameters: true BraceWrapping: AfterCaseLabel: true AfterClass: true AfterControlStatement: Always AfterEnum: false AfterFunction: true AfterNamespace: false AfterObjCDeclaration: false AfterStruct: false AfterUnion: false AfterExternBlock: false BeforeCatch: true BeforeElse: true BeforeLambdaBody: false BeforeWhile: true IndentBraces: false SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true BreakBeforeBinaryOperators: None BreakBeforeConceptDeclarations: Always #BreakBeforeBraces: Allman BreakBeforeBraces: Custom BreakBeforeInheritanceComma: false BreakInheritanceList: BeforeColon BreakBeforeTernaryOperators: true BreakConstructorInitializersBeforeComma: false BreakConstructorInitializers: AfterColon BreakAfterJavaFieldAnnotations: false BreakStringLiterals: false ColumnLimit: 180 CommentPragmas: '^ IWYU pragma:' QualifierAlignment: Leave CompactNamespaces: false ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DeriveLineEnding: true DerivePointerAlignment: false DisableFormat: false EmptyLineAfterAccessModifier: Never EmptyLineBeforeAccessModifier: LogicalBlock ExperimentalAutoDetectBinPacking: false PackConstructorInitializers: Never BasedOnStyle: '' ConstructorInitializerAllOnOneLineOrOnePerLine: false AllowAllConstructorInitializersOnNextLine: true FixNamespaceComments: true ForEachMacros: - foreach - Q_FOREACH - BOOST_FOREACH IfMacros: - KJ_IF_MAYBE IncludeBlocks: Preserve IncludeCategories: - Regex: '^"(llvm|llvm-c|clang|clang-c)/' Priority: 2 SortPriority: 0 CaseSensitive: false - Regex: '^(<|"(gtest|gmock|isl|json)/)' Priority: 3 SortPriority: 0 CaseSensitive: false - Regex: '.*' Priority: 1 SortPriority: 0 CaseSensitive: false IncludeIsMainRegex: '(Test)?$' IncludeIsMainSourceRegex: '' IndentAccessModifiers: false IndentCaseLabels: false IndentCaseBlocks: false IndentGotoLabels: true IndentPPDirectives: None IndentExternBlock: AfterExternBlock IndentRequiresClause: true IndentWidth: 4 IndentWrappedFunctionNames: false InsertBraces: false InsertTrailingCommas: None JavaScriptQuotes: Leave JavaScriptWrapImports: true KeepEmptyLinesAtTheStartOfBlocks: true LambdaBodyIndentation: Signature MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 2 NamespaceIndentation: None ObjCBinPackProtocolList: Auto ObjCBlockIndentWidth: 2 ObjCBreakBeforeNestedBlockParam: true ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: true PenaltyBreakAssignment: 2 PenaltyBreakBeforeFirstCallParameter: 19 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakOpenParenthesis: 0 PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PenaltyIndentedWhitespace: 0 PointerAlignment: Right PPIndentWidth: -1 ReferenceAlignment: Pointer ReflowComments: false RemoveBracesLLVM: false RequiresClausePosition: OwnLine SeparateDefinitionBlocks: Leave ShortNamespaceLines: 1 SortIncludes: Never SortJavaStaticImport: Before SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterLogicalNot: false SpaceAfterTemplateKeyword: true SpaceBeforeAssignmentOperators: true SpaceBeforeCaseColon: false SpaceBeforeCpp11BracedList: true SpaceBeforeCtorInitializerColon: true SpaceBeforeInheritanceColon: true SpaceBeforeParens: ControlStatements SpaceBeforeParensOptions: AfterControlStatements: true AfterForeachMacros: true AfterFunctionDefinitionName: false AfterFunctionDeclarationName: false AfterIfMacros: true AfterOverloadedOperator: false AfterRequiresInClause: false AfterRequiresInExpression: false BeforeNonEmptyParentheses: false SpaceAroundPointerQualifiers: Default SpaceBeforeRangeBasedForLoopColon: true SpaceInEmptyBlock: false SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 1 SpacesInAngles: Never SpacesInConditionalStatement: false SpacesInContainerLiterals: true SpacesInCStyleCastParentheses: false SpacesInLineCommentPrefix: Minimum: 1 Maximum: -1 SpacesInParentheses: false SpacesInSquareBrackets: false SpaceBeforeSquareBrackets: false BitFieldColonSpacing: Both Standard: Latest StatementAttributeLikeMacros: - Q_EMIT StatementMacros: - Q_UNUSED - QT_REQUIRE_VERSION TabWidth: 8 UseCRLF: false UseTab: Never WhitespaceSensitiveMacros: - STRINGIZE - PP_STRINGIZE - BOOST_PP_STRINGIZE - NS_SWIFT_NAME - CF_SWIFT_NAME ... ================================================ FILE: .clang-format.original ================================================ --- Language: Cpp # BasedOnStyle: LLVM AccessModifierOffset: -2 AlignAfterOpenBracket: Align AlignArrayOfStructures: None AlignConsecutiveAssignments: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: true AlignConsecutiveBitFields: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: false AlignConsecutiveDeclarations: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: false AlignConsecutiveMacros: Enabled: false AcrossEmptyLines: false AcrossComments: false AlignCompound: false PadOperators: false AlignEscapedNewlines: Right AlignOperands: Align AlignTrailingComments: true AllowAllArgumentsOnNextLine: true AllowAllParametersOfDeclarationOnNextLine: true AllowShortEnumsOnASingleLine: true AllowShortBlocksOnASingleLine: Never AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: All AllowShortLambdasOnASingleLine: All AllowShortIfStatementsOnASingleLine: Never AllowShortLoopsOnASingleLine: false AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterReturnType: None AlwaysBreakBeforeMultilineStrings: false AlwaysBreakTemplateDeclarations: MultiLine AttributeMacros: - __capability BinPackArguments: true BinPackParameters: true BraceWrapping: AfterCaseLabel: false AfterClass: false AfterControlStatement: Never AfterEnum: false AfterFunction: false AfterNamespace: false AfterObjCDeclaration: false AfterStruct: false AfterUnion: false AfterExternBlock: false BeforeCatch: false BeforeElse: false BeforeLambdaBody: false BeforeWhile: false IndentBraces: false SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true BreakBeforeBinaryOperators: None BreakBeforeConceptDeclarations: Always BreakBeforeBraces: Attach BreakBeforeInheritanceComma: false BreakInheritanceList: BeforeColon BreakBeforeTernaryOperators: true BreakConstructorInitializersBeforeComma: false BreakConstructorInitializers: BeforeColon BreakAfterJavaFieldAnnotations: false BreakStringLiterals: true ColumnLimit: 80 CommentPragmas: '^ IWYU pragma:' QualifierAlignment: Leave CompactNamespaces: false ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 4 Cpp11BracedListStyle: true DeriveLineEnding: true DerivePointerAlignment: false DisableFormat: false EmptyLineAfterAccessModifier: Never EmptyLineBeforeAccessModifier: LogicalBlock ExperimentalAutoDetectBinPacking: false PackConstructorInitializers: BinPack BasedOnStyle: '' ConstructorInitializerAllOnOneLineOrOnePerLine: false AllowAllConstructorInitializersOnNextLine: true FixNamespaceComments: true ForEachMacros: - foreach - Q_FOREACH - BOOST_FOREACH IfMacros: - KJ_IF_MAYBE IncludeBlocks: Preserve IncludeCategories: - Regex: '^"(llvm|llvm-c|clang|clang-c)/' Priority: 2 SortPriority: 0 CaseSensitive: false - Regex: '^(<|"(gtest|gmock|isl|json)/)' Priority: 3 SortPriority: 0 CaseSensitive: false - Regex: '.*' Priority: 1 SortPriority: 0 CaseSensitive: false IncludeIsMainRegex: '(Test)?$' IncludeIsMainSourceRegex: '' IndentAccessModifiers: false IndentCaseLabels: false IndentCaseBlocks: false IndentGotoLabels: true IndentPPDirectives: None IndentExternBlock: AfterExternBlock IndentRequiresClause: true IndentWidth: 2 IndentWrappedFunctionNames: false InsertBraces: false InsertTrailingCommas: None JavaScriptQuotes: Leave JavaScriptWrapImports: true KeepEmptyLinesAtTheStartOfBlocks: true LambdaBodyIndentation: Signature MacroBlockBegin: '' MacroBlockEnd: '' MaxEmptyLinesToKeep: 1 NamespaceIndentation: None ObjCBinPackProtocolList: Auto ObjCBlockIndentWidth: 2 ObjCBreakBeforeNestedBlockParam: true ObjCSpaceAfterProperty: false ObjCSpaceBeforeProtocolList: true PenaltyBreakAssignment: 2 PenaltyBreakBeforeFirstCallParameter: 19 PenaltyBreakComment: 300 PenaltyBreakFirstLessLess: 120 PenaltyBreakOpenParenthesis: 0 PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PenaltyIndentedWhitespace: 0 PointerAlignment: Right PPIndentWidth: -1 ReferenceAlignment: Pointer ReflowComments: true RemoveBracesLLVM: false RequiresClausePosition: OwnLine SeparateDefinitionBlocks: Leave ShortNamespaceLines: 1 SortIncludes: CaseSensitive SortJavaStaticImport: Before SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterLogicalNot: false SpaceAfterTemplateKeyword: true SpaceBeforeAssignmentOperators: true SpaceBeforeCaseColon: false SpaceBeforeCpp11BracedList: false SpaceBeforeCtorInitializerColon: true SpaceBeforeInheritanceColon: true SpaceBeforeParens: ControlStatements SpaceBeforeParensOptions: AfterControlStatements: true AfterForeachMacros: true AfterFunctionDefinitionName: false AfterFunctionDeclarationName: false AfterIfMacros: true AfterOverloadedOperator: false AfterRequiresInClause: false AfterRequiresInExpression: false BeforeNonEmptyParentheses: false SpaceAroundPointerQualifiers: Default SpaceBeforeRangeBasedForLoopColon: true SpaceInEmptyBlock: false SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 1 SpacesInAngles: Never SpacesInConditionalStatement: false SpacesInContainerLiterals: true SpacesInCStyleCastParentheses: false SpacesInLineCommentPrefix: Minimum: 1 Maximum: -1 SpacesInParentheses: false SpacesInSquareBrackets: false SpaceBeforeSquareBrackets: false BitFieldColonSpacing: Both Standard: Latest StatementAttributeLikeMacros: - Q_EMIT StatementMacros: - Q_UNUSED - QT_REQUIRE_VERSION TabWidth: 8 UseCRLF: false UseTab: Never WhitespaceSensitiveMacros: - STRINGIZE - PP_STRINGIZE - BOOST_PP_STRINGIZE - NS_SWIFT_NAME - CF_SWIFT_NAME ... ================================================ FILE: .dockerignore ================================================ Dockerfile ================================================ FILE: .get-os-codename-and-stamp.sh ================================================ #!/bin/bash my_codename="" my_int_version="" if [[ -e "/etc/os-release" ]]; then eval "$(cat "/etc/os-release")" my_codename="$VERSION_CODENAME" my_int_version=${VERSION_ID//./} elif [[ -e "/etc/lsb-release" ]]; then eval "$(cat "/etc/lsb-release")" my_codename="$DISTRIB_CODENAME" my_int_version=${DISTRIB_RELEASE//./} else 1>&2 echo "Error in determing os codename" exit 1 fi if [[ -z "$my_codename" ]]; then 1>&2 echo "ERROR in determining OS codename" exit 1 fi if [[ ! "$my_int_version" =~ ^[0-9]+$ ]]; then 1>&2 echo "ERROR: int version '$my_int_version' is not an int. We need a numeric string for proper debian-revision version comparison." exit 1 fi # Sequence numbers makes sure that when one upgrades the OS, the package for # the new distro version is selected. sequence="$my_int_version" if [[ -z "$sequence" ]]; then 1>&2 echo "ERROR: no OS sequence defined for $my_codename" exit 2 fi echo -n "${sequence}+${my_codename}+${EPOCHSECONDS}" ================================================ FILE: .github/workflows/building.yml ================================================ name: Building on: [push] jobs: compilation: strategy: matrix: include: - os: ubuntu-24.04 friendly: clang CXX: /usr/bin/clang++ aptpkg: clang - os: ubuntu-24.04 friendly: gcc CXX: /usr/bin/g++ aptpkg: build-essential - os: ubuntu-22.04 friendly: clang CXX: /usr/bin/clang++ aptpkg: clang - os: ubuntu-22.04 friendly: gcc CXX: /usr/bin/g++ aptpkg: build-essential runs-on: ${{ matrix.os }} name: "${{ matrix.os }}: ${{ matrix.friendly }}" steps: - name: Checkout uses: actions/checkout@v3 - run: sudo apt update # Build prerequisites - run: sudo apt install -y xsltproc - run: sudo apt install -y cmake libssl-dev libcurl4-openssl-dev ${{ matrix.aptpkg }} # Build example plugin(s) - run: CXX="${{ matrix.CXX }}" ./examples/plugin_libcurl/build.sh "Release" # Building - run: CXX="${{ matrix.CXX }}" ./build.sh "Release" # Testing build results - run: ./FlashMQBuildRelease/flashmq --version - run: sudo dpkg -i ./FlashMQBuildRelease/*.deb ================================================ FILE: .github/workflows/codestyle.yml ================================================ name: Checking code style on: [pull_request] jobs: check_cpp_codestyle: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - run: sudo apt install -y clang-format-15 - run: git fetch origin "$GITHUB_BASE_REF" --depth 1 # Github CI only checks out a single commit, so we need to fetch what to compare against - run: CHANGES=$(git diff -U0 --no-color "origin/$GITHUB_BASE_REF" | clang-format-diff-15 -p1); [[ -z "$CHANGES" ]] || { echo "Please fix the following C++ style issues:"; echo "$CHANGES"; exit 2; } ================================================ FILE: .github/workflows/docker.yml ================================================ name: Docker on: [push] jobs: docker: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - run: docker build . -t halfgaar/flashmq - run: docker run halfgaar/flashmq /bin/flashmq --version ================================================ FILE: .github/workflows/linting.yml ================================================ name: Linting on: [push] jobs: shellcheck: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - run: sudo apt install -y shellcheck - run: shellcheck debian/post* debian/pre* - run: find . -type f -iname '*.sh' -exec shellcheck '{}' '+' markdownlint: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - run: sudo snap install mdl # Exclude MD013: line-length. If it looks good to the dev editing the Markdown file it's good enough for us - run: find . -type f -iname '*.md' -exec mdl --rules "~MD013" '{}' '+' ================================================ FILE: .github/workflows/testing.yml ================================================ name: Testing on: [push] jobs: compilation: strategy: matrix: include: - os: ubuntu-24.04 friendly: clang compiler: clang++ aptpkg: clang - os: ubuntu-24.04 friendly: gcc compiler: g++ aptpkg: build-essential - os: ubuntu-22.04 friendly: clang compiler: clang++ aptpkg: clang - os: ubuntu-22.04 friendly: gcc compiler: g++ aptpkg: build-essential runs-on: ${{ matrix.os }} defaults: run: working-directory: FlashMQTests name: "Normal test - ${{ matrix.os }}: ${{ matrix.friendly }}" steps: - name: Checkout uses: actions/checkout@v3 - run: sudo apt update # Build prerequisites - run: sudo apt install -y cmake libssl-dev libcurl4-openssl-dev ${{ matrix.aptpkg }} # Building - run: ./run-make-from-ci.sh --compiler "${{ matrix.compiler }}" - run: ./run-tests-from-ci.sh ================================================ FILE: .github/workflows/testing_non_sse.yml ================================================ name: TestingNoSse on: [push] jobs: compilation: strategy: matrix: include: - os: ubuntu-24.04 friendly: clang compiler: clang++ aptpkg: clang - os: ubuntu-24.04 friendly: gcc compiler: g++ aptpkg: build-essential - os: ubuntu-22.04 friendly: clang compiler: clang++ aptpkg: clang - os: ubuntu-22.04 friendly: gcc compiler: g++ aptpkg: build-essential runs-on: ${{ matrix.os }} defaults: run: working-directory: FlashMQTests name: "Non-SSE test - ${{ matrix.os }}: ${{ matrix.friendly }}" steps: - name: Checkout uses: actions/checkout@v3 - run: sudo apt update # Build prerequisites - run: sudo apt install -y cmake libssl-dev libcurl4-openssl-dev ${{ matrix.aptpkg }} # Building - run: ./run-make-from-ci.sh --compiler "${{ matrix.compiler }}" --extra-config "FMQ_NO_SSE=1" - run: ./run-tests-from-ci.sh ================================================ FILE: .gitignore ================================================ *.user build-* FlashMQBuild* *.swp compile_commands.json .clangd/ .cache/ build FlashMQTests/build ================================================ FILE: CMakeLists.shared ================================================ # This file is shared between FlashMQ itself and FlashMQTests # It should not contain any definitions that isn't used by both # When building FlashMQTests: # CMAKE_CURRENT_SOURCE_DIR -> /home/user/FlashMQ/FlashMQTests # CMAKE_CURRENT_LIST_DIR -> /home/user/FlashMQ # # When building FlashMQ: # CMAKE_CURRENT_SOURCE_DIR -> /home/user/FlashMQ # CMAKE_CURRENT_LIST_DIR -> /home/user/FlashMQ if (CMAKE_CURRENT_LIST_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(RELPATH "./") else() set(RELPATH "../") endif() message("Determined RELPATH: ${RELPATH}") set(FLASHMQ_HEADERS ${RELPATH}forward_declarations.h ${RELPATH}mainapp.h ${RELPATH}utils.h ${RELPATH}threaddata.h ${RELPATH}client.h ${RELPATH}session.h ${RELPATH}mqttpacket.h ${RELPATH}exceptions.h ${RELPATH}types.h ${RELPATH}subscriptionstore.h ${RELPATH}rwlockguard.h ${RELPATH}retainedmessage.h ${RELPATH}cirbuf.h ${RELPATH}logger.h ${RELPATH}plugin.h ${RELPATH}configfileparser.h ${RELPATH}sslctxmanager.h ${RELPATH}iowrapper.h ${RELPATH}mosquittoauthoptcompatwrap.h ${RELPATH}settings.h ${RELPATH}listener.h ${RELPATH}unscopedlock.h ${RELPATH}scopedsocket.h ${RELPATH}bindaddr.h ${RELPATH}oneinstancelock.h ${RELPATH}evpencodectxmanager.h ${RELPATH}acltree.h ${RELPATH}enums.h ${RELPATH}threadlocalutils.h ${RELPATH}flashmq_plugin.h ${RELPATH}flashmq_plugin_deprecated.h ${RELPATH}retainedmessagesdb.h ${RELPATH}persistencefile.h ${RELPATH}sessionsandsubscriptionsdb.h ${RELPATH}qospacketqueue.h ${RELPATH}threadglobals.h ${RELPATH}threadloop.h ${RELPATH}publishcopyfactory.h ${RELPATH}variablebyteint.h ${RELPATH}mqtt5properties.h ${RELPATH}globalstats.h ${RELPATH}derivablecounter.h ${RELPATH}packetdatatypes.h ${RELPATH}haproxy.h ${RELPATH}network.h ${RELPATH}subscription.h ${RELPATH}sharedsubscribers.h ${RELPATH}pluginloader.h ${RELPATH}queuedtasks.h ${RELPATH}acksender.h ${RELPATH}bridgeconfig.h ${RELPATH}dnsresolver.h ${RELPATH}globber.h ${RELPATH}bridgeinfodb.h ${RELPATH}x509manager.h ${RELPATH}backgroundworker.h ${RELPATH}fmqmain.h ${RELPATH}driftcounter.h ${RELPATH}lockedweakptr.h ${RELPATH}lockedsharedptr.h ${RELPATH}globals.h ${RELPATH}nocopy.h ${RELPATH}fdmanaged.h ${RELPATH}checkedweakptr.h ${RELPATH}mutexowned.h ${RELPATH}http.h ${RELPATH}fmqsockaddr.h ${RELPATH}flags.h ${RELPATH}persistencefunctions.h ${RELPATH}clientacceptqueue.h ${RELPATH}checkedsharedptr.h ${RELPATH}flashmq_public.h ${RELPATH}sdnotify.h ${RELPATH}threadlocked.h ) set(FLASHMQ_IMPLS ${RELPATH}mainapp.cpp ${RELPATH}utils.cpp ${RELPATH}threaddata.cpp ${RELPATH}client.cpp ${RELPATH}session.cpp ${RELPATH}mqttpacket.cpp ${RELPATH}exceptions.cpp ${RELPATH}types.cpp ${RELPATH}subscriptionstore.cpp ${RELPATH}rwlockguard.cpp ${RELPATH}retainedmessage.cpp ${RELPATH}cirbuf.cpp ${RELPATH}logger.cpp ${RELPATH}plugin.cpp ${RELPATH}configfileparser.cpp ${RELPATH}sslctxmanager.cpp ${RELPATH}iowrapper.cpp ${RELPATH}mosquittoauthoptcompatwrap.cpp ${RELPATH}settings.cpp ${RELPATH}listener.cpp ${RELPATH}unscopedlock.cpp ${RELPATH}scopedsocket.cpp ${RELPATH}bindaddr.cpp ${RELPATH}oneinstancelock.cpp ${RELPATH}evpencodectxmanager.cpp ${RELPATH}acltree.cpp ${RELPATH}threadlocalutils.cpp ${RELPATH}flashmq_plugin.cpp ${RELPATH}retainedmessagesdb.cpp ${RELPATH}persistencefile.cpp ${RELPATH}sessionsandsubscriptionsdb.cpp ${RELPATH}qospacketqueue.cpp ${RELPATH}threadglobals.cpp ${RELPATH}threadloop.cpp ${RELPATH}publishcopyfactory.cpp ${RELPATH}variablebyteint.cpp ${RELPATH}mqtt5properties.cpp ${RELPATH}globalstats.cpp ${RELPATH}derivablecounter.cpp ${RELPATH}packetdatatypes.cpp ${RELPATH}haproxy.cpp ${RELPATH}network.cpp ${RELPATH}subscription.cpp ${RELPATH}sharedsubscribers.cpp ${RELPATH}pluginloader.cpp ${RELPATH}queuedtasks.cpp ${RELPATH}acksender.cpp ${RELPATH}bridgeconfig.cpp ${RELPATH}dnsresolver.cpp ${RELPATH}bridgeinfodb.cpp ${RELPATH}globber.cpp ${RELPATH}x509manager.cpp ${RELPATH}backgroundworker.cpp ${RELPATH}fmqmain.cpp ${RELPATH}driftcounter.cpp ${RELPATH}lockedweakptr.cpp ${RELPATH}globals.cpp ${RELPATH}nocopy.cpp ${RELPATH}fdmanaged.cpp ${RELPATH}http.cpp ${RELPATH}fmqsockaddr.cpp ${RELPATH}fmqssl.cpp ${RELPATH}persistencefunctions.cpp ${RELPATH}clientacceptqueue.cpp ${RELPATH}sdnotify.cpp ) ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.5) cmake_policy(SET CMP0048 NEW) include(CheckCXXCompilerFlag) include(CMakeLists.shared) project(FlashMQ VERSION 1.26.1 LANGUAGES CXX) add_definitions(-DOPENSSL_API_COMPAT=0x10100000L) add_definitions(-DFLASHMQ_VERSION=\"${PROJECT_VERSION}\") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) if (FMQ_ASAN) message("Building with ASAN.") add_compile_options(-fsanitize=address) add_link_options(-fsanitize=address) endif() check_cxx_compiler_flag("-msse4.2" COMPILER_RT_HAS_MSSE4_2_FLAG) if (${COMPILER_RT_HAS_MSSE4_2_FLAG}) SET(CMAKE_CXX_FLAGS "-msse4.2") endif() add_compile_options(-fvisibility=hidden -fvisibility-inlines-hidden) add_link_options(-rdynamic) add_compile_options(-Wall -Wextra) add_executable(flashmq ${FLASHMQ_HEADERS} ${FLASHMQ_IMPLS} main.cpp ) target_link_libraries(flashmq pthread dl ssl crypto resolv anl) execute_process(COMMAND ../.get-os-codename-and-stamp.sh OUTPUT_VARIABLE OS_CODENAME) install(TARGETS flashmq RUNTIME DESTINATION "/usr/bin/") install(DIRECTORY DESTINATION "/var/lib/flashmq") install(DIRECTORY DESTINATION "/var/log/flashmq") install(FILES flashmq.conf DESTINATION "/etc/flashmq") install(FILES debian/flashmq.service DESTINATION "/lib/systemd/system") install(FILES man/flashmq.conf.5 DESTINATION "/usr/share/man/man5") install(FILES man/flashmq.1 DESTINATION "/usr/share/man/man1") SET(CPACK_DEBIAN_PACKAGE_CONTROL_EXTRA "${CMAKE_CURRENT_SOURCE_DIR}/debian/conffiles;${CMAKE_CURRENT_SOURCE_DIR}/debian/preinst;${CMAKE_CURRENT_SOURCE_DIR}/debian/postinst;${CMAKE_CURRENT_SOURCE_DIR}/debian/postrm;${CMAKE_CURRENT_SOURCE_DIR}/debian/prerm") SET(CPACK_GENERATOR "DEB") SET(CPACK_DEBIAN_PACKAGE_MAINTAINER "Wiebe Cazemier ") SET(CPACK_DEBIAN_PACKAGE_DESCRIPTION "Light-weight, high performance MQTT server capable of million+ messages per second.") SET(CPACK_PACKAGE_HOMEPAGE_URL "https://www.flashmq.org/") SET(CPACK_DEBIAN_PACKAGE_SHLIBDEPS ON) SET(CPACK_DEBIAN_FILE_NAME "DEB-DEFAULT") SET(CPACK_DEBIAN_PACKAGE_RELEASE ${OS_CODENAME}) SET(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR}) SET(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR}) SET(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH}) INCLUDE(CPack) ================================================ FILE: Dockerfile ================================================ # build target, used for building the binary, providing shared libraries and could be used as a development env FROM debian:trixie-slim AS build # install build dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -y install g++ make cmake libssl-dev file # create flashmq user and group for runtime image below RUN useradd --system --shell /bin/false --user-group --no-log-init flashmq ARG BUILD_TYPE=Release RUN echo "Building ${BUILD_TYPE} version" WORKDIR /usr/src/app COPY . . RUN rm -rf FlashMQBuild* 2>/dev/null RUN ./build.sh ${BUILD_TYPE} # convert docker buildx platform name to Debian platform name FROM scratch AS run-amd64 ARG PLATFORM=x86_64 ARG LD_LOCATION=/lib64/ld-linux-x86-64.so.2 FROM scratch AS run-arm64 ARG PLATFORM=aarch64 ARG LD_LOCATION=/lib/ld-linux-aarch64.so.1 # from scratch image is empty FROM run-$TARGETARCH AS run USER flashmq:flashmq COPY --from=build /etc/passwd /etc/passwd COPY --from=build /etc/group /etc/group # copy in the shared libaries in use discovered using ldd on release binary COPY --from=build /lib/${PLATFORM}-linux-gnu/libpthread.so.0 /lib/${PLATFORM}-linux-gnu/libpthread.so.0 COPY --from=build /lib/${PLATFORM}-linux-gnu/libdl.so.2 /lib/${PLATFORM}-linux-gnu/libdl.so.2 COPY --from=build /usr/lib/${PLATFORM}-linux-gnu/libssl.so.3 /usr/lib/${PLATFORM}-linux-gnu/libssl.so.3 COPY --from=build /usr/lib/${PLATFORM}-linux-gnu/libcrypto.so.3 /usr/lib/${PLATFORM}-linux-gnu/libcrypto.so.3 COPY --from=build /usr/lib/${PLATFORM}-linux-gnu/libstdc++.so.6 /usr/lib/${PLATFORM}-linux-gnu/libstdc++.so.6 COPY --from=build /lib/${PLATFORM}-linux-gnu/libgcc_s.so.1 /lib/${PLATFORM}-linux-gnu/libgcc_s.so.1 COPY --from=build /lib/${PLATFORM}-linux-gnu/libc.so.6 /lib/${PLATFORM}-linux-gnu/libc.so.6 COPY --from=build ${LD_LOCATION} ${LD_LOCATION} COPY --from=build /lib/${PLATFORM}-linux-gnu/libm.so.6 /lib/${PLATFORM}-linux-gnu/libm.so.6 COPY --from=build /lib/${PLATFORM}-linux-gnu/libresolv.so.2 /lib/${PLATFORM}-linux-gnu/libresolv.so.2 COPY --from=build /lib/${PLATFORM}-linux-gnu/libanl.so.1 /lib/${PLATFORM}-linux-gnu/libanl.so.1 COPY --from=build /lib/${PLATFORM}-linux-gnu/libz.so.1 /lib/${PLATFORM}-linux-gnu/libz.so.1 COPY --from=build /lib/${PLATFORM}-linux-gnu/libzstd.so.1 /lib/${PLATFORM}-linux-gnu/libzstd.so.1 # copy in the FlashMQ binary itself COPY --from=build /usr/src/app/FlashMQBuildRelease/flashmq /bin/flashmq EXPOSE 1883 CMD ["/bin/flashmq"] ================================================ FILE: FlashMQTests/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.5) cmake_policy(SET CMP0048 NEW) include(CheckCXXCompilerFlag) include(../CMakeLists.shared) project(FlashMQTests VERSION 1.0.0 LANGUAGES CXX) add_definitions(-DOPENSSL_API_COMPAT=0x10100000L) add_definitions(-DFLASHMQ_VERSION=\"${PROJECT_VERSION}\") add_definitions(-DTESTING) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) if (FMQ_NO_SSE) add_definitions(-DFMQ_NO_SSE) message("Building tests wihout SSE.") else() SET(CMAKE_CXX_FLAGS "-msse4.2") endif() add_link_options(-rdynamic) add_compile_options(-Wall -Wextra -fsanitize=address) add_link_options(-fsanitize=address) add_library(test_plugin SHARED plugins/test_plugin.h plugins/test_plugin.cpp plugins/curlfunctions.h plugins/curlfunctions.cpp ) target_link_libraries(test_plugin curl) # Copying to a version 0.0.1 file is a bit of a hack. I'm not sure how to version it in this CMake otherwise. add_custom_command(TARGET test_plugin POST_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory "plugins/") add_custom_command(TARGET test_plugin POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy $ "plugins/libtest_plugin.so.0.0.1") # Copy in a way that always copies, to avoid unexpected staleness. add_custom_command(TARGET test_plugin POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_SOURCE_DIR}/UTF-8-test.txt" "${CMAKE_BINARY_DIR}/UTF-8-test.txt") add_custom_command(TARGET test_plugin POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_SOURCE_DIR}/../fuzztests/plainwebsocketpacket1_handshake.dat" "${CMAKE_BINARY_DIR}/plainwebsocketpacket1_handshake.dat") add_executable(flashmq-tests ${FLASHMQ_HEADERS} ${FLASHMQ_IMPLS} ../flashmqtestclient.cpp ../flashmqtestclient.h main.cpp mainappinthread.h mainappinthread.cpp maintests.h maintests.cpp testhelpers.cpp testhelpers.h flashmqtempdir.h flashmqtempdir.cpp tst_maintests.cpp retaintests.cpp conffiletemp.cpp conffiletemp.h filecloser.cpp filecloser.h plugintests.cpp configtests.cpp sharedsubscriptionstests.cpp websockettests.cpp willtests.cpp dnstests.cpp utiltests.cpp testinitializer.h testinitializer.cpp mainappasfork.h mainappasfork.cpp subscriptionidtests.cpp bridgeprefixtests.cpp ) target_compile_options(flashmq-tests PUBLIC -fvisibility=hidden -fvisibility-inlines-hidden) target_link_options(flashmq-tests PUBLIC -rdynamic) target_include_directories(flashmq-tests PUBLIC ..) target_link_libraries(flashmq-tests pthread dl ssl crypto resolv anl) ================================================ FILE: FlashMQTests/bridgeprefixtests.cpp ================================================ #include "maintests.h" #include "conffiletemp.h" #include "flashmqtestclient.h" #include "mainappasfork.h" #include "testhelpers.h" #include "flashmqtempdir.h" void waitForMessagesOverBridge(FlashMQTestClient &one, FlashMQTestClient &two, const std::string &topic) { int wait_i = 0; for(wait_i = 0; wait_i < 10; wait_i++) { one.publish(topic, "connectiontest", 2); try { two.waitForMessageCount(1); break; } catch (std::exception &ex) { } } if (wait_i >= 10) throw std::runtime_error("Timeout waiting for messages over bridge"); } void MainTests::forkingTestBridgeWithLocalAndRemotePrefix() { for (const std::string protocol_version : {"mqtt5", "mqtt3.1"}) { cleanup(); ConfFileTemp confFile; const std::string config = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe ManglerRemote/shoes 2 publish ManglerLocal/boots 2 clientid_prefix Mangler local_prefix ManglerLocal/ remote_prefix ManglerRemote/ protocol_version %s } listen { protocol mqtt port 51183 })"; confFile.writeLine(formatString(config, protocol_version.c_str())); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(51183); // We will consider the test server initialized here the 'remote' broker. init(); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); FlashMQTestClient clientToRemote; clientToRemote.start(); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt31, 51183); clientToLocalWithBridge.subscribe("#", 1); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "ManglerLocal/boots"); clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToLocalWithBridge.publish("ManglerLocal/boots", "asdf", 2); clientToRemote.waitForMessageCount(1); auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/boots")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "asdf"); } clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToRemote.publish("ManglerRemote/shoes", "are made for walking", 2); clientToLocalWithBridge.waitForMessageCount(1); auto ro = clientToLocalWithBridge.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerLocal/shoes")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "are made for walking"); } } } /** * @brief Test the internal packet cache; that we don't accidentally cache packets with the prefixes applied. * * The PublishCopyFactory was temporarily broken first to write this test. */ void MainTests::forkingTestBridgePrefixesOtherClientsUnaffected() { for (const std::string protocol_version : {"mqtt5", "mqtt3.1"}) { cleanup(); ConfFileTemp confFile; const std::string config = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe ManglerRemote/shoes 2 publish ManglerLocal/boots 2 clientid_prefix Mangler local_prefix ManglerLocal/ remote_prefix ManglerRemote/ protocol_version %s } listen { protocol mqtt port 51183 })"; confFile.writeLine(formatString(config, protocol_version.c_str())); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(51183); // We will consider the test server initialized here the 'remote' broker. init(); FlashMQTestClient randomOtherLocalClient1; randomOtherLocalClient1.start(); randomOtherLocalClient1.connectClient(ProtocolVersion::Mqtt5, 51183); // Not using wild-cards because we happen to know that produces the correct delivery order to subscribers to test this. randomOtherLocalClient1.subscribe("ManglerLocal/boots", 1); randomOtherLocalClient1.subscribe("ManglerRemote/boots", 1); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); FlashMQTestClient clientToRemote; clientToRemote.start(); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt31, 51183); clientToLocalWithBridge.subscribe("#", 1); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); FlashMQTestClient randomOtherLocalClient2; randomOtherLocalClient2.start(); randomOtherLocalClient2.connectClient(ProtocolVersion::Mqtt5, 51183); // Not using wild-cards because we happen to know that produces the correct delivery order to subscribers to test this. randomOtherLocalClient2.subscribe("ManglerLocal/boots", 1); randomOtherLocalClient2.subscribe("ManglerRemote/boots", 1); FlashMQTestClient randomOtherLocalClient3; randomOtherLocalClient3.start(); randomOtherLocalClient3.connectClient(ProtocolVersion::Mqtt5, 51183); // For this one, we do use the wildcard subscription, just to be sure. randomOtherLocalClient3.subscribe("ManglerLocal/#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "ManglerLocal/boots"); clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToLocalWithBridge.publish("ManglerLocal/boots", "asdf", 2); clientToRemote.waitForMessageCount(1); auto ro = clientToRemote.receivedObjects.lock(); //FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/boots")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "asdf"); } clientToLocalWithBridge.clearReceivedLists(); { randomOtherLocalClient1.waitForMessageCount(2); auto ro = randomOtherLocalClient1.receivedObjects.lock(); FMQ_VERIFY(std::all_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return startsWith(p.getTopic(), "ManglerLocal/"); })); } { randomOtherLocalClient2.waitForMessageCount(2); auto ro = randomOtherLocalClient2.receivedObjects.lock(); FMQ_VERIFY(std::all_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return startsWith(p.getTopic(), "ManglerLocal/"); })); } { randomOtherLocalClient3.waitForMessageCount(2); auto ro = randomOtherLocalClient3.receivedObjects.lock(); FMQ_VERIFY(std::all_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return startsWith(p.getTopic(), "ManglerLocal/"); })); } } } void MainTests::forkingTestBridgeWithOnlyRemotePrefix() { for (const std::string protocol_version : {"mqtt5", "mqtt3.1"}) { cleanup(); ConfFileTemp confFile; const std::string config = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe ManglerRemote/shoes 2 publish boots 2 clientid_prefix Mangler remote_prefix ManglerRemote/ protocol_version %s } listen { protocol mqtt port 51183 })"; confFile.writeLine(formatString(config, protocol_version.c_str())); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(51183); // We will consider the test server initialized here the 'remote' broker. init(); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); FlashMQTestClient clientToRemote; clientToRemote.start(); FlashMQTestClient receiverToLocal; receiverToLocal.start(); FlashMQTestClient receiverToLocal2; receiverToLocal2.start(); receiverToLocal2.connectClient(ProtocolVersion::Mqtt31, 51183); receiverToLocal2.subscribe("#", 1); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt31, 51183); clientToLocalWithBridge.subscribe("#", 1); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "boots"); clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); receiverToLocal.connectClient(ProtocolVersion::Mqtt31, 51183); receiverToLocal.subscribe("#", 1); { clientToLocalWithBridge.publish("boots", "asdf", 2); clientToRemote.waitForMessageCount(1); auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/boots")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "asdf"); } // Make sure other clients connected to the server with prefixes defined get normal topics. I.e. test packet cache. { receiverToLocal.waitForMessageCount(1); auto roLocal = receiverToLocal.receivedObjects.lock(); FMQ_COMPARE(roLocal->receivedPublishes.at(0).getTopic(), std::string("boots")); FMQ_COMPARE(roLocal->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(roLocal->receivedPublishes.at(0).getPayloadView(), "asdf"); } { receiverToLocal2.waitForMessageCount(1); auto roLocal2 = receiverToLocal2.receivedObjects.lock(); FMQ_COMPARE(roLocal2->receivedPublishes.back().getTopic(), std::string("boots")); FMQ_COMPARE(roLocal2->receivedPublishes.back().getQos(), 1); FMQ_COMPARE(roLocal2->receivedPublishes.back().getPayloadView(), "asdf"); } clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToRemote.publish("ManglerRemote/shoes", "are made for walking", 2); clientToLocalWithBridge.waitForMessageCount(1); auto ro = clientToLocalWithBridge.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("shoes")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "are made for walking"); } } } void MainTests::forkingTestBridgeWithOnlyLocalPrefix() { for (const std::string protocol_version : {"mqtt5", "mqtt3.1"}) { cleanup(); ConfFileTemp confFile; const std::string config = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe shoes 2 publish ManglerLocal/boots 2 clientid_prefix Mangler local_prefix ManglerLocal/ protocol_version %s } listen { protocol mqtt port 51183 })"; confFile.writeLine(formatString(config, protocol_version.c_str())); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(51183); // We will consider the test server initialized here the 'remote' broker. init(); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); FlashMQTestClient clientToRemote; clientToRemote.start(); FlashMQTestClient senderToLocal; senderToLocal.start(); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt31, 51183); clientToLocalWithBridge.subscribe("#", 1); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "ManglerLocal/boots"); clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToLocalWithBridge.publish("ManglerLocal/boots", "asdf", 2); clientToRemote.waitForMessageCount(1); auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("boots")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "asdf"); } clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToRemote.publish("shoes", "are made for walking", 2); clientToLocalWithBridge.waitForMessageCount(1); auto ro = clientToLocalWithBridge.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerLocal/shoes")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "are made for walking"); } // Make sure other clients connected to the server with prefixes defined get normal topics. { clientToLocalWithBridge.clearReceivedLists(); senderToLocal.connectClient(ProtocolVersion::Mqtt31, 51183); senderToLocal.publish("panic", "attack", 0); clientToLocalWithBridge.waitForMessageCount(1); auto ro = clientToLocalWithBridge.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("panic")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 0); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "attack"); } } } /** * @brief Test that we don't apply the prefix on outgoing messages when it's the same as the topic string. * * When having a prefix of 'one/two/', this is to prevent a subscription like 'one/two/#', turning topic 'one/two/' * into '', which is illegal (while empty subtopic strings, like caused by trailing slash, is legal). */ void MainTests::forkingTestBridgeOutgoingTopicEqualsPrefix() { cleanup(); ConfFileTemp confFile; const std::string config = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe shoes 2 publish ManglerLocal/# 2 clientid_prefix Mangler local_prefix ManglerLocal/ protocol_version mqtt5 } listen { protocol mqtt port 51183 })"; confFile.writeLine(config); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(51183); // We will consider the test server initialized here the 'remote' broker. init(); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); FlashMQTestClient clientToRemote; clientToRemote.start(); FlashMQTestClient senderToLocal; senderToLocal.start(); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt31, 51183); clientToLocalWithBridge.subscribe("#", 1); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "ManglerLocal/boots"); clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToLocalWithBridge.publish("ManglerLocal/", "WC4WQbYJb76TguUT", 2); clientToRemote.waitForMessageCount(1); auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerLocal/")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "WC4WQbYJb76TguUT"); } } /** * @brief Same as forkingTestBridgeOutgoingTopicEqualsPrefix, but then for incoming. */ void MainTests::forkingTestBridgeIncomingTopicEqualsPrefix() { cleanup(); ConfFileTemp confFile; const std::string config = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe ManglerRemote/# 2 publish boots 2 clientid_prefix Mangler remote_prefix ManglerRemote/ protocol_version mqtt5 } listen { protocol mqtt port 51183 })"; confFile.writeLine(config); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(51183); // We will consider the test server initialized here the 'remote' broker. init(); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); FlashMQTestClient clientToRemote; clientToRemote.start(); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt31, 51183); clientToLocalWithBridge.subscribe("#", 1); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "boots"); clientToLocalWithBridge.clearReceivedLists(); clientToRemote.clearReceivedLists(); { clientToRemote.publish("ManglerRemote/", "XN0KoFeDOimgRiTs", 2); clientToLocalWithBridge.waitForMessageCount(1); auto ro = clientToLocalWithBridge.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "XN0KoFeDOimgRiTs"); } } void MainTests::forkingTestBridgeZWithLocalAndRemotePrefixRetained() { for (const std::string protocol_version : {"mqtt5", "mqtt3.1"}) { cleanup(); ConfFileTemp conf_file_remote; { const std::string config_remote = R"( allow_anonymous true log_debug false listen { protocol mqtt port 51883 })"; conf_file_remote.writeLine(formatString(config_remote, protocol_version.c_str())); conf_file_remote.closeFile(); } std::vector args_remote {"--config-file", conf_file_remote.getFilePath()}; MainAppAsFork remoteServer(args_remote); remoteServer.start(); remoteServer.waitForStarted(51883); FlashMQTestClient clientToRemote; clientToRemote.start(); clientToRemote.connectClient(ProtocolVersion::Mqtt5, 51883); { Publish pub("ManglerRemote/shoes/retainme", "asdf", 2); pub.retain = true; clientToRemote.publish(pub); } ConfFileTemp conf_file_local; { const std::string config_local = R"( allow_anonymous true log_debug false bridge { address ::1 port 51883 subscribe ManglerRemote/shoes/# 2 publish ManglerLocal/boots/# 2 clientid_prefix Mangler local_prefix ManglerLocal/ remote_prefix ManglerRemote/ protocol_version %s } listen { protocol mqtt port 21883 })"; conf_file_local.writeLine(formatString(config_local, protocol_version.c_str())); conf_file_local.closeFile(); std::vector args_local {"--config-file", conf_file_local.getFilePath()}; // We consider our normal test client as 'local' init(args_local); } FlashMQTestClient clientToLocal; clientToLocal.start(); clientToLocal.connectClient(ProtocolVersion::Mqtt5); clientToLocal.subscribe("#", 1); clientToLocal.waitForMessageCount(1); { auto ro = clientToLocal.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.size(), static_cast(1)); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerLocal/shoes/retainme")); } clientToLocal.clearReceivedLists(); waitForMessagesOverBridge(clientToRemote, clientToLocal, "ManglerRemote/shoes/connectiontest"); { auto ro = clientToLocal.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.size(), static_cast(1)); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerLocal/shoes/connectiontest")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), std::string("connectiontest")); FMQ_COMPARE(ro->receivedPublishes.at(0).getRetain(), false); } { FlashMQTestClient clientToLocal; clientToLocal.start(); clientToLocal.connectClient(ProtocolVersion::Mqtt5); clientToLocal.subscribe("#", 1); clientToLocal.waitForPacketCount(1); auto ro = clientToLocal.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.size(), static_cast(1)); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerLocal/shoes/retainme")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "asdf"); FMQ_COMPARE(ro->receivedPublishes.at(0).getRetain(), true); } // Push a retained message to local { FlashMQTestClient clientToLocal; clientToLocal.start(); clientToLocal.connectClient(ProtocolVersion::Mqtt5); Publish pub("ManglerLocal/boots/retainmetoo", "zuOJHOekvdkGD9FH", 2); pub.retain = true; clientToLocal.publish(pub); } // That we then must see as retained on the remote { FlashMQTestClient clientToRemote; clientToRemote.start(); clientToRemote.connectClient(ProtocolVersion::Mqtt5, 51883); clientToRemote.subscribe("ManglerRemote/boots/#", 2); clientToRemote.waitForMessageCount(1); auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.size(), static_cast(1)); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/boots/retainmetoo")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 2); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "zuOJHOekvdkGD9FH"); FMQ_COMPARE(ro->receivedPublishes.at(0).getRetain(), true); } } } /** * @brief Test if queued QoS messages have have their prefixes applied. Special care is taken in FlashMQ to support * that (the topic_override is part of the QueuedPublish), so we need to test it. */ void MainTests::forkingTestBridgeWithLocalAndRemotePrefixQueuedQoS() { for (const std::string protocol_version : {"mqtt5", "mqtt3.1.1"}) { cleanup(); ConfFileTemp conf_file_local; const std::string config_local = R"( allow_anonymous true log_debug false bridge { address ::1 port 21883 subscribe ManglerRemote/shoes/# 2 publish ManglerLocal/boots/# 2 clientid_prefix Mangler remote_clean_start false local_clean_start false remote_session_expiry_interval 300 local_session_expiry_interval 300 local_prefix ManglerLocal/ remote_prefix ManglerRemote/ protocol_version %s } listen { protocol mqtt port 51883 })"; conf_file_local.writeLine(formatString(config_local, protocol_version.c_str())); conf_file_local.closeFile(); std::vector args_local {"--config-file", conf_file_local.getFilePath()}; MainAppAsFork localServer(args_local); localServer.start(); localServer.waitForStarted(51883); FlashMQTestClient clientToLocalWithBridge; clientToLocalWithBridge.start(); clientToLocalWithBridge.connectClient(ProtocolVersion::Mqtt5, 51883); FlashMQTempDir remote_server_storage_dir; ConfFileTemp conf_file_remote; const std::string config_remote = R"( allow_anonymous true log_debug false storage_dir %s listen { protocol mqtt port 21883 } )"; conf_file_remote.writeLine(formatString(config_remote, remote_server_storage_dir.getPath().c_str())); conf_file_remote.closeFile(); std::vector args_remote {"--config-file", conf_file_remote.getFilePath()}; // Bring the remote on-line init(args_remote); FlashMQTestClient clientToRemote; clientToRemote.start(); clientToRemote.connectClient(ProtocolVersion::Mqtt5); clientToRemote.subscribe("#", 1); waitForMessagesOverBridge(clientToLocalWithBridge, clientToRemote, "ManglerLocal/boots/connectiontest"); { auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.size(), static_cast(1)); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/boots/connectiontest")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), std::string("connectiontest")); } clientToRemote.clearReceivedLists(); // Stop the remote cleanup(); Publish pub("ManglerLocal/boots/queue_me", "late to the party", 1); clientToLocalWithBridge.publish(pub); std::cout << "Starting remote server again" << std::endl; // Start the remote again init(args_remote); { FlashMQTestClient clientToRemote; clientToRemote.start(); clientToRemote.connectClient(ProtocolVersion::Mqtt5, false, 1000, [](Connect &c) { c.clientid = "QueuedReceiver_666"; }); clientToRemote.subscribe("+/boots/queue_me", 1); clientToRemote.waitForMessageCount(1, 10); auto ro = clientToRemote.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.size(), static_cast(1)); FMQ_COMPARE(ro->receivedPublishes.at(0).getTopic(), std::string("ManglerRemote/boots/queue_me")); FMQ_COMPARE(ro->receivedPublishes.at(0).getQos(), 1); FMQ_COMPARE(ro->receivedPublishes.at(0).getPayloadView(), "late to the party"); } } } ================================================ FILE: FlashMQTests/conffiletemp.cpp ================================================ #include "conffiletemp.h" #include #include #include ConfFileTemp::ConfFileTemp() { const std::string templateName("/tmp/flashmqconf_XXXXXX"); std::vector nameBuf(templateName.size() + 1, 0); std::copy(templateName.begin(), templateName.end(), nameBuf.begin()); this->fd = mkstemp(nameBuf.data()); if (this->fd < 0) { throw std::runtime_error("mkstemp error."); } this->filePath = nameBuf.data(); } ConfFileTemp::~ConfFileTemp() { closeFile(); if (!this->filePath.empty()) unlink(this->filePath.c_str()); } const std::string &ConfFileTemp::getFilePath() const { if (fd > 0) throw std::runtime_error("You first need to close the file before using it."); return this->filePath; } void ConfFileTemp::writeLine(const std::string &line) { if (write(this->fd, line.c_str(), line.size()) < 0) throw std::runtime_error("Config file write failed"); if (write(this->fd, "\n", 1) < 0) throw std::runtime_error("Config file write failed"); } void ConfFileTemp::closeFile() { if (this->fd < 0) return; close(this->fd); this->fd = -1; } ================================================ FILE: FlashMQTests/conffiletemp.h ================================================ #ifndef CONFFILETEMP_H #define CONFFILETEMP_H #include class ConfFileTemp { int fd = -1; std::string filePath; public: ConfFileTemp(); ~ConfFileTemp(); const std::string &getFilePath() const; void writeLine(const std::string &line); void closeFile(); ConfFileTemp &operator=(const ConfFileTemp &other) = delete; ConfFileTemp(const ConfFileTemp &other) = delete; ConfFileTemp(ConfFileTemp &&other) = delete; }; #endif // CONFFILETEMP_H ================================================ FILE: FlashMQTests/configtests.cpp ================================================ #include "maintests.h" #include "testhelpers.h" #include "conffiletemp.h" #include "exceptions.h" #include "settings.h" void MainTests::test_loading_second_value() { /* this is expected to work*/ { ConfFileTemp config; config.writeLine("bridge {"); config.writeLine(" address localhost"); config.writeLine(" publish send/this 1"); // this value should be different from the default (0) config.writeLine("}"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); parser.loadFile(false); Settings settings = parser.getSettings(); std::list bridges = settings.stealBridges(); FMQ_VERIFY(!bridges.empty()); BridgeConfig &bridge = bridges.front(); FMQ_COMPARE(bridge.publishes[0].topic, "send/this"); FMQ_COMPARE(bridge.publishes[0].qos, (uint8_t)1); } /* this is expecte to fail because "address" doesn't take a second value */ { ConfFileTemp config; config.writeLine("bridge {"); config.writeLine(" address localhost thisisnotok"); config.writeLine(" publish send/this 1"); config.writeLine("}"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); try { parser.loadFile(false); FMQ_FAIL("The config parser is too liberal"); } catch (ConfigFileException &ex) { /* Excellent, what we wanted */ } } } void MainTests::test_parsing_numbers() { /* this should work: 180 */ { ConfFileTemp config; config.writeLine("expire_sessions_after_seconds 180"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); parser.loadFile(false); Settings settings = parser.getSettings(); FMQ_COMPARE(settings.expireSessionsAfterSeconds, (uint32_t)180); } /* this should fail: 180days */ { ConfFileTemp config; config.writeLine("expire_sessions_after_seconds 180days"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); try { parser.loadFile(false); FMQ_FAIL("The parser was too liberal"); } catch (ConfigFileException&) { /* Good! This is where we want to end up in */ } } /* this should also fail: 180 days */ { ConfFileTemp config; config.writeLine("expire_sessions_after_seconds 180 days"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); try { parser.loadFile(false); FMQ_FAIL("The parser was too liberal"); } catch (ConfigFileException&) { /* Good! This is where we want to end up in */ } } /* Last one that should fail: 180 days and a bit */ { ConfFileTemp config; config.writeLine("expire_sessions_after_seconds 180 days and a bit more"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); try { parser.loadFile(false); FMQ_FAIL("The parser was too liberal"); } catch (ConfigFileException&) { /* Good! This is where we want to end up in */ } } } void MainTests::testStringDistances() { FMQ_COMPARE(distanceBetweenStrings("", ""), (unsigned int)0); FMQ_COMPARE(distanceBetweenStrings("dog", ""), (unsigned int)3); FMQ_COMPARE(distanceBetweenStrings("", "dog"), (unsigned int)3); FMQ_COMPARE(distanceBetweenStrings("dog", "horse"), (unsigned int)4); FMQ_COMPARE(distanceBetweenStrings("horse", "dog"), (unsigned int)4); FMQ_COMPARE(distanceBetweenStrings("industry", "interest"), (unsigned int)6); FMQ_COMPARE(distanceBetweenStrings("kitten", "sitting"), (unsigned int)3); FMQ_COMPARE(distanceBetweenStrings("uninformed", "uniformed"), (unsigned int)1); } void MainTests::testConfigSuggestion() { // User made a small typo: 'session' instead of 'sessions' { ConfFileTemp config; config.writeLine("expire_session_after_seconds 180"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); try { parser.loadFile(false); FMQ_FAIL("The parser is too liberal"); } catch (ConfigFileException &ex) { FMQ_COMPARE(ex.what(), "Config key 'expire_session_after_seconds' is not valid (here). Did you mean: expire_sessions_after_seconds ?"); } } // User entered gibberish. Let's not suggest gibberish back { ConfFileTemp config; config.writeLine("foobarbaz 180"); config.closeFile(); ConfigFileParser parser(config.getFilePath()); try { parser.loadFile(false); FMQ_FAIL("The parser is too liberal"); } catch (ConfigFileException &ex) { FMQ_COMPARE(ex.what(), "Config key 'foobarbaz' is not valid (here)."); } } } void MainTests::testFlags() { Flags flags; FMQ_VERIFY(flags.hasNone()); flags.setAll(); FMQ_VERIFY(flags.hasAll()); flags.clearFlag(PersistenceDataToSave::BridgeInfo); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::RetainedMessages)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::BridgeInfo)); flags.clearFlag(PersistenceDataToSave::RetainedMessages); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::RetainedMessages)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::BridgeInfo)); flags.clearFlag(PersistenceDataToSave::SessionsAndSubscriptions); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::RetainedMessages)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::BridgeInfo)); flags.setFlag(PersistenceDataToSave::SessionsAndSubscriptions); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::RetainedMessages)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::BridgeInfo)); flags.setFlag(PersistenceDataToSave::RetainedMessages); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::RetainedMessages)); FMQ_VERIFY(!flags.hasFlagSet(PersistenceDataToSave::BridgeInfo)); flags.setFlag(PersistenceDataToSave::BridgeInfo); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::RetainedMessages)); FMQ_VERIFY(flags.hasFlagSet(PersistenceDataToSave::BridgeInfo)); flags.clearAll(); FMQ_VERIFY(flags.hasNone()); } ================================================ FILE: FlashMQTests/dnstests.cpp ================================================ #include "maintests.h" #include "testhelpers.h" #include "utils.h" void MainTests::testDnsResolver() { try { DnsResolver resolver; resolver.query("demo.flashmq.org", ListenerProtocol::IPv46, std::chrono::milliseconds(5000)); int count = 0; while (++count < 100) { std::list results = resolver.getResult(); if (!results.empty()) { QVERIFY(std::any_of(results.begin(), results.end(), [](FMQSockaddr &x){return x.getText() == "89.188.6.194";})); QVERIFY(std::any_of(results.begin(), results.end(), [](FMQSockaddr &x){return x.getText() == "2a01:1b0:7996:418:83:137:146:230";})); break; } usleep(10000); } if (count >= 100) QVERIFY(false); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testDnsResolverDontCancel() { try { DnsResolver resolver; resolver.query("demo.flashmq.org", ListenerProtocol::IPv46, std::chrono::milliseconds(5000)); resolver.query("demo.flashmq.org", ListenerProtocol::IPv46, std::chrono::milliseconds(5000)); int count = 0; while (++count < 100) { std::list results = resolver.getResult(); if (!results.empty()) { QVERIFY(std::any_of(results.begin(), results.end(), [](FMQSockaddr &x){return x.getText() == "89.188.6.194";})); QVERIFY(std::any_of(results.begin(), results.end(), [](FMQSockaddr &x){return x.getText() == "2a01:1b0:7996:418:83:137:146:230";})); break; } usleep(10000); } if (count >= 100) QVERIFY(false); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testDnsResolverSecondQuery() { try { DnsResolver resolver; for (int i = 0; i < 2; i++) { resolver.query("demo.flashmq.org", ListenerProtocol::IPv46, std::chrono::milliseconds(5000)); int count = 0; while (++count < 100) { std::list results = resolver.getResult(); if (!results.empty()) { QVERIFY(std::any_of(results.begin(), results.end(), [](FMQSockaddr &x){return x.getText() == "89.188.6.194";})); QVERIFY(std::any_of(results.begin(), results.end(), [](FMQSockaddr &x){return x.getText() == "2a01:1b0:7996:418:83:137:146:230";})); break; } usleep(10000); } if (count >= 100) QVERIFY(false); } } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testDnsResolverInvalid() { try { DnsResolver resolver; const std::string rnd = getSecureRandomString(16); const std::string domain = rnd + ".flashmq.org"; resolver.query(domain, ListenerProtocol::IPv46, std::chrono::milliseconds(5000)); int count = 0; while (count++ < 60) { std::list results = resolver.getResult(); if (!results.empty()) { for (const auto &r : results) { std::cerr << "Wrong DNS result: " << r.getText() << std::endl; } QVERIFY2(false, "A DNS result was returned when we expected nothing."); break; } usleep(100000); } QVERIFY2(false, "It took too long to get a result. That's weird, because we should have gotten a timeout exception."); } catch (std::exception &ex) { const std::string err = str_tolower(ex.what()); std::cout << "For reference, the error that we're scanning: " << err << std::endl; QVERIFY(strContains(err, "name or service not known")); } } void MainTests::testGetResultWhenThereIsNone() { try { DnsResolver resolver; std::list results = resolver.getResult(); QVERIFY(results.empty()); QVERIFY(false); } catch (std::exception &ex) { std::string err = str_tolower(ex.what()); QVERIFY(strContains(err, "no dns query in progress")); } } ================================================ FILE: FlashMQTests/filecloser.cpp ================================================ #include "filecloser.h" #include FileCloser::FileCloser(int fd) : fd(fd) { } FileCloser::~FileCloser() { if (fd >= 0) close(fd); fd = -1; } ================================================ FILE: FlashMQTests/filecloser.h ================================================ #ifndef FILECLOSER_H #define FILECLOSER_H class FileCloser { int fd = -1; public: FileCloser(int fd); ~FileCloser(); }; #endif // FILECLOSER_H ================================================ FILE: FlashMQTests/flashmqtempdir.cpp ================================================ #include "flashmqtempdir.h" #include #include #include #include "utils.h" FlashMQTempDir::FlashMQTempDir() { const std::string templateName(std::filesystem::temp_directory_path() / "flashmq_test_XXXXXX"); std::vector nameBuf(templateName.size() + 1, 0); std::copy(templateName.begin(), templateName.end(), nameBuf.begin()); this->path = std::string(mkdtemp(nameBuf.data())); } FlashMQTempDir::~FlashMQTempDir() { if (this->path.empty() || !strContains(this->path, "flashmq_test_")) return; // Not pretty, but whatever works... int pid = fork(); if (pid == 0) { execlp("rm", "rm", "-rf", "--", this->path.c_str(), (char*)NULL); } } const std::filesystem::path &FlashMQTempDir::getPath() const { return this->path; } ================================================ FILE: FlashMQTests/flashmqtempdir.h ================================================ #ifndef FLASHMQTEMPDIR_H #define FLASHMQTEMPDIR_H #include #include #include #include class FlashMQTempDir { std::filesystem::path path; public: FlashMQTempDir(); ~FlashMQTempDir(); const std::filesystem::path &getPath() const; }; #endif // FLASHMQTEMPDIR_H ================================================ FILE: FlashMQTests/main.cpp ================================================ #include #include #include "maintests.h" void printHelp(const std::string &arg0) { std::cout << std::endl; std::cout << "Usage: " << arg0 << " [ --skip-tests-with-internet ] [ --skip-server-tests ] " << " " << std::endl; } int main(int argc, char *argv[]) { bool skip_tests_with_internet = false; bool skip_server_tests = false; bool abort_on_first_fail = false; std::vector tests; bool option_list_terminated = false; for (int i = 1; i < argc ; i++) { const std::string name(argv[i]); if (option_list_terminated) tests.push_back(name); else if (name == "--") option_list_terminated = true; else if (name == "--help") { printHelp(argv[0]); return 1; } else if (name == "--skip-tests-with-internet") skip_tests_with_internet = true; else if (name == "--skip-server-tests") skip_server_tests = true; else if (name == "--abort-on-first-fail") abort_on_first_fail = true; else if (name.find("--") == 0) { std::cerr << "Unknown argument " << name << std::endl; printHelp(argv[0]); return 1; } else tests.push_back(name); } MainTests maintests; if (!maintests.test(skip_tests_with_internet, skip_server_tests, abort_on_first_fail, tests)) return 1; return 0; } ================================================ FILE: FlashMQTests/mainappasfork.cpp ================================================ #include "mainappasfork.h" #include #include "signal.h" #include "fmqmain.h" #include "sys/wait.h" std::string MainAppAsFork::getConfigFileFromArgs(const std::vector &args) { std::string result = ""; bool next = false; for(const std::string &arg : args) { if (arg == "--config-file") { next = true; continue; } if (next) { result = arg; break; } } return result; } MainAppAsFork::MainAppAsFork() { defaultConf.writeLine("allow_anonymous true"); defaultConf.closeFile(); args.push_back("--config-file"); args.push_back(defaultConf.getFilePath()); } MainAppAsFork::MainAppAsFork(const std::vector &args) : args(args) { } MainAppAsFork::~MainAppAsFork() { this->stop(); } void MainAppAsFork::start() { // We must not have threads when we fork. Logger::getInstance()->quit(); pid_t pid = fork(); if (pid < 0) throw std::runtime_error("What the fork?"); if (pid == 0) { try { std::list> argCopies; const std::string programName = "FlashMQTests"; std::vector programNameCopy(programName.size() + 1, 0); std::copy(programName.begin(), programName.end(), programNameCopy.begin()); argCopies.push_back(std::move(programNameCopy)); for (const std::string &arg : args) { std::vector copyArg(arg.size() + 1, 0); std::copy(arg.begin(), arg.end(), copyArg.begin()); argCopies.push_back(std::move(copyArg)); } char *argv[256]; memset(argv, 0, 256*sizeof (char*)); int i = 0; for (std::vector © : argCopies) { argv[i++] = copy.data(); } int r = fmqmain(i, argv); ::exit(r); } catch (std::exception &ex) { std::cout << "The forked process threw an exception: " << ex.what() << std::endl; std::cerr << "The forked process threw an exception: " << ex.what() << std::endl; } // Does not call destructors. abort(); } this->child = pid; } void MainAppAsFork::stop() { if (this->child <= 0) return; kill(this->child, SIGTERM); int status = 0; waitpid(this->child, &status, 0); this->child = -1; } void MainAppAsFork::waitForStarted(int port) { int sockfd = check(socket(AF_INET, SOCK_STREAM, 0)); struct sockaddr_in addr; memset(&addr, 0, sizeof(struct sockaddr_in)); addr.sin_port = htons(port); addr.sin_family = AF_INET; inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); struct sockaddr *addr2 = reinterpret_cast(&addr); bool timeout = false; int n = 0; while (connect(sockfd, addr2, sizeof(struct sockaddr_in)) != 0) { std::this_thread::sleep_for(std::chrono::milliseconds(20)); if (n++ > 100) { timeout = true; break; } } close(sockfd); if (timeout) throw std::runtime_error("Forked mainapp failed to start?"); } ================================================ FILE: FlashMQTests/mainappasfork.h ================================================ #ifndef MAINAPPASFORK_H #define MAINAPPASFORK_H #include #include #include #include "conffiletemp.h" /** * @brief The MainAppAsFork class provides a way to run the FlashMQ server in a separate process for tests. This is convenient for isolation, * running multiple versions (bridging) and avoids conflicts with (thread) globals. But, it can also be inconvenient, because you * don't have access to the internal state (for assertions) and on abnormal exist the child processes may linger (although we * could devise stuff against that). */ class MainAppAsFork { pid_t child = -1; std::vector args; ConfFileTemp defaultConf; public: static std::string getConfigFileFromArgs(const std::vector &args); MainAppAsFork(); MainAppAsFork(const std::vector &args); ~MainAppAsFork(); void start(); void stop(); void waitForStarted(int port=21883); }; #endif // MAINAPPASFORK_H ================================================ FILE: FlashMQTests/mainappinthread.cpp ================================================ #include "mainappinthread.h" MainAppInThread::MainAppInThread() { } MainAppInThread::MainAppInThread(const std::vector &args) : mArgs(args) { } MainAppInThread::~MainAppInThread() { stopApp(); } void MainAppInThread::start() { const auto args = mArgs; auto thread_task = [this, args]() { std::list> argCopies; const std::string programName = "FlashMQTests"; std::vector programNameCopy(programName.size() + 1, 0); std::copy(programName.begin(), programName.end(), programNameCopy.begin()); argCopies.push_back(std::move(programNameCopy)); for (const std::string &arg : args) { std::vector copyArg(arg.size() + 1, 0); std::copy(arg.begin(), arg.end(), copyArg.begin()); argCopies.push_back(std::move(copyArg)); } char *argv[256]; memset(argv, 0, 256*sizeof (char*)); int i = 0; for (std::vector © : argCopies) { argv[i++] = copy.data(); } std::shared_ptr mainapp; { auto mainapp_locked = mMainApp.lock(); std::shared_ptr &mainapp_member = *mainapp_locked; mainapp_member = MainApp::initMainApp(i, argv); mainapp = mainapp_member; } // A hack: when I supply args I probably define a config for auth stuff. if (args.empty()) mainapp->settings.allowAnonymous = true; mainapp->start(); }; this->thread = std::thread(thread_task); } void MainAppInThread::stopApp() { { auto mainapp_locked = mMainApp.lock(); std::shared_ptr &mainapp = *mainapp_locked; if (mainapp) mainapp->quit(); } if (this->thread.joinable()) this->thread.join(); } void MainAppInThread::waitForStarted() { auto started_f = [this] { std::shared_ptr mainapp; { auto mainapp_locked = mMainApp.lock(); mainapp = *mainapp_locked; } if (!mainapp) return false; return mainapp->getStarted(); }; int n = 0; while(!started_f()) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); if (n++ > 500) throw std::runtime_error("Waiting for app to start failed."); } } ================================================ FILE: FlashMQTests/mainappinthread.h ================================================ #ifndef MAINAPPINTHREAD_H #define MAINAPPINTHREAD_H #include #include "mainapp.h" #include "mutexowned.h" class MainAppInThread { std::thread thread; const std::vector mArgs; MutexOwned> mMainApp; public: MainAppInThread(); MainAppInThread(const std::vector &args); ~MainAppInThread(); void start(); void stopApp(); void waitForStarted(); }; #endif // MAINAPPINTHREAD_H ================================================ FILE: FlashMQTests/maintests.cpp ================================================ #include #include "maintests.h" #include "testhelpers.h" #include "testinitializer.h" #include "threadglobals.h" void MainTests::testAsserts() { { auto a = [] { assert_count = 0; assert_fail_count = 0; QCOMPARE(1, 1); }; a(); if (assert_count != 1 || assert_fail_count != 0) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; QCOMPARE(1, 2); }; a(); if (assert_count != 1 || assert_fail_count != 1) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; QVERIFY(true); }; a(); if (assert_count != 1 || assert_fail_count != 0) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; QVERIFY(false); }; a(); if (assert_count != 1 || assert_fail_count != 1) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; QVERIFY2(true, ""); }; a(); if (assert_count != 1 || assert_fail_count != 0) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; QVERIFY2(false, ""); }; a(); if (assert_count != 1 || assert_fail_count != 1) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; QFAIL(""); }; a(); if (assert_count != 1 || assert_fail_count != 1) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; MYCASTCOMPARE(static_cast(1), static_cast(1)); }; a(); if (assert_count != 1 || assert_fail_count != 0) throw std::exception(); } { auto a = [] { assert_count = 0; assert_fail_count = 0; MYCASTCOMPARE(static_cast(1), static_cast(2)); }; a(); if (assert_count != 1 || assert_fail_count != 1) throw std::exception(); } } void MainTests::initBeforeEachTest(const std::vector &args, bool startServer) { mainApp.reset(); if (startServer) { mainApp = std::make_unique(args); mainApp->start(); mainApp->waitForStarted(); } // We test functions directly that the server normally only calls from worker threads, in which thread data is available. This is kind of a dummy-fix, until // we actually need correct thread data at those points (at this point, it's only to increase message counters). this->dummyThreadData = std::make_shared(666, settings, pluginLoader, std::weak_ptr()); ThreadGlobals::assignThreadData(dummyThreadData); ThreadGlobals::assignSettings(&this->settings); } void MainTests::initBeforeEachTest(bool startServer) { std::vector args; initBeforeEachTest(args, startServer); } void MainTests::cleanupAfterEachTest() { if (this->mainApp) this->mainApp->stopApp(); this->mainApp.reset(); } void MainTests::registerFunction(const std::string &name, std::function f, bool requiresServer, bool requiresInternet) { TestFunction &tf = testFunctions[name]; tf.f = f; tf.requiresServer = requiresServer; tf.requiresInternet = requiresInternet; } void MainTests::testDummy() { FMQ_COMPARE(1,1); } MainTests::MainTests() { static int instanceCount = 0; if (instanceCount++ > 0) throw std::runtime_error("Don't instantiate this more than once."); asserts_print = false; testAsserts(); asserts_print = true; assert_count = 0; assert_fail_count = 0; /* * Forking tests need to be done first (by alphabet), otherwise you break everything with DNS. This is a limitation * of getaddrinfo_a(). So, name them appropriately. */ REGISTER_FUNCTION(forkingTestForkingTestServer); REGISTER_FUNCTION(forkingTestSaveAndLoadDelayedWill); REGISTER_FUNCTION(forkingTestBridgeWithLocalAndRemotePrefix); REGISTER_FUNCTION(forkingTestBridgePrefixesOtherClientsUnaffected); REGISTER_FUNCTION(forkingTestBridgeWithOnlyRemotePrefix); REGISTER_FUNCTION(forkingTestBridgeWithOnlyLocalPrefix); REGISTER_FUNCTION(forkingTestBridgeOutgoingTopicEqualsPrefix); REGISTER_FUNCTION(forkingTestBridgeIncomingTopicEqualsPrefix); REGISTER_FUNCTION(forkingTestBridgeWithLocalAndRemotePrefixQueuedQoS); REGISTER_FUNCTION(forkingTestBridgeZWithLocalAndRemotePrefixRetained); REGISTER_FUNCTION(testDummy); REGISTER_FUNCTION3(test_circbuf); REGISTER_FUNCTION3(test_circbuf_unwrapped_doubling); REGISTER_FUNCTION3(test_circbuf_wrapped_doubling); REGISTER_FUNCTION3(test_circbuf_full_wrapped_buffer_doubling); REGISTER_FUNCTION3(test_cirbuf_vector_methods); REGISTER_FUNCTION3(test_validSubscribePath); REGISTER_FUNCTION(test_retained); REGISTER_FUNCTION(test_retained_double_set); REGISTER_FUNCTION(test_retained_mode_drop); REGISTER_FUNCTION(test_retained_mode_downgrade); REGISTER_FUNCTION(test_retained_mode_no_retain); REGISTER_FUNCTION(test_retained_changed); REGISTER_FUNCTION(test_retained_removed); REGISTER_FUNCTION(test_retained_tree); REGISTER_FUNCTION(test_retained_global_expire); REGISTER_FUNCTION(test_retained_per_message_expire); REGISTER_FUNCTION(test_retained_tree_purging); REGISTER_FUNCTION(testRetainAsPublished); REGISTER_FUNCTION(testRetainAsPublishedNegative); REGISTER_FUNCTION(testRetainedParentOfWildcard); REGISTER_FUNCTION(testRetainedWildcard); REGISTER_FUNCTION(testRetainedAclReadCheck); REGISTER_FUNCTION(testRetainHandlingDontGiveRetain); REGISTER_FUNCTION(testRetainHandlingDontGiveRetainOnExistingSubscription); REGISTER_FUNCTION(test_various_packet_sizes); REGISTER_FUNCTION3(test_acl_tree); REGISTER_FUNCTION3(test_acl_tree2); REGISTER_FUNCTION3(test_acl_patterns_username); REGISTER_FUNCTION3(test_acl_patterns_clientid); REGISTER_FUNCTION(test_loading_acl_file); REGISTER_FUNCTION3(test_loading_second_value); REGISTER_FUNCTION3(test_parsing_numbers); REGISTER_FUNCTION3(test_validUtf8Generic); #ifndef FMQ_NO_SSE REGISTER_FUNCTION3(test_sse_split); REGISTER_FUNCTION3(test_validUtf8Sse); REGISTER_FUNCTION3(test_utf8_nonchars); REGISTER_FUNCTION3(test_utf8_overlong); REGISTER_FUNCTION3(test_utf8_compare_implementation); #endif REGISTER_FUNCTION3(testPacketInt16Parse); REGISTER_FUNCTION3(testRetainedMessageDB); REGISTER_FUNCTION3(testRetainedMessageDBNotPresent); REGISTER_FUNCTION3(testRetainedMessageDBEmptyList); REGISTER_FUNCTION3(testSavingSessions); REGISTER_FUNCTION3(testParsePacket); REGISTER_FUNCTION3(testbufferToMqttPacketsFuzz); REGISTER_FUNCTION(testDowngradeQoSOnSubscribeQos2to2); REGISTER_FUNCTION(testDowngradeQoSOnSubscribeQos2to1); REGISTER_FUNCTION(testDowngradeQoSOnSubscribeQos2to0); REGISTER_FUNCTION(testDowngradeQoSOnSubscribeQos1to1); REGISTER_FUNCTION(testDowngradeQoSOnSubscribeQos1to0); REGISTER_FUNCTION(testDowngradeQoSOnSubscribeQos0to0); REGISTER_FUNCTION(testNotMessingUpQosLevels); REGISTER_FUNCTION(testUnSubscribe); REGISTER_FUNCTION(testUnsubscribeNonExistingWildcard); REGISTER_FUNCTION(testBasicsWithFlashMQTestClient); REGISTER_FUNCTION(testDontRemoveSessionGivenToNewClientWithSameId); REGISTER_FUNCTION(testKeepSubscriptionOnKickingOutExistingClientWithCleanSessionFalse); REGISTER_FUNCTION(testPickUpSessionWithSubscriptionsAfterDisconnect); REGISTER_FUNCTION(testMqtt3will); REGISTER_FUNCTION(testMqtt3NoWillOnDisconnect); REGISTER_FUNCTION(testMqtt5NoWillOnDisconnect); REGISTER_FUNCTION(testMqtt5DelayedWill); REGISTER_FUNCTION(testMqtt5DelayedWillAlwaysOnSessionEnd); REGISTER_FUNCTION(testWillOnSessionTakeOvers); REGISTER_FUNCTION(testOverrideWillDelayOnSessionDestructionByTakeOver); REGISTER_FUNCTION(testDisabledWills); REGISTER_FUNCTION(testMqtt5DelayedWillsDisabled); REGISTER_FUNCTION3(testStringDistances); REGISTER_FUNCTION3(testConfigSuggestion); REGISTER_FUNCTION3(testFlags); REGISTER_FUNCTION(testIncomingTopicAlias); REGISTER_FUNCTION(testOutgoingTopicAlias); REGISTER_FUNCTION(testOutgoingTopicAliasBeyondMax); REGISTER_FUNCTION(testOutgoingTopicAliasStoredPublishes); REGISTER_FUNCTION(testReceivingRetainedMessageWithQoS); REGISTER_FUNCTION(testQosDowngradeOnOfflineClients); REGISTER_FUNCTION(testPacketOrderOnSessionPickup); REGISTER_FUNCTION(testUserProperties); REGISTER_FUNCTION(testMessageExpiry); REGISTER_FUNCTION(testExpiredQueuedMessages); REGISTER_FUNCTION(testQoSPublishQueue); REGISTER_FUNCTION3(testQoSPublishQueueMemoryLeak); REGISTER_FUNCTION3(testTimePointToAge); REGISTER_FUNCTION(testMosquittoPasswordFile); REGISTER_FUNCTION(testOverrideAllowAnonymousToTrue); REGISTER_FUNCTION(testOverrideAllowAnonymousToFalse); REGISTER_FUNCTION(testKeepAllowAnonymousFalse); REGISTER_FUNCTION(testAllowAnonymousWithoutPasswordsLoaded); REGISTER_FUNCTION3(testAddrMatchesSubnetIpv4); REGISTER_FUNCTION3(testAddrMatchesSubnetIpv6); REGISTER_FUNCTION(testSharedSubscribersUnit); REGISTER_FUNCTION(testSharedSubscribers); REGISTER_FUNCTION(testDisconnectedSharedSubscribers); REGISTER_FUNCTION(testUnsubscribedSharedSubscribers); REGISTER_FUNCTION(testSharedSubscribersSurviveRestart); REGISTER_FUNCTION(testSharedSubscriberDoesntGetRetainedMessages); REGISTER_FUNCTION(testExtendedAuthOneStepSucceed); REGISTER_FUNCTION(testExtendedAuthOneStepDeny); REGISTER_FUNCTION(testExtendedAuthOneStepBadAuthMethod); REGISTER_FUNCTION(testExtendedAuthTwoStep); REGISTER_FUNCTION(testExtendedAuthTwoStepSecondStepFail); REGISTER_FUNCTION(testExtendedReAuth); REGISTER_FUNCTION(testExtendedReAuthTwoStep); REGISTER_FUNCTION(testExtendedReAuthFail); REGISTER_FUNCTION(testSimpleAuthAsync); REGISTER_FUNCTION(testFailedAsyncClientCrashOnSession); REGISTER_FUNCTION(testAsyncWithImmediateFollowUpPackets); REGISTER_FUNCTION(testAsyncWithException); REGISTER_FUNCTION(testPluginAuthFail); REGISTER_FUNCTION(testPluginAuthSucceed); REGISTER_FUNCTION(testPluginOnDisconnect); REGISTER_FUNCTION(testPluginGetClientAddress); REGISTER_FUNCTION(testChangePublish); REGISTER_FUNCTION(testClientRemovalByPlugin); REGISTER_FUNCTION(testSubscriptionRemovalByPlugin); REGISTER_FUNCTION(testPublishByPlugin); REGISTER_FUNCTION(testWillDenialByPlugin); REGISTER_FUNCTION(testPluginMainInit); REGISTER_FUNCTION2(testAsyncCurl, true, true); REGISTER_FUNCTION(testSubscribeWithoutRetainedDelivery); REGISTER_FUNCTION(testDontUpgradeWildcardDenyMode); REGISTER_FUNCTION(testAlsoDontApproveOnErrorInPluginWithWildcardDenyMode); REGISTER_FUNCTION(testDenyWildcardSubscription); REGISTER_FUNCTION(testUserPropertiesPresent); REGISTER_FUNCTION(testPublishInThread); REGISTER_FUNCTION(testPublishToItself); REGISTER_FUNCTION(testNoLocalPublishToItself); REGISTER_FUNCTION3(testTopicMatchingInSubscriptionTree); REGISTER_FUNCTION2(testDnsResolver, false, true); REGISTER_FUNCTION2(testDnsResolverDontCancel, false, true); REGISTER_FUNCTION2(testDnsResolverSecondQuery, false, true); REGISTER_FUNCTION2(testDnsResolverInvalid, false, true); REGISTER_FUNCTION(testGetResultWhenThereIsNone); REGISTER_FUNCTION(testWebsocketPing); REGISTER_FUNCTION(testWebsocketCorruptLengthFrame); REGISTER_FUNCTION(testWebsocketHugePing); REGISTER_FUNCTION(testWebsocketManyBigPingFrames); REGISTER_FUNCTION(testWebsocketClose); REGISTER_FUNCTION3(testStartsWith); REGISTER_FUNCTION3(testStringValuesParsing); REGISTER_FUNCTION3(testStringValuesParsingEscaping); REGISTER_FUNCTION3(testStringValuesFuzz); REGISTER_FUNCTION3(testStringValuesInvalid); REGISTER_FUNCTION2(testPreviouslyValidConfigFile, false, false); REGISTER_FUNCTION3(testNoCopy); REGISTER_FUNCTION3(testBase64); REGISTER_FUNCTION(testSessionTakeoverOtherUsername); REGISTER_FUNCTION(testCorrelationData); REGISTER_FUNCTION(testSubscriptionIdOnlineClient); REGISTER_FUNCTION(testSubscriptionIdOfflineClient); REGISTER_FUNCTION(testSubscriptionIdRetainedMessages); REGISTER_FUNCTION(testSubscriptionIdSharedSubscriptions); REGISTER_FUNCTION(testSubscriptionIdChange); REGISTER_FUNCTION(testSubscriptionIdOverlappingSubscriptions); } bool MainTests::test(bool skip_tests_with_internet, bool skip_server_tests, bool abort_on_first_fail, const std::vector &tests) { int testCount = 0; int testPassCount = 0; int testFailCount = 0; int testExceptionCount = 0; std::map *selectedTests = &this->testFunctions; std::map subset; for(const std::string &test_name : tests) { auto pos = this->testFunctions.find(test_name); if (pos == this->testFunctions.end()) { std::cerr << "Test '" << test_name << "' not found." << std::endl; return false; } subset[test_name] = pos->second; } if (!subset.empty()) { selectedTests = ⊂ } std::vector failedTests; for (const auto &pair : *selectedTests) { if (abort_on_first_fail && testFailCount > 0) break; const TestFunction &tf = pair.second; if (skip_tests_with_internet && tf.requiresInternet) continue; if (skip_server_tests && tf.requiresServer) continue; testCount++; try { std::cout << CYAN << "INIT" << COLOR_END << ": " << pair.first << std::endl; if (!isatty(2)) std::cerr << "INIT: " << pair.first << std::endl; TestInitializer testInitializer(this); testInitializer.init(tf.requiresServer); const int failCountBefore = assert_fail_count; const int assertCountBefore = assert_count; std::cout << CYAN << "RUN" << COLOR_END << ": " << pair.first << std::endl; if (!isatty(2)) std::cerr << "RUN: " << pair.first << std::endl; tf.f(); const int failCountAfter = assert_fail_count; const int assertCountAfter = assert_count; testInitializer.cleanup(); if (assertCountBefore == assertCountAfter) { std::cout << RED << "FAIL" << COLOR_END << ": " << pair.first << ": no asserts performed" << std::endl; testFailCount++; failedTests.push_back(pair.first); } else if (failCountBefore != failCountAfter) { std::cout << RED << "FAIL" << COLOR_END << ": " << pair.first << std::endl; testFailCount++; failedTests.push_back(pair.first); } else { std::cout << GREEN << "PASS" << COLOR_END << ": " << pair.first << std::endl; testPassCount++; } } catch (std::exception &ex) { // TODO: get details testFailCount++; testExceptionCount++; failedTests.push_back(pair.first); std::cout << RED << "FAIL EXCEPTION" << COLOR_END << ": " << pair.first << ": " << ex.what() << std::endl; } std::cout << std::endl; } Logger::getInstance()->quit(); std::cout << "Tests run: " << testCount << ". Passed: " << testPassCount << ". Failed: " << testFailCount << " (of which " << testExceptionCount << " exceptions). Total assertions: " << assert_count << "." << std::endl; std::cout << std::endl << std::endl; if (testCount == 0) { std::cout << std::endl << RED << "No tests ran." << COLOR_END << std::endl; return false; } else if (assert_fail_count == 0 && testFailCount == 0) { std::cout << std::endl << GREEN << "TESTS PASSED" << COLOR_END << std::endl; return true; } else { std::cout << "Failed tests: " << std::endl; for (const std::string &test_name : failedTests) { std::cout << " - " << test_name << std::endl; } std::cout << std::endl << RED << "TESTS FAILED" << COLOR_END << std::endl; return false; } } ================================================ FILE: FlashMQTests/maintests.h ================================================ #ifndef MAINTESTS_H #define MAINTESTS_H #include #include #include #include "mainappinthread.h" #define REGISTER_FUNCTION(name) registerFunction(#name, std::bind(&MainTests::name, this)) #define REGISTER_FUNCTION2(name, server, internet) registerFunction(#name, std::bind(&MainTests::name, this), server, internet) #define REGISTER_FUNCTION3(name) registerFunction(#name, std::bind(&MainTests::name, this), false, false) struct TestFunction { std::function f; bool requiresServer = true; bool requiresInternet = false; }; class MainTests { friend class TestInitializer; std::shared_ptr dummyThreadData; std::unique_ptr mainApp; Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::map testFunctions; void testAsserts(); void initBeforeEachTest(const std::vector &args, bool startServer=true); void initBeforeEachTest(bool startServer=true); void cleanupAfterEachTest(); void registerFunction(const std::string &name, std::function f, bool requiresServer=true, bool requiresInternet=false); // Compatability for porting the tests away from Qt. The function names are too vague so want to phase them out. void init(const std::vector &args) { initBeforeEachTest(args);} void init() {initBeforeEachTest();} void cleanup() {cleanupAfterEachTest();} void testParsePacketHelper(const std::string &topic, uint8_t from_qos, bool retain); void testTopicMatchingInSubscriptionTreeHelper(const std::string &subscribe_topic, const std::string &publish_topic, int match_count=1); void testDummy(); void test_circbuf(); void test_circbuf_unwrapped_doubling(); void test_circbuf_wrapped_doubling(); void test_circbuf_full_wrapped_buffer_doubling(); void test_cirbuf_vector_methods(); void test_validSubscribePath(); /** * Retain tests */ void test_retained(); void test_retained_double_set(); void test_retained_mode_drop(); void test_retained_mode_downgrade(); void test_retained_mode_no_retain(); void test_retained_changed(); void test_retained_removed(); void test_retained_tree(); void test_retained_global_expire(); void test_retained_per_message_expire(); void test_retained_tree_purging(); void testRetainAsPublished(); void testRetainAsPublishedNegative(); void testRetainedParentOfWildcard(); void testRetainedWildcard(); void testRetainedAclReadCheck(); void testRetainHandlingDontGiveRetain(); void testRetainHandlingDontGiveRetainOnExistingSubscription(); void test_various_packet_sizes(); void test_acl_tree(); void test_acl_tree2(); void test_acl_patterns_username(); void test_acl_patterns_clientid(); void test_loading_acl_file(); void test_loading_second_value(); void test_parsing_numbers(); void test_validUtf8Generic(); #ifndef FMQ_NO_SSE void test_sse_split(); void test_validUtf8Sse(); void test_utf8_nonchars(); void test_utf8_overlong(); void test_utf8_compare_implementation(); #endif void testPacketInt16Parse(); void testRetainedMessageDB(); void testRetainedMessageDBNotPresent(); void testRetainedMessageDBEmptyList(); void testSavingSessions(); void testParsePacket(); void testbufferToMqttPacketsFuzz(); void testDowngradeQoSOnSubscribeQos2to2(); void testDowngradeQoSOnSubscribeQos2to1(); void testDowngradeQoSOnSubscribeQos2to0(); void testDowngradeQoSOnSubscribeQos1to1(); void testDowngradeQoSOnSubscribeQos1to0(); void testDowngradeQoSOnSubscribeQos0to0(); void testNotMessingUpQosLevels(); void testUnSubscribe(); void testUnsubscribeNonExistingWildcard(); void testBasicsWithFlashMQTestClient(); void testDontRemoveSessionGivenToNewClientWithSameId(); void testKeepSubscriptionOnKickingOutExistingClientWithCleanSessionFalse(); void testPickUpSessionWithSubscriptionsAfterDisconnect(); /** * Will tests. */ void testMqtt3will(); void testMqtt3NoWillOnDisconnect(); void testMqtt5NoWillOnDisconnect(); void testMqtt5DelayedWill(); void testMqtt5DelayedWillAlwaysOnSessionEnd(); void testWillOnSessionTakeOvers(); void testOverrideWillDelayOnSessionDestructionByTakeOver(); void testDisabledWills(); void testMqtt5DelayedWillsDisabled(); void testStringDistances(); void testConfigSuggestion(); void testFlags(); void testIncomingTopicAlias(); void testOutgoingTopicAlias(); void testOutgoingTopicAliasBeyondMax(); void testOutgoingTopicAliasStoredPublishes(); void testReceivingRetainedMessageWithQoS(); void testQosDowngradeOnOfflineClients(); void testPacketOrderOnSessionPickup(); void testUserProperties(); void testMessageExpiry(); void testExpiredQueuedMessages(); void testQoSPublishQueue(); void testQoSPublishQueueMemoryLeak(); void testTimePointToAge(); void testMosquittoPasswordFile(); void testOverrideAllowAnonymousToTrue(); void testOverrideAllowAnonymousToFalse(); void testKeepAllowAnonymousFalse(); void testAllowAnonymousWithoutPasswordsLoaded(); void testAddrMatchesSubnetIpv4(); void testAddrMatchesSubnetIpv6(); /** * Shared subscriptions tests */ void testSharedSubscribersUnit(); void testSharedSubscribers(); void testDisconnectedSharedSubscribers(); void testUnsubscribedSharedSubscribers(); void testSharedSubscribersSurviveRestart(); void testSharedSubscriberDoesntGetRetainedMessages(); /** * Plugin tests */ void testExtendedAuthOneStepSucceed(); void testExtendedAuthOneStepDeny(); void testExtendedAuthOneStepBadAuthMethod(); void testExtendedAuthTwoStep(); void testExtendedAuthTwoStepSecondStepFail(); void testExtendedReAuth(); void testExtendedReAuthTwoStep(); void testExtendedReAuthFail(); void testSimpleAuthAsync(); void testFailedAsyncClientCrashOnSession(); void testAsyncWithImmediateFollowUpPackets(); void testAsyncWithException(); void testPluginAuthFail(); void testPluginAuthSucceed(); void testPluginOnDisconnect(); void testPluginGetClientAddress(); void testChangePublish(); void testClientRemovalByPlugin(); void testSubscriptionRemovalByPlugin(); void testPublishByPlugin(); void testWillDenialByPlugin(); void testPluginMainInit(); void testAsyncCurl(); void testSubscribeWithoutRetainedDelivery(); void testDontUpgradeWildcardDenyMode(); void testAlsoDontApproveOnErrorInPluginWithWildcardDenyMode(); void testDenyWildcardSubscription(); void testUserPropertiesPresent(); void testPublishInThread(); void testPublishToItself(); void testNoLocalPublishToItself(); void testTopicMatchingInSubscriptionTree(); void testDnsResolver(); void testDnsResolverDontCancel(); void testDnsResolverSecondQuery(); void testDnsResolverInvalid(); void testGetResultWhenThereIsNone(); void testWebsocketPing(); void testWebsocketCorruptLengthFrame(); void testWebsocketHugePing(); void testWebsocketManyBigPingFrames(); void testWebsocketClose(); void testStartsWith(); void forkingTestForkingTestServer(); void testStringValuesParsing(); void testStringValuesParsingEscaping(); void testStringValuesFuzz(); void testStringValuesInvalid(); void testPreviouslyValidConfigFile(); void forkingTestSaveAndLoadDelayedWill(); void testBase64(); void testNoCopy(); void testSessionTakeoverOtherUsername(); void testCorrelationData(); void testSubscriptionIdOnlineClient(); void testSubscriptionIdOfflineClient(); void testSubscriptionIdRetainedMessages(); void testSubscriptionIdSharedSubscriptions(); void testSubscriptionIdChange(); void testSubscriptionIdOverlappingSubscriptions(); void forkingTestBridgeWithLocalAndRemotePrefix(); void forkingTestBridgePrefixesOtherClientsUnaffected(); void forkingTestBridgeWithOnlyRemotePrefix(); void forkingTestBridgeWithOnlyLocalPrefix(); void forkingTestBridgeOutgoingTopicEqualsPrefix(); void forkingTestBridgeIncomingTopicEqualsPrefix(); void forkingTestBridgeZWithLocalAndRemotePrefixRetained(); void forkingTestBridgeWithLocalAndRemotePrefixQueuedQoS(); public: MainTests(); bool test(bool skip_tests_with_internet, bool skip_server_tests, bool abort_on_first_fail, const std::vector &tests); }; #endif // MAINTESTS_H ================================================ FILE: FlashMQTests/plugins/curlfunctions.cpp ================================================ #include "curlfunctions.h" #include #include "../../flashmq_plugin.h" #include "test_plugin.h" #include /** * @brief This is curl telling us what events to watch for. * @param easy * @param s * @param what * @param clientp * @param socketp * @return */ int socket_event_watch_notification(CURL *easy, curl_socket_t s, int what, void *clientp, void *socketp) { (void)easy; (void)clientp; (void)socketp; if (what == CURL_POLL_REMOVE) flashmq_poll_remove_fd(s); else { int events = 0; if (what == CURL_POLL_IN) events |= EPOLLIN; else if (what == CURL_POLL_OUT) events |= EPOLLOUT; else if (what == CURL_POLL_INOUT) events = EPOLLIN | EPOLLOUT; else return 1; flashmq_poll_add_fd(s, events, std::weak_ptr()); } return 0; } void check_all_active_curls(TestPluginData *p, CURLM *curlMulti) { CURLMsg *msg; int msgs_left; while((msg = curl_multi_info_read(curlMulti, &msgs_left))) { if (msg->msg == CURLMSG_DONE) { CURL *easy = msg->easy_handle; AuthenticatingClient *c = nullptr; curl_easy_getinfo(easy, CURLINFO_PRIVATE, &c); flashmq_logf(LOG_INFO, "Libcurl said: %s", curl_easy_strerror(msg->data.result)); std::string answer(c->response.data(), std::min(9, c->response.size())); if (answer == "client, AuthResult::success, std::string(), std::string(), 0); else flashmq_continue_async_authentication_v4(c->client, AuthResult::login_denied, std::string(), std::string(), 0); // Normally we have to have something in the AuthenticatingClient to look up which request to delete, but we only have one here. p->curlTestClient.reset(); } } } void call_timed_curl_multi_socket_action(CURLM *multi, TestPluginData *p) { p->current_timer = 0; int a = 0; int rc = curl_multi_socket_action(multi, CURL_SOCKET_TIMEOUT, 0, &a); /* Curl says: "When this function returns error, the state of all transfers are uncertain and they cannot be * continued. curl_multi_socket_action should not be called again on the same multi handle after an error has * been returned, unless first removing all the handles and adding new ones." */ if (rc != CURLM_OK) { // This would normally be removing all our pending requests, but we only have one here. p->curlTestClient.reset(); return; } check_all_active_curls(p, multi); } int timer_callback(CURLM *multi, long timeout_ms, void *clientp) { TestPluginData *p = static_cast(clientp); // We also remove the last known task before it executes if curl tells us to install a new one. This // is suggested by the unclear and incomplete example at https://curl.se/libcurl/c/CURLMOPT_TIMERFUNCTION.html. if (timeout_ms == -1 || p->current_timer > 0) { flashmq_remove_task(p->current_timer); p->current_timer = 0; } if (timeout_ms >= 0) { auto f = std::bind(&call_timed_curl_multi_socket_action, multi, p); p->current_timer = flashmq_add_task(f, timeout_ms); } return CURLM_OK; } size_t curl_write_cb(char *data, size_t n, size_t l, void *userp) { AuthenticatingClient *ac = static_cast(userp); int pos = ac->response.size(); ac->response.resize(ac->response.size() + n*l); std::memcpy(&ac->response[pos], data, n*l); return n*l; } ================================================ FILE: FlashMQTests/plugins/curlfunctions.h ================================================ #ifndef CURLFUNCTIONS_H #define CURLFUNCTIONS_H #include #include "test_plugin.h" int socket_event_watch_notification(CURL *easy, curl_socket_t s, int what, void *clientp, void *socketp); void check_all_active_curls(TestPluginData *p, CURLM *curlMulti); void call_timed_curl_multi_socket_action(CURLM *multi, TestPluginData *p); int timer_callback(CURLM *multi, long timeout_ms, void *clientp); size_t curl_write_cb(char *data, size_t n, size_t l, void *userp); #endif // CURLFUNCTIONS_H ================================================ FILE: FlashMQTests/plugins/test_plugin.cpp ================================================ #include #include #include #include #include #include #include "../../flashmq_plugin.h" #include "test_plugin.h" #include #include #include "curlfunctions.h" TestPluginData::TestPluginData() : curlMulti(curl_multi_init(), curl_multi_cleanup) { if (!curlMulti) throw std::runtime_error("Curl failed to init"); curl_multi_setopt(curlMulti.get(), CURLMOPT_SOCKETFUNCTION, socket_event_watch_notification); curl_multi_setopt(curlMulti.get(), CURLMOPT_TIMERFUNCTION, timer_callback); curl_multi_setopt(curlMulti.get(), CURLMOPT_TIMERDATA, this); } TestPluginData::~TestPluginData() { if (this->t.joinable()) t.join(); } void get_auth_result_delayed(std::weak_ptr client, AuthResult result) { usleep(500000); flashmq_continue_async_authentication_v4(client, result, "", "", 0); } int flashmq_plugin_version() { return FLASHMQ_PLUGIN_VERSION; } void flashmq_plugin_allocate_thread_memory(void **thread_data, std::unordered_map &plugin_opts) { TestPluginData *p = new TestPluginData(); *thread_data = p; (void)plugin_opts; } void flashmq_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map &plugin_opts) { (void)plugin_opts; TestPluginData *p = static_cast(thread_data); delete p; } void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t events, const std::weak_ptr &ptr) { (void)ptr; TestPluginData *p = static_cast(thread_data); int new_events = CURL_CSELECT_ERR; if (events & EPOLLIN) { new_events &= ~CURL_CSELECT_ERR; new_events |= CURL_CSELECT_IN; } if (events & EPOLLOUT) { new_events &= ~CURL_CSELECT_ERR; new_events |= CURL_CSELECT_OUT; } int n = -1; if (curl_multi_socket_action(p->curlMulti.get(), fd, new_events, &n) != CURLM_OK) { p->curlTestClient.reset(); return; } check_all_active_curls(p, p->curlMulti.get()); } void flashmq_plugin_init(void *thread_data, std::unordered_map &plugin_opts, bool reloading) { (void)thread_data; (void)plugin_opts; (void)reloading; TestPluginData *p = static_cast(thread_data); if (plugin_opts.find("main_init was here") != plugin_opts.end()) p->main_init_ran = true; } void flashmq_plugin_deinit(void *thread_data, std::unordered_map &plugin_opts, bool reloading) { (void)thread_data; (void)plugin_opts; (void)reloading; } void flashmq_plugin_periodic_event(void *thread_data) { (void)thread_data; } AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client) { (void)thread_data; (void)clientid; (void)username; (void)password; (void)userProperties; (void)client; if (username.find("async") == 0) { TestPluginData *p = static_cast(thread_data); p->c = client; AuthResult result = password == "success" ? AuthResult::success : AuthResult::login_denied; auto delayedResult = std::bind(&get_auth_result_delayed, p->c, result); if (p->t.joinable()) p->t.join(); p->t = std::thread(delayedResult); return AuthResult::async; } if (username == "failme") return AuthResult::login_denied; if (username == "getaddress") { struct sockaddr_storage addr_mem; struct sockaddr *addr = reinterpret_cast(&addr_mem); socklen_t addrlen = sizeof(addr_mem); std::string text; flashmq_get_client_address_v4(client, &text, addr, &addrlen); sockaddr sockaddr_after; std::memcpy(&sockaddr_after, addr, std::min(addrlen, sizeof(sockaddr))); flashmq_publish_message("getaddresstest/address", 0, false, text); if (sockaddr_after.sa_family == AF_INET) { flashmq_publish_message("getaddresstest/family", 0, false, "AF_INET"); } } if (username == "curl") { TestPluginData *p = static_cast(thread_data); p->curlTestClient = std::make_unique(); p->curlTestClient->client = client; curl_easy_setopt(p->curlTestClient->easy_handle.get(), CURLOPT_WRITEFUNCTION, curl_write_cb); curl_easy_setopt(p->curlTestClient->easy_handle.get(), CURLOPT_WRITEDATA, p->curlTestClient.get()); curl_easy_setopt(p->curlTestClient->easy_handle.get(), CURLOPT_PRIVATE, p->curlTestClient.get()); // Keep in mind that DNS resovling may be blocking too. You could perhaps resolve the DNS once and use the result. curl_easy_setopt(p->curlTestClient->easy_handle.get(), CURLOPT_URL, "http://www.google.com/"); p->curlTestClient->addToMulti(p->curlMulti); return AuthResult::async; } return AuthResult::success; } void publish_in_thread() { flashmq_publish_message("topic/from/thread", 0, false, "payload from thread"); } AuthResult flashmq_plugin_acl_check( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties) { (void)thread_data; (void)access; (void)clientid; (void)username; (void)subtopics; (void)qos; (void)retain; (void)correlationData; (void)responseTopic; (void)userProperties; (void)shareName; (void)contentType; (void)expiresAt; if (clientid == "return_error") return AuthResult::error; if (access == AclAccess::subscribe && clientid == "success_without_retained_delivery") return AuthResult::success_without_retained_delivery; if (clientid == "test_user_without_retain_as_published_CswU21YA" && access == AclAccess::read) assert(!retain); if (clientid == "test_user_with_retain_as_published_v8sIeCvI" && access == AclAccess::read) assert(retain); assert(access == AclAccess::subscribe || !payload.empty()); if (access == AclAccess::register_will && topic == "will/disallowed") return AuthResult::acl_denied; if (topic == "removeclient" || topic == "removeclientandsession") { std::weak_ptr ses; flashmq_get_session_pointer(clientid, username, ses); flashmq_plugin_remove_client_v4(ses, topic == "removeclientandsession", ServerDisconnectReasons::NormalDisconnect); } if (clientid == "unsubscribe" && access == AclAccess::write) { std::weak_ptr session; flashmq_get_session_pointer(clientid, username, session); flashmq_plugin_remove_subscription_v4(session, topic); } if (clientid == "generate_publish") { flashmq_logf(LOG_INFO, "Publishing from plugin."); const std::string topic = "generated/topic"; const std::string payload = "money"; flashmq_publish_message(topic, 0, false, payload); } if ((access == AclAccess::read || access == AclAccess::write) && topic == "test_user_property") { assert(userProperties); } if (topic == "publish_in_thread" && access == AclAccess::write) { std::thread t(publish_in_thread); pthread_setname_np(t.native_handle(), "PubInThread"); t.join(); } return AuthResult::success; } AuthResult flashmq_plugin_extended_auth(void *thread_data, const std::string &clientid, ExtendedAuthStage stage, const std::string &authMethod, const std::string &authData, const std::vector> *userProperties, std::string &returnData, std::string &username, const std::weak_ptr &client) { (void)thread_data; (void)stage; (void)authMethod; (void)authData; (void)username; (void)clientid; (void)userProperties; (void)returnData; (void)client; if (authMethod == "always_good_passing_back_the_auth_data") { if (authData == "actually not good.") return AuthResult::login_denied; returnData = authData; return AuthResult::success; } if (authMethod == "always_fail") { return AuthResult::login_denied; } if (authMethod == "two_step") { if (authData == "Hello") returnData = "Hello back"; if (authData == "grant me already!") { returnData = "OK, if you insist."; return AuthResult::success; } else if (authData == "whoops, wrong data.") return AuthResult::login_denied; else return AuthResult::auth_continue; } return AuthResult::auth_method_not_supported; } bool flashmq_plugin_alter_publish( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::optional &correlationData, std::optional &responseTopic, std::optional &contentType, std::vector> *userProperties) { (void)thread_data; (void)clientid; (void)subtopics; (void)qos; (void)retain; (void)correlationData; (void)responseTopic; (void)userProperties; (void)contentType; TestPluginData *p = static_cast(thread_data); assert(!payload.empty()); if (topic == "changeme") { topic = "changed"; qos = 2; return true; } if (topic == "check_main_init_presence" && p->main_init_ran) { topic = "check_main_init_presence_confirmed"; return true; } return false; } void flashmq_plugin_client_disconnected(void *thread_data, const std::string &clientid) { (void)thread_data; flashmq_logf(LOG_INFO, "flashmq_plugin_client_disconnected called for '%s'", clientid.c_str()); flashmq_publish_message("disconnect/confirmed", 0, false, "adsf"); } void flashmq_plugin_main_init(std::unordered_map &plugin_opts) { (void)plugin_opts; flashmq_logf(LOG_INFO, "The tester was here."); // The plugin_opts aren't const. I don't know if that was a mistake or not anymore, but it works in my favor now. plugin_opts["main_init was here"] = "true"; if (curl_global_init(CURL_GLOBAL_ALL) != 0) throw std::runtime_error("Global curl init failed to init"); } void flashmq_plugin_main_deinit(std::unordered_map &plugin_opts) { (void)plugin_opts; curl_global_cleanup(); } AuthenticatingClient::AuthenticatingClient() : easy_handle(curl_easy_init(), curl_easy_cleanup) { } AuthenticatingClient::~AuthenticatingClient() { auto x = registeredAtMultiHandle.lock(); if (x) { curl_multi_remove_handle(x.get(), easy_handle.get()); } } void AuthenticatingClient::addToMulti(std::shared_ptr &curlMulti) { if (curl_multi_add_handle(curlMulti.get(), easy_handle.get()) != CURLM_OK) throw std::runtime_error("curl_multi_add_handle failed"); registeredAtMultiHandle = curlMulti; } ================================================ FILE: FlashMQTests/plugins/test_plugin.h ================================================ #ifndef TESTPLUGIN_H #define TESTPLUGIN_H #include #include #include #include "../../forward_declarations.h" #include struct AuthenticatingClient { std::weak_ptr client; std::vector response; std::unique_ptr easy_handle; std::weak_ptr registeredAtMultiHandle; public: AuthenticatingClient(); ~AuthenticatingClient(); void addToMulti(std::shared_ptr &curlMulti); }; class TestPluginData { public: std::thread t; std::weak_ptr c; bool main_init_ran = false; std::shared_ptr curlMulti; uint32_t current_timer = 0; // Normally we keep some kind of indexed record of requests, but in our test plugin, we just track one. std::unique_ptr curlTestClient; public: TestPluginData(); ~TestPluginData(); }; #endif // TESTPLUGIN_H ================================================ FILE: FlashMQTests/plugins/test_plugin.pro ================================================ QT -= gui CONFIG += c++17 TARGET = test_plugin TEMPLATE = lib VERSION=0.0.1 LIBS += -lcurl HEADERS += test_plugin.h \ curlfunctions.h SOURCES += test_plugin.cpp \ curlfunctions.cpp ================================================ FILE: FlashMQTests/plugintests.cpp ================================================ #include "maintests.h" #include "testhelpers.h" #include "conffiletemp.h" #include "flashmqtestclient.h" #include void MainTests::testWillDenialByPlugin() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "will/allowed"; will->payload = "mypayload"; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt311); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe("will/+", 0); sender.reset(); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MqttPacket pubPack = ro->receivedPublishes.front(); std::shared_ptr client = receiver.getClient(); pubPack.parsePublishData(client); QCOMPARE(pubPack.getPublishData().topic, "will/allowed"); QCOMPARE(pubPack.getPublishData().payload, "mypayload"); QCOMPARE(pubPack.getPublishData().qos, 0); } receiver.clearReceivedLists(); // Now set a will that we will deny. { sender = std::make_unique(); sender->start(); std::shared_ptr will2 = std::make_shared(); will2->topic = "will/disallowed"; will2->payload = "mypayload"; sender->setWill(will2); sender->connectClient(ProtocolVersion::Mqtt311); sender.reset(); usleep(500000); receiver.waitForMessageCount(0); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } } void MainTests::testPluginAuthFail() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); for (ProtocolVersion &version : versions) { FlashMQTestClient client; client.start(); client.connectClient(version, false, 120, [](Connect &connect) { connect.username = "failme"; connect.password = "boo"; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); if (version >= ProtocolVersion::Mqtt5) QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); else QVERIFY(static_cast(connAckData.reasonCode) == 5); } } void MainTests::testPluginAuthSucceed() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); for (ProtocolVersion &version : versions) { FlashMQTestClient client; client.start(); client.connectClient(version, false, 120, [](Connect &connect) { connect.username = "passme"; connect.password = "boo"; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); } } void MainTests::testExtendedAuthOneStepSucceed() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "always_good_passing_back_the_auth_data"; connect.authenticationData = "I have a proposal to put to ye."; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); QVERIFY(connAckData.authData == "I have a proposal to put to ye."); } void MainTests::testExtendedAuthOneStepDeny() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "always_fail"; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); } void MainTests::testExtendedAuthOneStepBadAuthMethod() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "doesnt_exist"; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::BadAuthenticationMethod); } void MainTests::testExtendedAuthTwoStep() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "two_step"; connect.authenticationData = "Hello"; }); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); AuthPacketData authData = ro->receivedPackets.front().parseAuthData(); QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); QVERIFY(authData.data == "Hello back"); } client.clearReceivedLists(); const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); client.writeAuth(auth); client.waitForConnack(); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); QVERIFY(connAckData.authData == "OK, if you insist."); } void MainTests::testExtendedAuthTwoStepSecondStepFail() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "two_step"; connect.authenticationData = "Hello"; }); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); AuthPacketData authData = ro->receivedPackets.front().parseAuthData(); QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); QVERIFY(authData.data == "Hello back"); } client.clearReceivedLists(); const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "whoops, wrong data."); client.writeAuth(auth); client.waitForConnack(); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); } void MainTests::testExtendedReAuth() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "always_good_passing_back_the_auth_data"; connect.authenticationData = "Santa Claus"; }); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); } client.clearReceivedLists(); // Then reauth. Auth auth(ReasonCodes::ContinueAuthentication, "always_good_passing_back_the_auth_data", "Again Santa Claus"); client.writeAuth(auth); client.waitForConnack(); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); AuthPacketData authData = ro->receivedPackets.front().parseAuthData(); QVERIFY(authData.reasonCode == ReasonCodes::Success); QVERIFY(authData.data == "Again Santa Claus"); } void MainTests::testExtendedReAuthTwoStep() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "two_step"; connect.authenticationData = "Hello"; }); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); AuthPacketData authData = ro->receivedPackets.front().parseAuthData(); QVERIFY(authData.reasonCode == ReasonCodes::ContinueAuthentication); QVERIFY(authData.data == "Hello back"); } client.clearReceivedLists(); const Auth auth(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); client.writeAuth(auth); client.waitForConnack(); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); QVERIFY(connAckData.authData == "OK, if you insist."); } client.clearReceivedLists(); // Then reauth. const Auth reauth(ReasonCodes::ReAuthenticate, "two_step", "Hello"); client.writeAuth(reauth); client.waitForConnack(); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); AuthPacketData reauthData = ro->receivedPackets.front().parseAuthData(); QVERIFY(reauthData.reasonCode == ReasonCodes::ContinueAuthentication); QVERIFY(reauthData.data == "Hello back"); } client.clearReceivedLists(); const Auth reauthFinish(ReasonCodes::ContinueAuthentication, "two_step", "grant me already!"); client.writeAuth(reauthFinish); client.waitForConnack(); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); AuthPacketData reauthFinishData = ro->receivedPackets.front().parseAuthData(); QVERIFY(reauthFinishData.reasonCode == ReasonCodes::Success); QVERIFY(reauthFinishData.data == "OK, if you insist."); } } void MainTests::testExtendedReAuthFail() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "me"; connect.password = "me me"; connect.authenticationMethod = "always_good_passing_back_the_auth_data"; connect.authenticationData = "I have a proposal to put to ye."; }); { auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); QVERIFY(connAckData.authData == "I have a proposal to put to ye."); } client.clearReceivedLists(); // Then reauth. { const Auth reauth(ReasonCodes::ReAuthenticate, "always_good_passing_back_the_auth_data", "actually not good."); client.writeAuth(reauth); client.waitForPacketCount(1); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); QVERIFY(ro->receivedPackets.front().packetType == PacketType::DISCONNECT); DisconnectData data = ro->receivedPackets.front().parseDisconnectData(); QVERIFY(data.reasonCode == ReasonCodes::NotAuthorized); } } void MainTests::testSimpleAuthAsync() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::list results { "success", "fail" }; for (std::string &result : results) { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [&](Connect &connect) { connect.username = "async"; connect.password = result; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); if (result == "success") QVERIFY(connAckData.reasonCode == ReasonCodes::Success); else QVERIFY(connAckData.reasonCode == ReasonCodes::NotAuthorized); } } /** * There was a crash when doing session stuff with a client that was rejected by exception * in continuationOfAuthentication (by duplicate session id between different users). */ void MainTests::testFailedAsyncClientCrashOnSession() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::list clients; clients.emplace_back(); clients.back().start(); clients.back().connectClient(ProtocolVersion::Mqtt5, false, 120, [&](Connect &connect) { connect.username = "async1"; connect.password = "success"; connect.clientid = "duplicate"; }, 21883); clients.emplace_back(); clients.back().start(); clients.back().connectClient(ProtocolVersion::Mqtt5, false, 120, [&](Connect &connect) { connect.username = "async2"; connect.password = "success"; connect.clientid = "duplicate"; }, 21883, false); // Because it crashed, we don't have events to wait for. std::this_thread::sleep_for(std::chrono::seconds(1)); /* * An if statement in a test is dodgy, but the underlying client was made weak after this * test was written, so when the original problem occurs, the client hasn't disconnected nor expired. */ if (!clients.back().clientExpired()) { Publish pub("sdf", "wer", 2); MqttPacket pubPack(ProtocolVersion::Mqtt5, pub); if (pub.qos > 0) pubPack.setPacketId(3); clients.back().getClient()->writeMqttPacketAndBlameThisClient(pubPack); } std::this_thread::sleep_for(std::chrono::milliseconds(500)); // Test if the server still works after that. clients.front().clearReceivedLists(); clients.front().publish("sdf", "sfd", 2); auto ro = clients.front().receivedObjects.lock(); FMQ_VERIFY(!ro->receivedPackets.empty()); FMQ_COMPARE(ro->receivedPackets.back().packetType, PacketType::PUBCOMP); } /** * We send a subscription directly after authentication, without waiting for it. These should be queued and dealt * with after the login was approved. */ void MainTests::testAsyncWithImmediateFollowUpPackets() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [&](Connect &connect) { connect.username = "async"; connect.password = "success"; }, 21883, false); { const uint16_t packet_id = 66; std::vector subs; subs.emplace_back("our/random/topic", 0); MqttPacket subPack(client.getProtocolVersion(), packet_id, 0, subs); client.getClient()->writeMqttPacketAndBlameThisClient(subPack); } client.waitForConnack(); client.getClient()->setAuthenticated(true); { auto ro = client.receivedObjects.lock(); FMQ_VERIFY(ro->receivedPackets.size() >= 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); } client.clearReceivedLists(); { FlashMQTestClient client2; client2.start(); client2.connectClient(ProtocolVersion::Mqtt311); client2.publish("our/random/topic", "1Xs0QC5XInKLGHfm", 0); } client.waitForMessageCount(1); { auto ro = client.receivedObjects.lock(); MqttPacket &p = ro->receivedPublishes.front(); FMQ_COMPARE(p.getTopic(), "our/random/topic"); FMQ_COMPARE(p.getPayloadCopy(), "1Xs0QC5XInKLGHfm"); } } /** * We're testing behavior that throws an internal FlashMQ exception in handling async auth, namely trying * to take over session with a different username. An exception inside the plugin would not be a good test, * because that is handled locally by the authentication. */ void MainTests::testAsyncWithException() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); // ========= for (ProtocolVersion p : {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}) { FlashMQTestClient client1; client1.start(); client1.connectClient(p, true, 600, [](Connect &connect) { connect.clientid = "thesameclientid"; connect.username = "async0"; connect.password = "success"; }); { auto ro = client1.receivedObjects.lock(); auto &pack = ro->receivedPackets.at(0); FMQ_COMPARE(pack.packetType, PacketType::CONNACK); ConnAckData ackData = pack.parseConnAckData(); FMQ_COMPARE(ackData.reasonCode, ReasonCodes::Success); } FlashMQTestClient client2; client2.start(); client2.connectClient(p, true, 600, [](Connect &connect) { connect.clientid = "thesameclientid"; connect.username = "async1"; connect.password = "success"; }); { auto ro = client2.receivedObjects.lock(); auto &pack = ro->receivedPackets.at(0); FMQ_COMPARE(pack.packetType, PacketType::CONNACK); ConnAckData ackData = pack.parseConnAckData(); int expectedCode = p == ProtocolVersion::Mqtt5 ? static_cast(ReasonCodes::NotAuthorized) : static_cast(ConnAckReturnCodes::NotAuthorized); FMQ_COMPARE(static_cast(ackData.reasonCode), expectedCode); } } } void MainTests::testClientRemovalByPlugin() { std::list methods { "removeclient", "removeclientandsession"}; ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; for (std::string &method : methods) { cleanup(); init(args); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5, false, 120); const std::string sender_client_id = sender.getClientId(); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("#", 2); sender.publish(method, "asdf", 0); sender.waitForDisconnectPacket(); auto ro = sender.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); QVERIFY(ro->receivedPackets.front().packetType == PacketType::DISCONNECT); std::shared_ptr store = globals->subscriptionStore; std::shared_ptr session = store->lockSession(sender_client_id); if (method == "removeclient") { QVERIFY(session); } else { QVERIFY(!session); } } } void MainTests::testSubscriptionRemovalByPlugin() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.clientid = "unsubscribe"; }); receiver.subscribe("a/b/c", 2); FlashMQTestClient dummyreceiver; dummyreceiver.start(); dummyreceiver.connectClient(ProtocolVersion::Mqtt5); dummyreceiver.subscribe("#", 2); sender.publish("a/b/c", "asdf", 0); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.size() == 1); } receiver.clearReceivedLists(); sender.clearReceivedLists(); receiver.publish("a/b/c", "sdf", 0); // Because the clientid of this client, this will unsubscribe. dummyreceiver.clearReceivedLists(); // A hack way to make sure the relevant thread has done the unsubscribe (by doing work we can detect). const int nprocs = get_nprocs(); for (int i = 0; i < nprocs; i++) receiver.publish("waitforthis", "sdf", 0); dummyreceiver.waitForPacketCount(nprocs); receiver.clearReceivedLists(); dummyreceiver.clearReceivedLists(); sender.publish("a/b/c", "asdf", 0); usleep(200000); receiver.waitForMessageCount(0); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } } void MainTests::testPublishByPlugin() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.clientid = "generate_publish"; }); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, false, 120); receiver.subscribe("#", 2); sender.publish("boo", "booboo", 0); receiver.waitForMessageCount(2); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 2); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return p.getTopic() == "boo"; })); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return p.getTopic() == "generated/topic"; })); } void MainTests::testChangePublish() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); for (ProtocolVersion &version : versions) { FlashMQTestClient sender; sender.start(); sender.connectClient(version, false, 120); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(version, false, 100); receiver.subscribe("#", 2); FlashMQTestClient receiver_of_pattern; receiver_of_pattern.start(); receiver_of_pattern.connectClient(version, false, 100); receiver_of_pattern.subscribe("changed", 2); sender.publish("changeme", "hello", 1); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MYCASTCOMPARE(ro->receivedPublishes.front().getTopic(), "changed"); MYCASTCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), "hello"); MYCASTCOMPARE(ro->receivedPublishes.front().getQos(), 2); } receiver_of_pattern.waitForMessageCount(1); { auto ro = receiver_of_pattern.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MYCASTCOMPARE(ro->receivedPublishes.front().getTopic(), "changed"); MYCASTCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), "hello"); MYCASTCOMPARE(ro->receivedPublishes.front().getQos(), 2); } } } void MainTests::testPluginOnDisconnect() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("#", 0); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5); client.disconnect(ReasonCodes::Success); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QCOMPARE(ro->receivedPublishes.front().getTopic(), "disconnect/confirmed"); } } void MainTests::testPluginGetClientAddress() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("getaddresstest/#", 0); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "getaddress"; }); try { receiver.waitForMessageCount(2); } catch(std::exception &e) { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 2); } auto ro = receiver.receivedObjects.lock(); QCOMPARE(ro->receivedPublishes[0].getTopic(), "getaddresstest/address"); QCOMPARE(ro->receivedPublishes[0].getPayloadCopy(), "127.0.0.1"); QCOMPARE(ro->receivedPublishes[1].getTopic(), "getaddresstest/family"); QCOMPARE(ro->receivedPublishes[1].getPayloadCopy(), "AF_INET"); } void MainTests::testPluginMainInit() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, false, 120); receiver.subscribe("#", 2); sender.publish("check_main_init_presence", "booboo", 0); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MYCASTCOMPARE(ro->receivedPublishes.front().getTopic(), "check_main_init_presence_confirmed"); } void MainTests::testAsyncCurl() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "curl"; connect.password = "boo"; }); auto ro = client.receivedObjects.lock(); QVERIFY(ro->receivedPackets.size() == 1); ConnAckData connAckData = ro->receivedPackets.front().parseConnAckData(); QVERIFY(connAckData.reasonCode == ReasonCodes::Success); } void MainTests::testSubscribeWithoutRetainedDelivery() { // Control case without plugin loaded. { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5, true, 0, [] (Connect &connect){ connect.clientid = "success_without_retained_delivery"; }); receiver.subscribe(topic, 0); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); } ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5, true, 0, [] (Connect &connect){ connect.clientid = "success_without_retained_delivery"; }); receiver.subscribe(topic, 0); usleep(250000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } sender.publish("retaintopic/one/two/three", "on-line payload", 0); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QVERIFY(ro->receivedPublishes.front().getPayloadView() == "on-line payload"); } } void MainTests::testDontUpgradeWildcardDenyMode() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.writeLine("minimum_wildcard_subscription_depth 2"); confFile.writeLine("wildcard_subscription_deny_mode deny_retained_only"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("#", 0); usleep(250000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } sender.publish("retaintopic/one/two/three", "on-line payload", 0); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QVERIFY(ro->receivedPublishes.front().getPayloadView() == "on-line payload"); } } void MainTests::testAlsoDontApproveOnErrorInPluginWithWildcardDenyMode() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.writeLine("minimum_wildcard_subscription_depth 2"); confFile.writeLine("wildcard_subscription_deny_mode deny_retained_only"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5, true, 0, [] (Connect &connect){ connect.clientid = "return_error"; }); bool suback_errored = false; try { receiver.subscribe("#", 0); } catch (SubAckIsError &ex) { suback_errored = true; } QVERIFY(suback_errored); usleep(250000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } sender.publish("retaintopic/one/two/three", "on-line payload", 0); usleep(250000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } } void MainTests::testDenyWildcardSubscription() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.writeLine("minimum_wildcard_subscription_depth 2"); confFile.writeLine("wildcard_subscription_deny_mode deny_all"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5, true, 0, [] (Connect &connect){ connect.clientid = "success_without_retained_delivery"; }); bool suback_errored = false; try { receiver.subscribe("bla/#", 0); } catch (SubAckIsError &ex) { suback_errored = true; } QVERIFY(suback_errored); } void MainTests::testUserPropertiesPresent() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "werwer payload"; const std::string topic = "test_user_property"; sender.connectClient(ProtocolVersion::Mqtt5); receiver.connectClient(ProtocolVersion::Mqtt5, true, 0, [] (Connect &connect){ connect.clientid = "qq5HD9s9VDomlF2l"; }); receiver.subscribe(topic, 0); Publish pub1(topic, payload, 0); pub1.addUserProperty("myprop", "myval"); sender.publish(pub1); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QVERIFY(ro->receivedPublishes.front().getPayloadView() == payload); QVERIFY(ro->receivedPublishes.front().getUserProperties() != nullptr); auto props = ro->receivedPublishes.front().getUserProperties(); QVERIFY(std::any_of(props->begin(), props->end(), [](std::pair &p) { return p.first == "myprop" && p.second == "myval"; })); } void MainTests::testPublishInThread() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); sender.connectClient(ProtocolVersion::Mqtt5); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("topic/from/thread", 0); Publish pub1("publish_in_thread", "dummy", 0); sender.publish(pub1); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QVERIFY(ro->receivedPublishes.front().getPayloadView() == "payload from thread"); } ================================================ FILE: FlashMQTests/retaintests.cpp ================================================ #include "maintests.h" #include "flashmqtestclient.h" #include "conffiletemp.h" #include "testhelpers.h" #include "utils.h" #include "retainedmessagesdb.h" void MainTests::test_retained() { std::vector protocols {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; for (const ProtocolVersion senderVersion : protocols) { for (const ProtocolVersion receiverVersion : protocols) { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing"; const std::string topic = "retaintopic"; sender.connectClient(senderVersion); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); Publish pub2("dummy2", "Nobody sees this", 0); pub2.retain = true; sender.publish(pub2); receiver.connectClient(receiverVersion); receiver.subscribe("dummy", 0); receiver.subscribe(topic, 0); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &msg = ro->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), payload); QCOMPARE(msg.getTopic(), topic); QVERIFY(msg.getRetain()); ro->clear(); } sender.publish(pub1); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); QVERIFY2(ro->receivedPublishes.size() == 1, "There must be one message in the received list"); MqttPacket &msg2 = ro->receivedPublishes.front(); QCOMPARE(msg2.getPayloadCopy(), payload); QCOMPARE(msg2.getTopic(), topic); QVERIFY2(!msg2.getRetain(), "Getting a retained message while already being subscribed must be marked as normal, not retain."); } } } } /** * @brief MainTests::test_retained_double_set Test incepted because of different locking paths in first tree node and second tree node. */ void MainTests::test_retained_double_set() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string topic = "one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); { Publish pub("one", "dummy node creator", 0); pub.retain = true; pub.qos = 1; sender.publish(pub); } Publish pub1(topic, "nobody sees this", 0); pub1.retain = true; pub1.qos = 1; sender.publish(pub1); pub1.payload = "We are setting twice"; sender.publish(pub1); { // Confirm the retained messages are present. FlashMQTestClient c; c.start(); c.connectClient(ProtocolVersion::Mqtt5); c.subscribe("one/#", 0); c.waitForMessageCount(2); } receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("one/#", 0); receiver.waitForMessageCount(2); MYCASTCOMPARE(receiver.receivedObjects.lock()->receivedPublishes.size(), 2); { auto ro = receiver.receivedObjects.lock(); auto msg = std::find_if(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](const MqttPacket &p) {return p.getTopic() == "one";}); FMQ_VERIFY(msg != ro->receivedPublishes.end()); QCOMPARE(msg->getPayloadCopy(), "dummy node creator"); QCOMPARE(msg->getTopic(), "one"); QVERIFY(msg->getRetain()); } { auto ro = receiver.receivedObjects.lock(); auto msg2 = std::find_if(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [&](const MqttPacket &p) {return p.getTopic() == topic;}); FMQ_VERIFY(msg2 != ro->receivedPublishes.end()); QCOMPARE(msg2->getPayloadCopy(), "We are setting twice"); QCOMPARE(msg2->getTopic(), topic); QVERIFY(msg2->getRetain()); } } /** * @brief MainTests::test_retained_disabled Copied from test_retained and adjusted */ void MainTests::test_retained_mode_drop() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("retained_messages_mode drop"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::vector protocols {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; for (const ProtocolVersion senderVersion : protocols) { for (const ProtocolVersion receiverVersion : protocols) { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing"; const std::string topic = "retaintopic"; sender.connectClient(senderVersion); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); Publish pub2("dummy2", "Nobody sees this", 0); pub2.retain = true; sender.publish(pub2); receiver.connectClient(receiverVersion); receiver.subscribe("dummy", 0); receiver.subscribe(topic, 0); usleep(250000); receiver.waitForMessageCount(0); { auto ro = receiver.receivedObjects.lock(); QVERIFY2(ro->receivedPublishes.empty(), "In drop mode, retained publishes should be stored as retained messages."); } receiver.clearReceivedLists(); sender.publish(pub1); usleep(250000); receiver.waitForMessageCount(0); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } } } } /** * @brief MainTests::test_retained_mode_downgrade copied from test_retained and adjusted */ void MainTests::test_retained_mode_downgrade() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("retained_messages_mode downgrade"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::vector protocols {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; for (const ProtocolVersion senderVersion : protocols) { for (const ProtocolVersion receiverVersion : protocols) { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing"; const std::string topic = "retaintopic"; sender.connectClient(senderVersion); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); Publish pub2("dummy2", "Nobody sees this", 0); pub2.retain = true; sender.publish(pub2); receiver.connectClient(receiverVersion); receiver.subscribe("dummy", 0); receiver.subscribe(topic, 0, false, true); usleep(250000); receiver.waitForMessageCount(0); { auto ro = receiver.receivedObjects.lock(); QVERIFY2(ro->receivedPublishes.empty(), "In downgrade mode, retained publishes should not be stored as retained messages."); } receiver.clearReceivedLists(); sender.publish(pub1); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &msg2 = ro->receivedPublishes.front(); QCOMPARE(msg2.getPayloadCopy(), payload); QCOMPARE(msg2.getTopic(), topic); QVERIFY2(!msg2.getRetain(), "Getting a retained message while already being subscribed must be marked as normal, not retain."); } } } } /** * @brief Tests 'enabled_without_retaining', which relays the message with the 'retain' flag set, but does not retain. */ void MainTests::test_retained_mode_no_retain() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("retained_messages_mode enabled_without_retaining"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing"; const std::string topic = "retaintopic/foo/bar"; sender.connectClient(ProtocolVersion::Mqtt5); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe(topic, 0, false, true); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.size() == 1); MqttPacket &msg = ro->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), payload); QCOMPARE(msg.getTopic(), topic); QVERIFY2(msg.getRetain(), "We were supposed to have seen the retain flag, because 'retain as published' was on."); } FlashMQTestClient late_receiver; late_receiver.start(); late_receiver.connectClient(ProtocolVersion::Mqtt5); late_receiver.subscribe("#", 0); usleep(250000); receiver.waitForMessageCount(0); QVERIFY2(late_receiver.receivedObjects.lock()->receivedPublishes.empty(), "In enabled_without_retaining mode, retained publishes should not be stored as retained messages."); } void MainTests::test_retained_changed() { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); const std::string topic = "retaintopic"; Publish p(topic, "We are testing", 0); p.retain = true; p.qos = 1; sender.publish(p); p.payload = "Changed payload"; sender.publish(p); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe(topic, 0); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &pack = ro->receivedPublishes.front(); QCOMPARE(pack.getPayloadCopy(), p.payload); QVERIFY(pack.getRetain()); } void MainTests::test_retained_removed() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); std::string payload = "We are testing"; std::string topic = "retaintopic"; sender.connectClient(ProtocolVersion::Mqtt311); Publish pub1(topic, payload, 1); pub1.retain = true; sender.publish(pub1); pub1.payload = ""; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe(topic, 0); usleep(100000); receiver.waitForMessageCount(0); auto ro = receiver.receivedObjects.lock(); QVERIFY2(ro->receivedPublishes.empty(), "We erased the retained message. We shouldn't have received any."); } /** * @brief MainTests::test_retained_tree tests a bug I found, where '+/+' yields different results than '#', where it should be the same. */ void MainTests::test_retained_tree() { FlashMQTestClient sender; sender.start(); std::string payload = "We are testing"; const std::string topic1 = "TopicA/B"; const std::string topic2 = "Topic/C"; const std::string topic3 = "TopicB/C"; const std::list topics {topic1, topic2, topic3}; sender.connectClient(ProtocolVersion::Mqtt311); Publish p1(topic1, payload, 0); p1.retain = true; sender.publish(p1); Publish p2(topic2, payload, 0); p2.retain = true; sender.publish(p2); Publish p3(topic3, payload, 0); p3.retain = true; sender.publish(p3); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("+/+", 0); receiver.waitForMessageCount(3); auto ro = receiver.receivedObjects.lock(); QCOMPARE(ro->receivedPublishes.size(), topics.size()); for (const std::string &s : topics) { bool r = std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getTopic() == s && pack.getPayloadCopy() == payload; }); QVERIFY2(r, formatString("%s not found in retained messages.", s.c_str()).c_str()); } } void MainTests::test_retained_global_expire() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("expire_retained_messages_after_seconds 1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::vector protocols {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; for (const ProtocolVersion senderVersion : protocols) { for (const ProtocolVersion receiverVersion : protocols) { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); sender.connectClient(senderVersion); Publish pub1("retaintopic/1", "We are testing", 0); pub1.retain = true; sender.publish(pub1); Publish pub2("retaintopic/2", "asdf", 0); pub2.retain = true; sender.publish(pub2); usleep(2000000); receiver.connectClient(receiverVersion); receiver.subscribe("#", 0); usleep(500000); receiver.waitForMessageCount(0); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 0); } } } void MainTests::test_retained_per_message_expire() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("expire_retained_messages_after_seconds 10"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1("retaintopic/1", "We are testing", 0); pub1.retain = true; sender.publish(pub1); Publish pub2("retaintopic/2", "asdf", 0); pub2.retain = true; pub2.setExpireAfter(1); sender.publish(pub2); usleep(2000000); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("#", 0); usleep(500000); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &msg = ro->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), "We are testing"); QCOMPARE(msg.getTopic(), "retaintopic/1"); QVERIFY(msg.getRetain()); } void MainTests::test_retained_tree_purging() { std::shared_ptr store = globals->subscriptionStore; int toDeleteCount = 0; for (int i = 0; i < 10; i++) { for (int j = 0; j < 10; j++) { std::string topic = formatString("retain%d/bla%d/asdf", i, j); Publish pub(topic, "willnotexpire", 0); if (i % 2 == 0) { pub.setExpireAfter(1); pub.payload = "willexpire"; toDeleteCount++; } std::vector subtopics = splitTopic(topic); store->setRetainedMessage(pub, subtopics); } } { Publish pubStray("retain0/bla5", "willnotexpire", 0); std::vector subtopics = splitTopic(pubStray.topic); store->setRetainedMessage(pubStray, subtopics); } const int beforeCount = store->getAllRetainedMessages().size(); usleep(2000000); store->expireRetainedMessages(); std::vector list; const std::chrono::time_point limit = std::chrono::steady_clock::now() + std::chrono::milliseconds(1000); std::deque> deferred; store->getRetainedMessages(store->retainedMessagesRoot.get(), list, limit, 100000, deferred); QVERIFY(deferred.empty()); QVERIFY(std::none_of(list.begin(), list.end(), [](RetainedMessage &rm) { return rm.publish.payload == "willexpire"; })); QVERIFY(std::all_of(list.begin(), list.end(), [](RetainedMessage &rm) { return rm.publish.payload == "willnotexpire"; })); MYCASTCOMPARE(store->getAllRetainedMessages().size(), beforeCount - toDeleteCount); } void MainTests::testRetainAsPublished() { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient client2; client2.start(); client2.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient client3; client3.start(); client3.connectClient(ProtocolVersion::Mqtt5); client2.subscribe("mytopic", 1, false, false); client.subscribe("mytopic", 1, false, true); client3.subscribe("mytopic", 1, false, false); try { Publish pub("mytopic", "mypayload", 1); pub.retain = true; client.publish(pub); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } try { client.waitForMessageCount(1); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } auto ro = client.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); const MqttPacket &first = ro->receivedPublishes.front(); QVERIFY(first.getRetain()); } void MainTests::testRetainAsPublishedNegative() { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5); client.subscribe("mytopic", 1, false, false); try { Publish pub("mytopic", "mypayload", 1); pub.retain = true; client.publish(pub); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } try { client.waitForMessageCount(1); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } auto ro = client.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); const MqttPacket &first = ro->receivedPublishes.front(); QVERIFY(!first.getRetain()); } /** * @brief MainTests::testRetainedParentOfWildcard tests whether subscribing to 'one/two/three/four/#' gives you 'one/two/three/four'. */ void MainTests::testRetainedParentOfWildcard() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing testRetainedParentOfWildcard"; const std::string publish_topic = "one/two/three/four"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(publish_topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("dummy", 0); receiver.subscribe("one/two/three/four/#", 0); try { receiver.waitForMessageCount(1); } catch (std::exception &ex) { QVERIFY2(false, "Exception happened. Likely waited for retained messages, but none received."); } auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &msg = ro->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), payload); QCOMPARE(msg.getTopic(), publish_topic); QVERIFY(msg.getRetain()); } void MainTests::testRetainedWildcard() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing testRetainedWildcard"; const std::string publish_topic = "one/two/three/four"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(publish_topic, payload, 0); pub1.retain = true; sender.publish(pub1); Publish pub2("publish/into/nothing", payload, 0); pub2.retain = true; sender.publish(pub2); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("dummy", 0); receiver.subscribe("one/two/three/#", 0); try { receiver.waitForMessageCount(1); } catch (std::exception &ex) { QVERIFY2(false, "Exception happened. Likely waited for retained messages, but none received."); } auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &msg = ro->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), payload); QCOMPARE(msg.getTopic(), publish_topic); QVERIFY(msg.getRetain()); } /** * @brief MainTests::testRetainedAclReadCheck tests the manipulation of the retain bit in the original incoming packet. * * This has to do with the optimization in CopyFactory, that under certain conditions, the original packet's vector is * just used to write to the client. */ void MainTests::testRetainedAclReadCheck() { ConfFileTemp confFile; confFile.writeLine("plugin plugins/libtest_plugin.so.0.0.1"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, true, 30, [](Connect &connect) { connect.clientid = "test_user_with_retain_as_published_v8sIeCvI"; }); client.subscribe("mytopic", 1, false, true); FlashMQTestClient client2; client2.start(); client2.connectClient(ProtocolVersion::Mqtt5, true, 30, [](Connect &connect) { connect.clientid = "test_user_without_retain_as_published_CswU21YA"; }); client2.subscribe("mytopic", 1, false, false); FlashMQTestClient publish_client; publish_client.start(); publish_client.connectClient(ProtocolVersion::Mqtt5); Publish pub("mytopic", "mypayload", 1); pub.retain = true; publish_client.publish(pub); try { client.waitForMessageCount(1); client2.waitForMessageCount(1); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } auto ro = client.receivedObjects.lock(); auto ro2 = client2.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MYCASTCOMPARE(ro2->receivedPublishes.size(), 1); } void MainTests::testRetainHandlingDontGiveRetain() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5, true, 0); receiver.subscribe(topic, 0, false, false, 0, RetainHandling::DoNotSendRetainedMessages); usleep(250000); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } void MainTests::testRetainHandlingDontGiveRetainOnExistingSubscription() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "retained payload"; const std::string topic = "retaintopic/one/two/three"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); { receiver.connectClient(ProtocolVersion::Mqtt5, true, 0); receiver.subscribe(topic, 0, false, false, 0, RetainHandling::SendRetainedMessagesAtNewSubscribeOnly); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.size() == 1); } receiver.clearReceivedLists(); receiver.subscribe(topic, 1, false, false, 0, RetainHandling::SendRetainedMessagesAtNewSubscribeOnly); usleep(250000); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } ================================================ FILE: FlashMQTests/run-make-from-ci.sh ================================================ #!/bin/bash extra_configs=() compiler="g++" while [ -n "$*" ]; do flag=$1 value=$2 case "$flag" in "--compiler") compiler="$value" shift ;; "--extra-config") extra_configs+=("$value") shift ;; "--") shift break ;; *) echo -e "unknown option $flag\n" exit 1 ;; esac shift done set -u for item in "${extra_configs[@]}"; do extra_configs_expanded=("${extra_configs_expanded[@]}" "-D" "$item") done CXX="$compiler" cmake -DCMAKE_BUILD_TYPE=Debug -S . -B buildtests "${extra_configs_expanded[@]}" nprocs=4 if _nprocs=$(nproc); then nprocs="$_nprocs" fi make -C buildtests -j "$nprocs" ================================================ FILE: FlashMQTests/run-tests-from-ci.sh ================================================ #!/bin/bash STDERR_LOG=$(mktemp) cd buildtests || exit 1 # Using --abort-on-first-fail because the output can be hard to find in CI when it's swamped out by the rest. if ./flashmq-tests --abort-on-first-fail 2> "$STDERR_LOG" ; then echo -e '\033[01;32mSUCCESS!\033[00m' else echo -e '\033[01;31mBummer\033[00m' echo -e "\n\nTail of stderr:\n\n" tail -n 200 "$STDERR_LOG" exit 1 fi ================================================ FILE: FlashMQTests/sharedsubscriptionstests.cpp ================================================ #include "maintests.h" #include "testhelpers.h" #include "flashmqtestclient.h" #include "flashmqtempdir.h" #include "conffiletemp.h" #include "utils.h" #include "threadglobals.h" void MainTests::testSharedSubscribersUnit() { Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); ThreadGlobals::assignThreadData(t); std::shared_ptr c1(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); c1->setClientProperties(ProtocolVersion::Mqtt5, "clientid1", {}, "user1", true, 60); std::shared_ptr c2(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); c2->setClientProperties(ProtocolVersion::Mqtt5, "clientid2", {}, "user2", true, 60); std::shared_ptr c3(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); c3->setClientProperties(ProtocolVersion::Mqtt5, "clientid3", {}, "user3", true, 60); std::shared_ptr ses1 = std::make_shared(c1->getClientId(), c1->getUsername(), std::optional()); ses1->assignActiveConnection(c1); std::shared_ptr ses2 = std::make_shared(c2->getClientId(), c2->getUsername(), std::optional()); ses2->assignActiveConnection(c2); std::shared_ptr ses3 = std::make_shared(c3->getClientId(), c3->getUsername(), std::optional()); ses3->assignActiveConnection(c3); SharedSubscribers s("dummy", {}); s[ses1->getClientId()].session = ses1; MYCASTCOMPARE(s.members.size(), 1); s[ses2->getClientId()].session = ses2; MYCASTCOMPARE(s.members.size(), 2); s[ses3->getClientId()].session = ses3; MYCASTCOMPARE(s.members.size(), 3); s[ses2->getClientId()].reset(); MYCASTCOMPARE(s.members.size(), 3); QCOMPARE(*s.getNext(), s[ses1->getClientId()]); QCOMPARE(*s.getNext(), s[ses3->getClientId()]); s.purgeAndReIndex(); MYCASTCOMPARE(s.members.size(), 2); // We still should get the same two active members QCOMPARE(*s.getNext(), s[ses1->getClientId()]); QCOMPARE(*s.getNext(), s[ses3->getClientId()]); s.erase(ses3->getClientId()); // Now we only have one left QCOMPARE(*s.getNext(), s[ses1->getClientId()]); QCOMPARE(*s.getNext(), s[ses1->getClientId()]); s.erase(ses1->getClientId()); QVERIFY(!s.empty()); s.purgeAndReIndex(); QVERIFY(s.empty()); } void MainTests::testSharedSubscribers() { FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt5); receiver1.subscribe("$share/ahTahHu5/one/two/three", 1); receiver2.subscribe("$share/ahTahHu5/one/two/three", 1); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); sender.publish("one/two/three", "rainy day", 1); sender.publish("one/two/three", "sunny day", 1); receiver1.waitForMessageCount(1); receiver2.waitForMessageCount(1); { auto ro1 = receiver1.receivedObjects.lock(); auto ro2 = receiver2.receivedObjects.lock(); MYCASTCOMPARE(ro1->receivedPublishes.size(), 1); MYCASTCOMPARE(ro2->receivedPublishes.size(), 1); int rain = std::count_if(ro1->receivedPublishes.begin(), ro1->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "rainy day";}); int sun = std::count_if(ro2->receivedPublishes.begin(), ro2->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "sunny day";}); QCOMPARE(rain, 1); QCOMPARE(sun, 1); } receiver1.unsubscribe("$share/ahTahHu5/one/two/three"); receiver1.clearReceivedLists(); receiver2.clearReceivedLists(); sender.publish("one/two/three", "received by one", 1); sender.publish("one/two/three", "received by one", 1); receiver2.waitForMessageCount(2); { auto ro1 = receiver1.receivedObjects.lock(); auto ro2 = receiver2.receivedObjects.lock(); MYCASTCOMPARE(ro1->receivedPublishes.size(), 0); MYCASTCOMPARE(ro2->receivedPublishes.size(), 2); QCOMPARE(ro2->receivedPublishes.at(0).getPayloadCopy(), "received by one"); } } void MainTests::testDisconnectedSharedSubscribers() { FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver3; receiver3.start(); receiver3.connectClient(ProtocolVersion::Mqtt5); receiver1.subscribe("$share/iidahs2U/one/two/three", 1); receiver2.subscribe("$share/iidahs2U/one/two/three", 1); receiver3.subscribe("$share/iidahs2U/one/two/three", 1); receiver2.disconnect(ReasonCodes::Success); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); sender.publish("one/two/three", "rainy day", 1); sender.publish("one/two/three", "sunny day", 1); receiver1.waitForMessageCount(1); receiver3.waitForMessageCount(1); auto ro1 = receiver1.receivedObjects.lock(); auto ro3 = receiver3.receivedObjects.lock(); MYCASTCOMPARE(ro1->receivedPublishes.size(), 1); MYCASTCOMPARE(ro3->receivedPublishes.size(), 1); int rain = std::count_if(ro1->receivedPublishes.begin(), ro1->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "rainy day";}); int sun = std::count_if(ro3->receivedPublishes.begin(), ro3->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "sunny day";}); QCOMPARE(rain, 1); QCOMPARE(sun, 1); } void MainTests::testUnsubscribedSharedSubscribers() { FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver3; receiver3.start(); receiver3.connectClient(ProtocolVersion::Mqtt5); receiver1.subscribe("$share/iidahs2U/one/two/three", 1); receiver2.subscribe("$share/iidahs2U/one/two/three", 1); receiver3.subscribe("$share/iidahs2U/one/two/three", 1); receiver2.unsubscribe("$share/iidahs2U/one/two/three"); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); sender.publish("one/two/three", "rainy day", 1); sender.publish("one/two/three", "sunny day", 1); receiver1.waitForMessageCount(1); receiver3.waitForMessageCount(1); auto ro1 = receiver1.receivedObjects.lock(); auto ro3 = receiver3.receivedObjects.lock(); MYCASTCOMPARE(ro1->receivedPublishes.size(), 1); MYCASTCOMPARE(ro3->receivedPublishes.size(), 1); int rain = std::count_if(ro1->receivedPublishes.begin(), ro1->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "rainy day";}); int sun = std::count_if(ro3->receivedPublishes.begin(), ro3->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "sunny day";}); QCOMPARE(rain, 1); QCOMPARE(sun, 1); } void MainTests::testSharedSubscribersSurviveRestart() { FlashMQTempDir storageDir; ConfFileTemp confFile; confFile.writeLine(formatString("storage_dir %s", storageDir.getPath().c_str())); confFile.writeLine("allow_anonymous yes"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect){ connect.clientid = "Receiver1"; }); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect){ connect.clientid = "Receiver2"; }); receiver1.subscribe("$share/kw3O9fGK/one/two/three", 1); receiver2.subscribe("$share/kw3O9fGK/one/two/three", 1); // Restart the server. cleanup(); init(args); receiver1.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect){ connect.clientid = "Receiver1"; }); receiver2.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect){ connect.clientid = "Receiver2"; }); // Now that we should have resumed sessions, perform a test like testSharedSubscribers() FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); sender.publish("one/two/three", "rainy day", 1); sender.publish("one/two/three", "sunny day", 1); receiver1.waitForMessageCount(1); receiver2.waitForMessageCount(1); auto ro1 = receiver1.receivedObjects.lock(); auto ro2 = receiver2.receivedObjects.lock(); MYCASTCOMPARE(ro1->receivedPublishes.size(), 1); MYCASTCOMPARE(ro2->receivedPublishes.size(), 1); int rain = std::count_if(ro1->receivedPublishes.begin(), ro1->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "rainy day";}); int sun = std::count_if(ro2->receivedPublishes.begin(), ro2->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "sunny day";}); QCOMPARE(rain, 1); QCOMPARE(sun, 1); // This makes sure the server is shutdown before FlashMQTempDir can remove our temp dir. cleanup(); } void MainTests::testSharedSubscriberDoesntGetRetainedMessages() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string payload = "We are testing"; const std::string topic = "$share/sharename/retaintopic"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe(topic, 0); usleep(250000); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 0); } ================================================ FILE: FlashMQTests/subscriptionidtests.cpp ================================================ #include "maintests.h" #include "flashmqtestclient.h" #include "testhelpers.h" void MainTests::testSubscriptionIdOnlineClient() { FlashMQTestClient client1; client1.start(); client1.connectClient(ProtocolVersion::Mqtt5); client1.subscribe("several/sub/topics", 1, false, false, 666); FlashMQTestClient client2; client2.start(); client2.connectClient(ProtocolVersion::Mqtt5); client2.subscribe("several/sub/topics", 1, false, false, 777); // Also one without an identifier. FlashMQTestClient client3; client3.start(); client3.connectClient(ProtocolVersion::Mqtt5); client3.subscribe("several/sub/topics", 1, false, false); { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub("several/sub/topics", "payload", 1); sender.publish(pub); } client1.waitForMessageCount(1); client2.waitForMessageCount(1); client3.waitForMessageCount(1); { auto ro = client1.receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(666)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } { auto ro = client2.receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(777)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } { auto ro = client3.receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(0)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } } void MainTests::testSubscriptionIdOfflineClient() { std::optional client1; client1.emplace(); client1->start(); client1->connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &c) { c.clientid = "one"; }); client1->subscribe("several/sub/topics", 1, false, false, 42); client1.reset(); std::optional client2; client2.emplace(); client2->start(); client2->connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &c) { c.clientid = "two"; }); client2->subscribe("several/sub/topics", 1, false, false, 99); client2.reset(); { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub("several/sub/topics", "payload", 1); sender.publish(pub); } client1.emplace(); client1->start(); client1->connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &c) { c.clientid = "one"; }); client2.emplace(); client2->start(); client2->connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &c) { c.clientid = "two"; }); client1->waitForMessageCount(1); client2->waitForMessageCount(1); { auto ro = client1.value().receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(42)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } { auto ro = client2.value().receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(99)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } } void MainTests::testSubscriptionIdRetainedMessages() { FlashMQTestClient sender; sender.start(); const std::string payload = "We are testing"; const std::string topic = "retaintopic"; sender.connectClient(ProtocolVersion::Mqtt5); Publish pub1(topic, payload, 0); pub1.retain = true; sender.publish(pub1); FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5); receiver1.subscribe("dummy", 0); receiver1.subscribe(topic, 0, false, false, 123); receiver1.waitForMessageCount(1); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt5); receiver2.subscribe("dummy", 0); receiver2.subscribe(topic, 0, false, false, 1000000); receiver2.waitForMessageCount(1); { auto ro1 = receiver1.receivedObjects.lock(); auto ro2 = receiver2.receivedObjects.lock(); MYCASTCOMPARE(ro1->receivedPublishes.size(), 1); MYCASTCOMPARE(ro2->receivedPublishes.size(), 1); MqttPacket &msg = ro1->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), payload); QCOMPARE(msg.getTopic(), topic); QVERIFY(msg.getRetain()); } { auto ro1 = receiver1.receivedObjects.lock(); auto &pack = ro1->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(123)); FMQ_COMPARE(pack.getTopic(), topic); FMQ_COMPARE(pack.getPayloadView(), payload); FMQ_VERIFY(pack.getRetain()); } { auto ro2 = receiver2.receivedObjects.lock(); auto &pack = ro2->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(1000000)); FMQ_COMPARE(pack.getTopic(), topic); FMQ_COMPARE(pack.getPayloadView(), payload); FMQ_VERIFY(pack.getRetain()); } } void MainTests::testSubscriptionIdSharedSubscriptions() { FlashMQTestClient client1; client1.start(); client1.connectClient(ProtocolVersion::Mqtt5); client1.subscribe("$share/myshare/several/sub/topics", 1, false, false, 666); FlashMQTestClient client2; client2.start(); client2.connectClient(ProtocolVersion::Mqtt5); client2.subscribe("$share/myshare/several/sub/topics", 1, false, false, 777); // Also one without an identifier. FlashMQTestClient client3; client3.start(); client3.connectClient(ProtocolVersion::Mqtt5); client3.subscribe("$share/myshare/several/sub/topics", 1, false, false); { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub("several/sub/topics", "payload", 1); for (int i = 0; i < 3; i++) sender.publish(pub); } client1.waitForMessageCount(1); client2.waitForMessageCount(1); client3.waitForMessageCount(1); { auto ro1 = client1.receivedObjects.lock(); FMQ_COMPARE(ro1->receivedPublishes.size(), static_cast(1)); auto &pack = ro1->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(666)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } { auto ro2 = client2.receivedObjects.lock(); FMQ_COMPARE(ro2->receivedPublishes.size(), static_cast(1)); auto &pack = ro2->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(777)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } { auto ro3 = client3.receivedObjects.lock(); FMQ_COMPARE(ro3->receivedPublishes.size(), static_cast(1)); auto &pack = ro3->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(0)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } } void MainTests::testSubscriptionIdChange() { FlashMQTestClient client1; client1.start(); client1.connectClient(ProtocolVersion::Mqtt5); client1.subscribe("several/sub/topics", 1, false, false, 666); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub("several/sub/topics", "payload", 1); sender.publish(pub); client1.waitForMessageCount(1); { auto ro = client1.receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(666)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } // Now we subscribe again, but with a different identifier. client1.subscribe("several/sub/topics", 1, false, false, 667); sender.publish(pub); client1.waitForMessageCount(1); { auto ro = client1.receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.subscriptionIdentifierTesting, static_cast(667)); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } } void MainTests::testSubscriptionIdOverlappingSubscriptions() { FlashMQTestClient client1; client1.start(); client1.connectClient(ProtocolVersion::Mqtt5); client1.subscribe("several/sub/topics", 1, false, false, 666); client1.subscribe("several/#", 1, false, false, 999); { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub("several/sub/topics", "payload", 1); sender.publish(pub); } client1.waitForMessageCount(2); { auto ro = client1.receivedObjects.lock(); auto pos = std::find_if(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [] (MqttPacket &p) { return p.publishData.subscriptionIdentifierTesting == 999; }); FMQ_VERIFY(pos != ro->receivedPublishes.end()); FMQ_COMPARE(pos->publishData.subscriptionIdentifierTesting, static_cast(999)); FMQ_COMPARE(pos->getTopic(), "several/sub/topics"); FMQ_COMPARE(pos->getPayloadView(), "payload"); } { auto ro = client1.receivedObjects.lock(); auto pos = std::find_if(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [] (MqttPacket &p) { return p.publishData.subscriptionIdentifierTesting == 666; }); FMQ_VERIFY(pos != ro->receivedPublishes.end()); FMQ_COMPARE(pos->publishData.subscriptionIdentifierTesting, static_cast(666)); FMQ_COMPARE(pos->getTopic(), "several/sub/topics"); FMQ_COMPARE(pos->getPayloadView(), "payload"); } } ================================================ FILE: FlashMQTests/testhelpers.cpp ================================================ #include "testhelpers.h" #include #include int assert_count; int assert_fail_count; bool asserts_print; bool fmq_assert(bool b, const char *failmsg, const char *actual, const char *expected, const char *file, int line) { assert_count++; if (!b) { assert_fail_count++; if (asserts_print) { // There are two types of failmsg: unformatted ones and formatted ones. // By testing for a newline we can detect formatted ones. if (strchr(failmsg, '\n') == nullptr) { // unformatted std::cout << RED << "FAIL" << COLOR_END << ": '" << failmsg << "', " << actual << " != " << expected << std::endl << " in " << file << ", line " << line << std::endl; } else { // formatted std::cout << RED << "FAIL" << COLOR_END << " in " << file << ", line " << line << std::endl; std::cout << failmsg << std::endl; std::cout << "Comparison: " << actual << " != " << expected << std::endl; } } } return b; } void fmq_fail(const char *failmsg, const char *file, int line) { assert_count++; assert_fail_count++; if (asserts_print) std::cout << RED << "FAIL" << COLOR_END << ": " << failmsg << std::endl << " in " << file << ", line " << line << std::endl; } ================================================ FILE: FlashMQTests/testhelpers.h ================================================ #ifndef TESTHELPERS_H #define TESTHELPERS_H #include #define RED "\033[01;31m" #define GREEN "\033[01;32m" #define CYAN "\033[01;36m" #define COLOR_END "\033[00m" extern int assert_count; extern int assert_fail_count; extern bool asserts_print; bool fmq_assert(bool b, const char *failmsg, const char *actual, const char *expected, const char *file, int line); void fmq_fail(const char *failmsg, const char *file, int line); #define FMQ_COMPARE(actual, expected) \ do { \ if (!fmq_compare(actual, expected, #actual, #expected, __FILE__, __LINE__))\ return; \ } while(false) #define FMQ_VERIFY(val) \ do { \ if (!fmq_assert(static_cast(val), "assertion failed", #val, "true", __FILE__, __LINE__))\ return; \ } while (false) #define FMQ_VERIFY2(val, failmsg) \ do { \ if (!fmq_assert(static_cast(val), failmsg, #val, "true", __FILE__, __LINE__))\ return; \ } while (false) #define FMQ_FAIL(msg) fmq_fail(msg, __FILE__, __LINE__) // Compatability for porting the tests away from Qt. #define QCOMPARE(actual, expected) FMQ_COMPARE(actual, expected) #define QVERIFY(val) FMQ_VERIFY(val) #define QVERIFY2(val, failmsg) FMQ_VERIFY2(val, failmsg) #define QFAIL(msg) FMQ_FAIL(msg) inline bool fmq_compare(const std::string &s1, const std::string &s2, const char *actual, const char *expected, const char *file, int line) { std::ostringstream oss; oss << s1 << " != " << s2; return fmq_assert(s1 == s2, oss.str().c_str(), actual, expected, file, line); } template inline bool fmq_compare(const T1 &t1, const T2 &t2, const char *actual, const char *expected, const char *file, int line) { return fmq_assert(t1 == t2, "Values aren't the same", actual, expected, file, line); } inline bool fmq_compare(const char *c1, const char *c2, const char *actual, const char *expected, const char *file, int line) { std::string s1(c1); std::string s2(c2); std::ostringstream oss; if (s1.length() + s2.length() < 100) { // short form oss << s1 << " != " << s2; } else { oss << "Actual: " << s1 << std::endl; oss << "Expected: " << s2; } return fmq_assert(s1 == s2, oss.str().c_str(), actual, expected, file, line); } template inline bool myCastCompare(const T1 &t1, const T2 &t2, const char *actual, const char *expected, const char *file, int line) { T1 t2_ = static_cast(t2); return fmq_compare(t1, t2_, actual, expected, file, line); } #define MYCASTCOMPARE(actual, expected) \ do {\ if (!myCastCompare(actual, expected, #actual, #expected, __FILE__, __LINE__))\ return;\ } while (false) #endif // TESTHELPERS_H ================================================ FILE: FlashMQTests/testinitializer.cpp ================================================ #include "testinitializer.h" TestInitializer::TestInitializer(MainTests *tests) : tests(tests) { } void TestInitializer::init(bool startServer) { if (!tests) return; tests->initBeforeEachTest(startServer); } void TestInitializer::cleanup() { if (!tests) return; tests->cleanupAfterEachTest(); tests = nullptr; } TestInitializer::~TestInitializer() { cleanup(); } ================================================ FILE: FlashMQTests/testinitializer.h ================================================ #ifndef TESTINITIALIZER_H #define TESTINITIALIZER_H #include "maintests.h" /** * @brief Simple RAII way to make sure test cleanup is run. */ class TestInitializer { MainTests *tests = nullptr; public: TestInitializer(MainTests *tests); virtual ~TestInitializer(); TestInitializer(const TestInitializer &other) = delete; TestInitializer(TestInitializer &&other) = delete; TestInitializer &operator=(const TestInitializer &other) = delete; void init(bool startServer); void cleanup(); }; #endif // TESTINITIALIZER_H ================================================ FILE: FlashMQTests/tst_maintests.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include #include #include #include #include "maintests.h" #include "testhelpers.h" #include "flashmqtestclient.h" #include "conffiletemp.h" #include "mainappasfork.h" #include "threadglobals.h" #include "threadlocalutils.h" #include "retainedmessagesdb.h" #include "utils.h" #include "exceptions.h" #include "flashmqtempdir.h" void MainTests::test_circbuf() { CirBuf buf(64); MYCASTCOMPARE(buf.freeSpace(), 63); uint write_n = 40; char *head = buf.headPtr(); for (uint i = 0; i < write_n; i++) { head[i] = i+1; } buf.advanceHead(write_n); QCOMPARE(buf.head, write_n); MYCASTCOMPARE(buf.tail, 0); QCOMPARE(buf.maxReadSize(), write_n); QCOMPARE(buf.maxWriteSize(), (64 - write_n - 1)); QCOMPARE(buf.freeSpace(), 64 - write_n - 1); for (uint i = 0; i < write_n; i++) { MYCASTCOMPARE(buf.tailPtr()[i], i+1); } buf.advanceTail(write_n); QVERIFY(buf.tail == buf.head); QCOMPARE(buf.tail, write_n); MYCASTCOMPARE(buf.maxReadSize(), 0); QCOMPARE(buf.maxWriteSize(), (64 - write_n)); // no longer -1, because the head can point to 0 afterwards MYCASTCOMPARE(buf.freeSpace(), 63); write_n = buf.maxWriteSize(); head = buf.headPtr(); for (uint i = 0; i < write_n; i++) { head[i] = i+1; } buf.advanceHead(write_n); MYCASTCOMPARE(buf.head, 0); // Now write more, starting at the beginning. write_n = buf.maxWriteSize(); head = buf.headPtr(); for (uint i = 0; i < write_n; i++) { head[i] = i+100; // Offset by 100 so we can see if we overwrite the tail } buf.advanceHead(write_n); QCOMPARE(buf.tailPtr()[0], 1); // Did we not overwrite the tail? QCOMPARE(buf.head, buf.tail - 1); } void MainTests::test_circbuf_unwrapped_doubling() { CirBuf buf(64); int w = 63; char *head = buf.headPtr(); for (int i = 0; i < w; i++) { head[i] = i+1; } buf.advanceHead(63); char *tail = buf.tailPtr(); for (int i = 0; i < w; i++) { QCOMPARE(tail[i], i+1); } QCOMPARE(buf.buf[63], 0); // Vacant place, because of the circulerness. MYCASTCOMPARE(buf.head, 63); MYCASTCOMPARE(buf.freeSpace(), 0); buf.doubleCapacity(); tail = buf.tailPtr(); for (int i = 0; i < w; i++) { QCOMPARE(tail[i], i+1); } for (int i = 63; i < 127; i++) { QCOMPARE(tail[i], 5); } QCOMPARE(tail[127], 68); MYCASTCOMPARE(buf.tail, 0); MYCASTCOMPARE(buf.head, 63); MYCASTCOMPARE(buf.maxWriteSize(), 64); MYCASTCOMPARE(buf.maxReadSize(), 63); } void MainTests::test_circbuf_wrapped_doubling() { CirBuf buf(64); int w = 40; char *head = buf.headPtr(); for (int i = 0; i < w; i++) { head[i] = i+1; } buf.advanceHead(w); MYCASTCOMPARE(buf.tail, 0); MYCASTCOMPARE(buf.head, w); MYCASTCOMPARE(buf.maxReadSize(), 40); MYCASTCOMPARE(buf.maxWriteSize(), 23); buf.advanceTail(40); MYCASTCOMPARE(buf.maxWriteSize(), 24); head = buf.headPtr(); for (int i = 0; i < 24; i++) { head[i] = 99; } buf.advanceHead(24); MYCASTCOMPARE(buf.tail, 40); MYCASTCOMPARE(buf.head, 0); MYCASTCOMPARE(buf.maxReadSize(), 24); MYCASTCOMPARE(buf.maxWriteSize(), 39); // Now write a little more, which starts at the start head = buf.headPtr(); for (int i = 0; i < 10; i++) { head[i] = 88; } buf.advanceHead(10); MYCASTCOMPARE(buf.head, 10); buf.doubleCapacity(); // The 88's that were appended at the start, should now appear at the end; for (int i = 64; i < 74; i++) { MYCASTCOMPARE(buf.buf[i], 88); } MYCASTCOMPARE(buf.tail, 40); MYCASTCOMPARE(buf.head, 74); } void MainTests::test_circbuf_full_wrapped_buffer_doubling() { CirBuf buf(64); buf.head = 10; buf.tail = 10; memset(buf.headPtr(), 1, buf.maxWriteSize()); buf.advanceHead(buf.maxWriteSize()); memset(buf.headPtr(), 2, buf.maxWriteSize()); buf.advanceHead(buf.maxWriteSize()); for (int i = 0; i < 9; i++) { QCOMPARE(buf.buf[i], 2); } QCOMPARE(buf.buf[9], 0); for (int i = 10; i < 64; i++) { QCOMPARE(buf.buf[i], 1); } QVERIFY(true); buf.doubleCapacity(); // The places where value was 1 are the same for (int i = 10; i < 64; i++) { QCOMPARE(buf.buf[i], 1); } // The nine 2's have been moved to the end for (int i = 64; i < 73; i++) { QCOMPARE(buf.buf[i], 2); } // The rest are our debug 5. for (int i = 73; i < 128; i++) { QCOMPARE(buf.buf[i], 5); } QVERIFY(true); } void MainTests::test_cirbuf_vector_methods() { std::vector source(47); getrandom(source.data(), source.size(), 0); CirBuf target(8); for (int i = 0; i < 4096; i++) { target.writerange(source.begin(), source.end()); FMQ_COMPARE(target.usedBytes(), source.size()); std::vector peeked = target.peekAllToVector(); std::vector reread = i % 2 == 0 ? target.readToVector(source.size()) : target.readAllToVector(); FMQ_COMPARE(reread, source); FMQ_COMPARE(peeked, source); FMQ_COMPARE(target.usedBytes(), static_cast(0)); FMQ_COMPARE(target.freeSpace(), static_cast(63)); if (i % 64 == 0) target.resetCapacity(8); } } void MainTests::test_validSubscribePath() { QVERIFY(isValidSubscribePath("one/two/three")); QVERIFY(isValidSubscribePath("one//three")); QVERIFY(isValidSubscribePath("one/+/three")); QVERIFY(isValidSubscribePath("one/+/#")); QVERIFY(isValidSubscribePath("#")); QVERIFY(isValidSubscribePath("///")); QVERIFY(isValidSubscribePath("//#")); QVERIFY(isValidSubscribePath("+")); QVERIFY(isValidSubscribePath("")); QVERIFY(isValidSubscribePath("hello")); QVERIFY(isValidSubscribePath("$SYS/hello")); QVERIFY(isValidSubscribePath("hello/$SYS")); // Hmm, is this valid? QVERIFY(!isValidSubscribePath("one/tw+o/three")); QVERIFY(!isValidSubscribePath("one/+o/three")); QVERIFY(!isValidSubscribePath("one/a+/three")); QVERIFY(!isValidSubscribePath("#//three")); QVERIFY(!isValidSubscribePath("#//+")); QVERIFY(!isValidSubscribePath("one/#/+")); QVERIFY(!isValidSubscribePath("one/two#")); QVERIFY(!isValidSubscribePath("one/two#/three")); QVERIFY(!isValidSubscribePath("one/asdf+/+")); QVERIFY(!isValidSubscribePath("+one/asdf+/+")); } void MainTests::test_various_packet_sizes() { std::vector protocols {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; std::list payloads {std::string(8000,3), std::string(10*1024*1024, 5)}; for (const ProtocolVersion senderVersion : protocols) { for (const ProtocolVersion receiverVersion : protocols) { for (const std::string &payload : payloads) { FlashMQTestClient sender; FlashMQTestClient receiver; std::string topic = "hugepacket"; sender.start(); sender.connectClient(senderVersion); receiver.start(); receiver.connectClient(receiverVersion); receiver.subscribe(topic, 0); sender.publish(topic, payload, 0); receiver.waitForMessageCount(1, 2); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &msg = ro->receivedPublishes.front(); QCOMPARE(msg.getPayloadCopy(), payload); QVERIFY(!msg.getRetain()); } } } } void MainTests::test_acl_tree() { AclTree aclTree; aclTree.addTopic("one/two/#", AclGrant::ReadWrite, AclTopicType::Strings); aclTree.addTopic("one/two/three", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("a/+/c", AclGrant::Read, AclTopicType::Strings); aclTree.addTopic("1/+/3", AclGrant::ReadWrite, AclTopicType::Strings); aclTree.addTopic("1/blocked/3", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("cat/+/dog", AclGrant::Write, AclTopicType::Strings); aclTree.addTopic("cat/blocked/dog", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("cat/blocked/dog/bla/bla/#", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("cat/turtle/dog/%u/bla/#", AclGrant::ReadWrite, AclTopicType::Strings); aclTree.addTopic("fish/turtle/dog/%u/bla/#", AclGrant::ReadWrite, AclTopicType::Strings, "john"); aclTree.addTopic("fish/turtle/dog/%u/bla/#", AclGrant::ReadWrite, AclTopicType::Strings, "AAA"); QCOMPARE(aclTree.findPermission(splitToVector("one/two", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/two/four", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/two/four/five/six", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/two/four/five/six", '/'), AclGrant::Write, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/two/three", '/'), AclGrant::Read, "", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("asdf", '/'), AclGrant::Read, "", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/b/c", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("a/b/c", '/'), AclGrant::Write, "", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/wildcardmatch/c", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("1/2/3", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("1/2/3", '/'), AclGrant::Write, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("1/wildcardmatch/3", '/'), AclGrant::Write, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("1/wildcardmatch/3", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("cat/2/dog", '/'), AclGrant::Write, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("cat/2/dog", '/'), AclGrant::Read, "", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("cat/blocked/dog", '/'), AclGrant::Write, "", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("cat/blocked/dog", '/'), AclGrant::Read, "", "clientid"), AuthResult::acl_denied); // Test that wildcards aren't replaced here QCOMPARE(aclTree.findPermission(splitToVector("cat/turtle/dog/%u/bla/sdf", '/'), AclGrant::Read, "", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("fish/turtle/dog/%u/bla/sdf", '/'), AclGrant::Read, "john", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("fish/turtle/dog/john/bla/sdf", '/'), AclGrant::Read, "john", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("fish/turtle/dog/AAA/bla/sdf", '/'), AclGrant::Read, "AAA", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("fish/turtle/dog/john/bla/sdf", '/'), AclGrant::Read, "john", "clientid"), AuthResult::acl_denied); } void MainTests::test_acl_tree2() { AclTree aclTree; aclTree.addTopic("one/two/#", AclGrant::ReadWrite, AclTopicType::Strings); aclTree.addTopic("one/two/three", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("one/two/three", AclGrant::ReadWrite, AclTopicType::Strings, "Metusalem"); aclTree.addTopic("a/+/c", AclGrant::Read, AclTopicType::Strings); aclTree.addTopic("1/+/3", AclGrant::ReadWrite, AclTopicType::Strings); aclTree.addTopic("1/blocked/3", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("cat/+/dog", AclGrant::Write, AclTopicType::Strings); aclTree.addTopic("cat/blocked/dog", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("cat/blocked/dog/bla/bla/#", AclGrant::Deny, AclTopicType::Strings); // Test all these with a user, which should be denied. QCOMPARE(aclTree.findPermission(splitToVector("one/two", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("one/two/four", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("one/two/four/five/six", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("one/two/four/five/six", '/'), AclGrant::Write, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("one/two/three", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("asdf", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/b/c", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/b/c", '/'), AclGrant::Write, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/wildcardmatch/c", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("1/2/3", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("1/2/3", '/'), AclGrant::Write, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("1/wildcardmatch/3", '/'), AclGrant::Write, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("1/wildcardmatch/3", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("cat/2/dog", '/'), AclGrant::Write, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("cat/2/dog", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("cat/blocked/dog", '/'), AclGrant::Write, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("cat/blocked/dog", '/'), AclGrant::Read, "a", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("one/two/three", '/'), AclGrant::Read, "Metusalem", "clientid"), AuthResult::success); } void MainTests::test_acl_patterns_username() { AclTree aclTree; aclTree.addTopic("one/%u/three", AclGrant::ReadWrite, AclTopicType::Patterns); aclTree.addTopic("a/%u/c", AclGrant::Read, AclTopicType::Patterns); aclTree.addTopic("d/%u/f/#", AclGrant::Read, AclTopicType::Patterns); aclTree.addTopic("one/Jheronimus/three", AclGrant::Deny, AclTopicType::Strings); aclTree.addTopic("one/santaclause/three", AclGrant::Deny, AclTopicType::Strings, "santaclause"); aclTree.addTopic("%u/#", AclGrant::ReadWrite, AclTopicType::Patterns); // Succeeds, because the anonymous deny should have no effect on the authenticated ACL check, so it checks the pattern based. QCOMPARE(aclTree.findPermission(splitToVector("one/Jheronimus/three", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::success); // The fixed-strings deny for 'santaclause' should override the pattern based ReadWrite. QCOMPARE(aclTree.findPermission(splitToVector("one/santaclause/three", '/'), AclGrant::Read, "santaclause", "clientid"), AuthResult::acl_denied); aclTree.addTopic("some/thing", AclGrant::ReadWrite, AclTopicType::Strings, "Rembrandt"); aclTree.addTopic("some/thing", AclGrant::ReadWrite, AclTopicType::Patterns); QCOMPARE(aclTree.findPermission(splitToVector("one/Jheronimus/three", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/Theo/three", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/Jheronimus/c", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("a/NotJheronimus/c", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/Jheronimus/c", '/'), AclGrant::Write, "Jheronimus", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("d/Jheronimus/f", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("d/Jheronimus/f/A", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("d/Jheronimus/f/A/B", '/'), AclGrant::Read, "Jheronimus", "clientid"), AuthResult::success); // Repeat the test, but now with a user for which there is also an unrelated user specific ACL. QCOMPARE(aclTree.findPermission(splitToVector("one/Rembrandt/three", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/Theo/three", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/Rembrandt/c", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("a/NotRembrandt/c", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/Rembrandt/c", '/'), AclGrant::Write, "Rembrandt", "clientid"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("d/Rembrandt/f", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("d/Rembrandt/f/A", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("d/Rembrandt/f/A/B", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("Rembrandt", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("Rembrandt", '/'), AclGrant::Write, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("Rembrandt/wer", '/'), AclGrant::Read, "Rembrandt", "clientid"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("Rembrandt/eee", '/'), AclGrant::Write, "Rembrandt", "clientid"), AuthResult::success); } void MainTests::test_acl_patterns_clientid() { AclTree aclTree; aclTree.addTopic("one/%c/three", AclGrant::ReadWrite, AclTopicType::Patterns); aclTree.addTopic("a/%c/c", AclGrant::Read, AclTopicType::Patterns); aclTree.addTopic("d/%c/f/#", AclGrant::Read, AclTopicType::Patterns); QCOMPARE(aclTree.findPermission(splitToVector("one/clientid_one/three", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("one/clientid_two/three", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/clientid_one/c", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("a/not_clientidone/c", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("a/clientid_one/c", '/'), AclGrant::Write, "foo", "clientid_one"), AuthResult::acl_denied); QCOMPARE(aclTree.findPermission(splitToVector("d/clientid_one/f", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("d/clientid_one/f/A", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::success); QCOMPARE(aclTree.findPermission(splitToVector("d/clientid_one/f/A/B", '/'), AclGrant::Read, "foo", "clientid_one"), AuthResult::success); } /** * @brief MainTests::test_loading_acl_file was created because assertions in it failed when publishing $SYS topics were passed through the ACL * layer. That's why it seemingly doesn't do anything. */ void MainTests::test_loading_acl_file() { ConfFileTemp aclFile; aclFile.writeLine("topic readwrite one/two"); aclFile.closeFile(); ConfFileTemp confFile; confFile.writeLine("mosquitto_acl_file " + aclFile.getFilePath()); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); usleep(1000000); QVERIFY(true); } #ifndef FMQ_NO_SSE void MainTests::test_sse_split() { SimdUtils data; std::list topics; topics.push_back("one/two/threeabcasdfasdf/koe"); topics.push_back("/two/threeabcasdfasdf/koe"); // Test empty component. topics.push_back("//two/threeabcasdfasdf/koe"); // Test two empty components. topics.push_back("//1234567890abcde/bla/koe"); // Test two empty components, 15 char topic (one byte short of 16 alignment). topics.push_back("//1234567890abcdef/bla/koe"); // Test two empty components, 16 char topic topics.push_back("//1234567890abcdefg/bla/koe"); // Test two empty components, 17 char topic topics.push_back("//1234567890abcdefg/1234567890abcdefg/koe"); // Test two empty components, two 17 char topics topics.push_back("//1234567890abcdef/1234567890abcdefg/koe"); // Test two empty components, 16 and 17 char topics.push_back("//1234567890abcdef/1234567890abcdefg/koe/"); topics.push_back("//1234567890abcdef/1234567890abcdefg/koe//"); topics.push_back("//1234567890abcdef/1234567890abcdef/"); topics.push_back("/"); topics.push_back(""); for (const std::string &t : topics) { std::vector output = data.splitTopic(t); QCOMPARE(output, splitToVector(t, '/')); } } #endif void MainTests::test_validUtf8Generic() { char m[16]; QVERIFY(isValidUtf8Generic("")); QVERIFY(isValidUtf8Generic("ƀ")); QVERIFY(isValidUtf8Generic("Hello")); std::memset(m, 0, 16); QVERIFY(!isValidUtf8Generic(std::string(m, 16))); QVERIFY(isValidUtf8Generic("Straƀe")); // two byte chars QVERIFY(isValidUtf8Generic("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars QVERIFY(isValidUtf8Generic("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars std::memset(m, 0, 16); m[0] = 'a'; m[1] = 13; // is \r QVERIFY(!isValidUtf8Generic(std::string(m, 16))); const std::string unicode_ballet_shoes("🩰"); QVERIFY(unicode_ballet_shoes.length() == 4); QVERIFY(isValidUtf8Generic(unicode_ballet_shoes)); const std::string unicode_ballot_box("☐"); QVERIFY(unicode_ballot_box.length() == 3); QVERIFY(isValidUtf8Generic(unicode_ballot_box)); std::memset(m, 0, 16); m[0] = 0b11000001; // Start 2 byte char m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string a(m, 2); QVERIFY(!isValidUtf8Generic(a)); std::memset(m, 0, 16); m[0] = 0b11100001; // Start 3 byte char m[1] = 0b10100001; m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string b(m, 3); QVERIFY(!isValidUtf8Generic(b)); std::memset(m, 0, 16); m[0] = 0b11110001; // Start 4 byte char m[1] = 0b10100001; m[2] = 0b10100001; m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string c(m, 4); QVERIFY(!isValidUtf8Generic(c)); std::memset(m, 0, 16); m[0] = 0b11110001; // Start 4 byte char m[1] = 0b10100001; m[2] = 0b00100001; // Doesn't start with 1: invalid. m[3] = 0b10000001; std::string d(m, 4); QVERIFY(!isValidUtf8Generic(d)); // Upper ASCII, invalid std::memset(m, 0, 16); m[0] = 127; std::string e(m, 1); QVERIFY(!isValidUtf8Generic(e)); } #ifndef FMQ_NO_SSE void MainTests::test_validUtf8Sse() { SimdUtils data; char m[16]; QVERIFY(data.isValidUtf8("")); QVERIFY(data.isValidUtf8("ƀ")); QVERIFY(data.isValidUtf8("Hello")); std::memset(m, 0, 16); QVERIFY(!data.isValidUtf8(std::string(m, 16))); QVERIFY(data.isValidUtf8("Straƀe")); // two byte chars QVERIFY(data.isValidUtf8("StraƀeHelloHelloHelloHelloHelloHello")); // two byte chars QVERIFY(data.isValidUtf8("HelloHelloHelloHelloHelloHelloHelloHelloStraƀeHelloHelloHelloHelloHelloHello")); // two byte chars QVERIFY(!data.isValidUtf8("Straƀe#", true)); QVERIFY(!data.isValidUtf8("ƀ#", true)); QVERIFY(!data.isValidUtf8("#ƀ", true)); QVERIFY(!data.isValidUtf8("+", true)); QVERIFY(!data.isValidUtf8("🩰+asdfasdfasdf", true)); QVERIFY(!data.isValidUtf8("+asdfasdfasdf", true)); std::memset(m, 0, 16); m[0] = 'a'; m[1] = 13; // is \r QVERIFY(!data.isValidUtf8(std::string(m, 16))); const std::string unicode_ballet_shoes("🩰"); QVERIFY(unicode_ballet_shoes.length() == 4); QVERIFY(data.isValidUtf8(unicode_ballet_shoes)); const std::string unicode_ballot_box("☐"); QVERIFY(unicode_ballot_box.length() == 3); QVERIFY(data.isValidUtf8(unicode_ballot_box)); std::memset(m, 0, 16); m[0] = 0b11000001; // Start 2 byte char m[1] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string a(m, 2); QVERIFY(!data.isValidUtf8(a)); std::memset(m, 0, 16); m[0] = 0b11100001; // Start 3 byte char m[1] = 0b10100001; m[2] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string b(m, 3); QVERIFY(!data.isValidUtf8(b)); std::memset(m, 0, 16); m[0] = 0b11110001; // Start 4 byte char m[1] = 0b10100001; m[2] = 0b10100001; m[3] = 0b00000001; // Next byte doesn't start with 1, which is wrong std::string c(m, 4); QVERIFY(!data.isValidUtf8(c)); std::memset(m, 0, 16); m[0] = 0b11110001; // Start 4 byte char m[1] = 0b10100001; m[2] = 0b00100001; // Doesn't start with 1: invalid. m[3] = 0b10000001; std::string d(m, 4); QVERIFY(!data.isValidUtf8(d)); // Upper ASCII, invalid std::memset(m, 0, 16); m[0] = 127; std::string e(m, 1); QVERIFY(!data.isValidUtf8(e)); } /** * @brief MainTests::test_utf8_nonchars tests the 66 non-chars in unicode, and a bit around them. */ void MainTests::test_utf8_nonchars() { SimdUtils simd_utils; for (int i = 0x80; i < 0x90; i++) { std::string c; c.push_back(0xEF); c.push_back(0xB7); c.push_back(i); QVERIFY(isValidUtf8Generic(c)); QVERIFY(simd_utils.isValidUtf8(c)); } // The invalid ones for (int i = 0x90; i <= 0xAF; i++) { std::string c; c.push_back(0xEF); c.push_back(0xB7); c.push_back(i); QVERIFY(!isValidUtf8Generic(c)); QVERIFY(!simd_utils.isValidUtf8(c)); } for (int i = 0xB0; i < 0xB5; i++) { std::string c; c.push_back(0xEF); c.push_back(0xB7); c.push_back(i); QVERIFY(isValidUtf8Generic(c)); QVERIFY(simd_utils.isValidUtf8(c)); } // Now the last two code points of the multilingual planes { std::string s; s.clear(); s.push_back(0xEF); s.push_back(0xBF); s.push_back(0xBE); QVERIFY(!isValidUtf8Generic(s)); QVERIFY(!simd_utils.isValidUtf8(s)); s.clear(); s.push_back(0xEF); s.push_back(0xBF); s.push_back(0xBF); QVERIFY(!isValidUtf8Generic(s)); QVERIFY(!simd_utils.isValidUtf8(s)); // Adjacent one that is valid. s.clear(); s.push_back(0xEF); s.push_back(0xBF); s.push_back(0xBD); QVERIFY(isValidUtf8Generic(s)); QVERIFY(simd_utils.isValidUtf8(s)); } { std::string s; s.clear(); s.push_back(0xF0); s.push_back(0x9F); s.push_back(0xBF); s.push_back(0xBE); QVERIFY(!isValidUtf8Generic(s)); QVERIFY(!simd_utils.isValidUtf8(s)); s.clear(); s.push_back(0xF0); s.push_back(0x9F); s.push_back(0xBF); s.push_back(0xBF); QVERIFY(!isValidUtf8Generic(s)); QVERIFY(!simd_utils.isValidUtf8(s)); // Adjacent one that is valid. s.clear(); s.push_back(0xF0); s.push_back(0x9F); s.push_back(0xBF); s.push_back(0xBD); QVERIFY(isValidUtf8Generic(s)); QVERIFY(simd_utils.isValidUtf8(s)); } // TODO: there are more planes to check, but programming that out means encoding in UTF8. } /** * @brief MainTests::test_utf8_overlong tests multiple representations of '/' (ASCII 0xAF). * * U+002F = c0 af * U+002F = e0 80 af * U+002F = f0 80 80 af */ void MainTests::test_utf8_overlong() { SimdUtils simd_utils; { std::string two; two.push_back(0xc0); two.push_back(0xaf); QVERIFY(!isValidUtf8Generic(two)); QVERIFY(!simd_utils.isValidUtf8(two)); } { std::string three; three.push_back(0xe0); three.push_back(0x80); three.push_back(0xaf); QVERIFY(!isValidUtf8Generic(three)); QVERIFY(!simd_utils.isValidUtf8(three)); } { std::string four; four.push_back(0xf0); four.push_back(0x80); four.push_back(0x80); four.push_back(0xaf); QVERIFY(!isValidUtf8Generic(four)); QVERIFY(!simd_utils.isValidUtf8(four)); } } void MainTests::test_utf8_compare_implementation() { SimdUtils simd_utils; // Just something to look at. It prefixes lines with a red cross when the checker returns false. Note that this means you // don't see the difference between invalid UTF8 and valid but invalid for MQTT. FlashMQTempDir tmpdir; std::ofstream outfile(tmpdir.getPath() / "flashmq_utf8_test_result.txt", std::ios::binary); int line_count = 0; std::ifstream infile("UTF-8-test.txt", std::ios::binary); for(std::string line; getline(infile, line ); ) { const bool a = isValidUtf8Generic(line); const bool b = simd_utils.isValidUtf8(line); QVERIFY(a == b); if (a) outfile << "\u2705 "; else outfile << "\u274C "; outfile << line << std::endl; line_count++; } QVERIFY(line_count > 40); } #endif void MainTests::testPacketInt16Parse() { std::vector tests {128, 300, 64, 65550, 32000}; for (const uint16_t id : tests) { Publish pub("hallo", "content", 1); MqttPacket packet(ProtocolVersion::Mqtt311, pub); packet.setPacketId(id); packet.pos -= 2; uint16_t idParsed = packet.readTwoBytesToUInt16(); QVERIFY(id == idParsed); } } void MainTests::testRetainedMessageDB() { try { std::string longpayload = getSecureRandomString(65537); std::string longTopic = formatString("one/two/%s", getSecureRandomString(4000).c_str()); std::vector messages; messages.emplace_back(Publish("one/two/three", "payload", 0)); messages.emplace_back(Publish("one/two/wer", "payload", 1)); messages.emplace_back(Publish("one/e/wer", "payload", 1)); messages.emplace_back(Publish("one/wee/wer", "asdfasdfasdf", 1)); messages.emplace_back(Publish("one/two/wer", "µsdf", 1)); messages.emplace_back(Publish("/boe/bah", longpayload, 1)); messages.emplace_back(Publish("one/two/wer", "paylasdfaoad", 1)); messages.emplace_back(Publish("one/two/wer", "payload", 1)); messages.emplace_back(Publish(longTopic, "payload", 1)); messages.emplace_back(Publish(longTopic, longpayload, 1)); messages.emplace_back(Publish("one", "µsdf", 1)); messages.emplace_back(Publish("/boe", longpayload, 1)); messages.emplace_back(Publish("one", "µsdf", 1)); int clientidCount = 1; int usernameCount = 1; for (RetainedMessage &rm : messages) { rm.publish.client_id = formatString("Clientid__%d", clientidCount++); rm.publish.username = formatString("Username__%d", usernameCount++); } FlashMQTempDir tmpdir; auto dbpath = tmpdir.getPath() / "flashmqtests_retained.db"; RetainedMessagesDB db(dbpath); db.openWrite(); db.saveData(messages); db.closeFile(); RetainedMessagesDB db2(dbpath); db2.openRead(); std::list messagesLoaded = db2.readData(); db2.closeFile(); QCOMPARE(messagesLoaded.size(), messages.size()); auto itOrg = messages.begin(); auto itLoaded = messagesLoaded.begin(); while (itOrg != messages.end() && itLoaded != messagesLoaded.end()) { RetainedMessage &one = *itOrg; RetainedMessage &two = *itLoaded; // Comparing the fields because the RetainedMessage class has an == operator that only looks at topic. QCOMPARE(one.publish.topic, two.publish.topic); QCOMPARE(one.publish.payload, two.publish.payload); QCOMPARE(one.publish.qos, two.publish.qos); QVERIFY(!two.publish.client_id.empty()); QVERIFY(!two.publish.username.empty()); QCOMPARE(two.publish.client_id, one.publish.client_id); QCOMPARE(two.publish.username, one.publish.username); itOrg++; itLoaded++; } } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testRetainedMessageDBNotPresent() { try { FlashMQTempDir tmpdir; RetainedMessagesDB db2(tmpdir.getPath() / "flashmqtests_asdfasdfasdf.db"); db2.openRead(); std::list messagesLoaded = db2.readData(); db2.closeFile(); MYCASTCOMPARE(messagesLoaded.size(), 0); QVERIFY2(false, "We should have run into an exception."); } catch (PersistenceFileCantBeOpened &ex) { QVERIFY(true); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testRetainedMessageDBEmptyList() { try { std::vector messages; FlashMQTempDir tmpdir; std::string dbpath = tmpdir.getPath() / "flashmqtests_retained.db"; RetainedMessagesDB db(dbpath); db.openWrite(); db.saveData(messages); db.closeFile(); RetainedMessagesDB db2(dbpath); db2.openRead(); std::list messagesLoaded = db2.readData(); db2.closeFile(); MYCASTCOMPARE(messages.size(), messagesLoaded.size()); MYCASTCOMPARE(messages.size(), 0); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testSavingSessions() { try { Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); std::shared_ptr c1(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); c1->setClientProperties(ProtocolVersion::Mqtt5, "c1", {}, "user1", true, 60); store->registerClientAndKickExistingOne(c1, false, 512, 120); c1->getSession()->addIncomingQoS2MessageId(2); c1->getSession()->addIncomingQoS2MessageId(3); std::shared_ptr c2(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); c2->setClientProperties(ProtocolVersion::Mqtt5, "c2", {}, "user2", true, 60); store->registerClientAndKickExistingOne(c2, false, 512, 120); c2->getSession()->addOutgoingQoS2MessageId(55); c2->getSession()->addOutgoingQoS2MessageId(66); const std::string topic1 = "one/two/three"; std::vector subtopics; subtopics = splitTopic(topic1); store->addSubscription(c1->getSession(), subtopics, 0, true, false, "", 0); const std::string topic2 = "four/five/six"; subtopics = splitTopic(topic2); store->addSubscription(c2->getSession(), subtopics, 0, false, true, "", 0); store->addSubscription(c1->getSession(), subtopics, 0, false, false, "", 94612); const std::string topic3 = ""; subtopics = splitTopic(topic3); store->addSubscription(c2->getSession(), subtopics, 0, false, false, "", 0); const std::string topic4 = "#"; subtopics = splitTopic(topic4); store->addSubscription(c2->getSession(), subtopics, 0, false, false, "", 0); Publish publish("a/b/c", "Hello Barry", 1); publish.client_id = "ClientIdFromFakePublisher"; publish.username = "UsernameFromFakePublisher"; publish.setExpireAfter(10); usleep(1000000); std::shared_ptr c1ses = c1->getSession(); c1.reset(); MqttPacket publishPacket(ProtocolVersion::Mqtt5, publish); PublishCopyFactory fac(&publishPacket); c1ses->writePacket(fac, 1, false, 6268); FlashMQTempDir tmpdir; auto dbpath = tmpdir.getPath() / "flashmqtests_sessions.db"; store->saveSessionsAndSubscriptions(dbpath); usleep(1000000); std::shared_ptr store2(new SubscriptionStore()); store2->loadSessionsAndSubscriptions(dbpath); MYCASTCOMPARE(store->sessionsById.size(), 2); MYCASTCOMPARE(store2->sessionsById.size(), 2); for (auto &pair : store->sessionsById) { std::shared_ptr &ses = pair.second; std::shared_ptr &ses2 = store2->sessionsById[pair.first]; MutexLocked qos_locked = ses->qos.lock(); MutexLocked qos_locked2 = ses2->qos.lock(); QCOMPARE(pair.first, ses2->getClientId()); QCOMPARE(ses->username, ses2->username); QCOMPARE(ses->client_id, ses2->client_id); QCOMPARE(qos_locked->incomingQoS2MessageIds, qos_locked2->incomingQoS2MessageIds); QCOMPARE(qos_locked->outgoingQoS2MessageIds, qos_locked2->outgoingQoS2MessageIds); QCOMPARE(qos_locked->nextPacketId, qos_locked2->nextPacketId); } std::unordered_map> store1Subscriptions; store1Subscriptions = store->getSubscriptions(); std::unordered_map> store2Subscriptions; store2Subscriptions = store2->getSubscriptions(); MYCASTCOMPARE(store1Subscriptions.size(), 4); MYCASTCOMPARE(store2Subscriptions.size(), 4); int noLocalCount = 0; int retainAsPublishedCount = 0; int withSubscriptionIdentifierCount = 0; for(auto &pair : store1Subscriptions) { std::list &subscList1 = pair.second; std::list &subscList2 = store2Subscriptions[pair.first]; QCOMPARE(subscList1.size(), subscList2.size()); // They're not sorted/deterministic, so resorting to this. for (SubscriptionForSerializing &one : subscList1) { int match_count = 0; for (SubscriptionForSerializing &two : subscList2) { if (one.clientId != two.clientId) continue; match_count++; QCOMPARE(one.clientId, two.clientId); QCOMPARE(one.qos, two.qos); QCOMPARE(one.noLocal, two.noLocal); QCOMPARE(one.retainAsPublished, two.retainAsPublished); QCOMPARE(one.subscriptionidentifier, two.subscriptionidentifier); if (two.noLocal) noLocalCount++; if (two.retainAsPublished) retainAsPublishedCount++; if (two.subscriptionidentifier > 0) withSubscriptionIdentifierCount++; } FMQ_COMPARE(match_count, 1); } } QVERIFY(noLocalCount > 0); QVERIFY(retainAsPublishedCount == 1); FMQ_VERIFY(withSubscriptionIdentifierCount == 1); std::shared_ptr loadedSes = store2->sessionsById["c1"]; MutexLocked qos_loaded_locked = loadedSes->qos.lock(); std::shared_ptr queuedPublishLoaded = qos_loaded_locked->qosPacketQueue.popNext(); QCOMPARE(queuedPublishLoaded->getPublish().topic, "a/b/c"); QCOMPARE(queuedPublishLoaded->getPublish().payload, "Hello Barry"); QCOMPARE(queuedPublishLoaded->getPublish().qos, 1); QCOMPARE(queuedPublishLoaded->getPublish().client_id, "ClientIdFromFakePublisher"); QCOMPARE(queuedPublishLoaded->getPublish().username, "UsernameFromFakePublisher"); QCOMPARE(queuedPublishLoaded->getPublish().expireInfo.value().expiresAfter.count(), 9); const auto publish_age = queuedPublishLoaded->getPublish().getAge().count(); FMQ_VERIFY(publish_age > 900); FMQ_VERIFY(publish_age < 2100); QCOMPARE(queuedPublishLoaded->getPublish().subscriptionIdentifierTesting, static_cast(6268)); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testParsePacketHelper(const std::string &topic, uint8_t from_qos, bool retain) { Logger::getInstance()->setFlags(LogLevel::None, false, false); Settings settings; settings.logLevel = LogLevel::Info; std::shared_ptr store(new SubscriptionStore()); std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); std::shared_ptr dummyClient(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "qostestclient", {}, "user1", true, 60); store->registerClientAndKickExistingOne(dummyClient, false, 512, 120); uint16_t packetid = 66; for (int len = 0; len < 150; len++ ) { const uint16_t pack_id = packetid++; std::vector parsedPackets; const std::string payloadOne = getSecureRandomString(len); Publish pubOne(topic, payloadOne, from_qos); pubOne.retain = retain; MqttPacket stagingPacketOne(ProtocolVersion::Mqtt311, pubOne); if (from_qos > 0) stagingPacketOne.setPacketId(pack_id); CirBuf stagingBufOne(1024); stagingPacketOne.readIntoBuf(stagingBufOne); MqttPacket::bufferToMqttPackets(stagingBufOne, parsedPackets, dummyClient); QVERIFY(parsedPackets.size() == 1); MqttPacket parsedPacketOne = std::move(parsedPackets.front()); parsedPacketOne.parsePublishData(dummyClient); if (retain) // A normal handled packet always has retain=0, so I force setting it here. parsedPacketOne.setRetain(true); QCOMPARE(stagingPacketOne.getTopic(), parsedPacketOne.getTopic()); QCOMPARE(stagingPacketOne.getPayloadCopy(), parsedPacketOne.getPayloadCopy()); QCOMPARE(stagingPacketOne.getRetain(), parsedPacketOne.getRetain()); QCOMPARE(stagingPacketOne.getQos(), parsedPacketOne.getQos()); QCOMPARE(stagingPacketOne.first_byte, parsedPacketOne.first_byte); } } /** * @brief MainTests::testCopyPacket tests the actual bytes of a published packet that would be written to a client. */ void MainTests::testParsePacket() { for (int retain = 0; retain < 2; retain++) { testParsePacketHelper("John/McLane", 0, retain); testParsePacketHelper("Ben/Sisko", 1, retain); testParsePacketHelper("Rebecca/Bunch", 2, retain); testParsePacketHelper("Buffy/Slayer", 1, retain); testParsePacketHelper("Sarah/Connor", 2, retain); testParsePacketHelper("Susan/Mayer", 2, retain); } } /** * @brief MainTests::testbufferToMqttPacketsFuzz perform a quick fuzz on parsing MQTT packets. * * There was some chatter about this function maybe crashing, so this test was added. Nothing * could be reproduced, not even when running for hours. */ void MainTests::testbufferToMqttPacketsFuzz() { Logger::getInstance()->setFlags(LogLevel::None, false, false); Settings settings; settings.logLevel = LogLevel::Info; std::shared_ptr store(new SubscriptionStore()); std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); settings.maxPacketSize = 32768; std::shared_ptr dummyClient(new Client(ClientType::Normal, -1, t, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt311, "dummy", {}, "user1", true, 60); store->registerClientAndKickExistingOne(dummyClient, false, 512, 120); // To avoid the restriction on packet size for unauthenticated clients. dummyClient->setAuthenticated(true); // To avoid writing random MQTT headers that happen to say 'packet is 100 MB big' and will result in the parser // thinking we are supposed to get more bytes, which will make it get stuck. const ssize_t len = settings.maxPacketSize * 2; std::vector randombuf(len); CirBuf stagingBuf(len); size_t protocol_error_count = 0; size_t parsed_packet_count = 0; const auto then = std::chrono::steady_clock::now() + std::chrono::seconds(4); for (auto now = std::chrono::steady_clock::now(); now < then; now = std::chrono::steady_clock::now()) { if (getrandom(randombuf.data(), len, 0) != len) throw std::runtime_error("Random error"); std::vector parsedPackets; stagingBuf.ensureFreeSpace(len); stagingBuf.write(randombuf.data(), len); try { MqttPacket::bufferToMqttPackets(stagingBuf, parsedPackets, dummyClient); parsed_packet_count += parsedPackets.size(); FMQ_VERIFY(true); } catch (ProtocolError&) { stagingBuf.reset(); protocol_error_count++; FMQ_VERIFY(true); } catch (std::exception&) { FMQ_VERIFY(false); } } std::cout << std::endl << "Flash-fuzzing bufferToMqttPackets done. Parsed packets: " << parsed_packet_count << ". Protocol errors: " << protocol_error_count << std::endl; } void testDowngradeQoSOnSubscribeHelper(const uint8_t pub_qos, const uint8_t sub_qos) { std::vector protocols {ProtocolVersion::Mqtt31, ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; for (const ProtocolVersion senderVersion : protocols) { for (const ProtocolVersion receiverVersion : protocols) { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); const std::string topic("Star/Trek"); const std::string payload("Captain Kirk"); sender.connectClient(senderVersion); receiver.connectClient(receiverVersion); receiver.subscribe(topic, sub_qos); sender.publish(topic, payload, pub_qos); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MqttPacket &recv = ro->receivedPublishes.front(); const uint8_t expected_qos = std::min(pub_qos, sub_qos); QVERIFY2(recv.getQos() == expected_qos, formatString("Failure: received QoS is %d. Published is %d. Subscribed as %d. Expected QoS is %d", recv.getQos(), pub_qos, sub_qos, expected_qos).c_str()); QVERIFY(recv.getTopic() == topic); QVERIFY(recv.getPayloadCopy() == payload); } } } void MainTests::testDowngradeQoSOnSubscribeQos2to2() { testDowngradeQoSOnSubscribeHelper(2, 2); } void MainTests::testDowngradeQoSOnSubscribeQos2to1() { testDowngradeQoSOnSubscribeHelper(2, 1); } void MainTests::testDowngradeQoSOnSubscribeQos2to0() { testDowngradeQoSOnSubscribeHelper(2, 0); } void MainTests::testDowngradeQoSOnSubscribeQos1to1() { testDowngradeQoSOnSubscribeHelper(1, 1); } void MainTests::testDowngradeQoSOnSubscribeQos1to0() { testDowngradeQoSOnSubscribeHelper(1, 0); } void MainTests::testDowngradeQoSOnSubscribeQos0to0() { testDowngradeQoSOnSubscribeHelper(0, 0); } /** * @brief MainTests::testNotMessingUpQosLevels was divised because we optimize by preventing packet copies. This entails changing the vector of the original * incoming packet, resulting in possibly changing values like QoS levels for later subscribers. */ void MainTests::testNotMessingUpQosLevels() { const std::string topic = "HK7c1MFu6kdT69fWY"; const std::string payload = "M4XK2LZ2Smaazba8RobZOgoe6CENxCll"; std::list senderVersions {ProtocolVersion::Mqtt31, ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; std::list receiverVersions {ProtocolVersion::Mqtt31, ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}; for (ProtocolVersion senderVersion : senderVersions) { for (ProtocolVersion receiverVersion : receiverVersions) { FlashMQTestClient testContextSender; FlashMQTestClient testContextReceiver1; FlashMQTestClient testContextReceiver2; FlashMQTestClient testContextReceiver3; FlashMQTestClient testContextReceiver4; FlashMQTestClient testContextReceiver5; FlashMQTestClient testContextReceiverMqtt3; FlashMQTestClient testContextReceiverMqtt5; testContextReceiver1.start(); testContextReceiver1.connectClient(receiverVersion); testContextReceiver1.subscribe(topic, 0); testContextReceiver2.start(); testContextReceiver2.connectClient(receiverVersion); testContextReceiver2.subscribe(topic, 1); testContextReceiver3.start(); testContextReceiver3.connectClient(receiverVersion); testContextReceiver3.subscribe(topic, 2); testContextReceiver4.start(); testContextReceiver4.connectClient(receiverVersion); testContextReceiver4.subscribe(topic, 1); testContextReceiver5.start(); testContextReceiver5.connectClient(receiverVersion); testContextReceiver5.subscribe(topic, 0); testContextReceiverMqtt3.start(); testContextReceiverMqtt3.connectClient(ProtocolVersion::Mqtt311); testContextReceiverMqtt3.subscribe(topic, 0); testContextReceiverMqtt5.start(); testContextReceiverMqtt5.connectClient(ProtocolVersion::Mqtt5); testContextReceiverMqtt5.subscribe(topic, 0); testContextSender.start(); testContextSender.connectClient(senderVersion); testContextSender.publish(topic, payload, 2); testContextReceiver1.waitForMessageCount(1); testContextReceiver2.waitForMessageCount(1); testContextReceiver3.waitForMessageCount(1); testContextReceiver4.waitForMessageCount(1); testContextReceiver5.waitForMessageCount(1); auto testContextReceiver1_ro = testContextReceiver1.receivedObjects.lock(); auto testContextReceiver2_ro = testContextReceiver2.receivedObjects.lock(); auto testContextReceiver3_ro = testContextReceiver3.receivedObjects.lock(); auto testContextReceiver4_ro = testContextReceiver4.receivedObjects.lock(); auto testContextReceiver5_ro = testContextReceiver5.receivedObjects.lock(); auto testContextReceiverMqtt3_ro = testContextReceiverMqtt3.receivedObjects.lock(); auto testContextReceiverMqtt5_ro = testContextReceiverMqtt5.receivedObjects.lock(); MYCASTCOMPARE(testContextReceiver1_ro->receivedPublishes.size(), 1); MYCASTCOMPARE(testContextReceiver2_ro->receivedPublishes.size(), 1); MYCASTCOMPARE(testContextReceiver3_ro->receivedPublishes.size(), 1); MYCASTCOMPARE(testContextReceiver4_ro->receivedPublishes.size(), 1); MYCASTCOMPARE(testContextReceiver5_ro->receivedPublishes.size(), 1); MYCASTCOMPARE(testContextReceiverMqtt3_ro->receivedPublishes.size(), 1); MYCASTCOMPARE(testContextReceiverMqtt5_ro->receivedPublishes.size(), 1); QCOMPARE(testContextReceiver1_ro->receivedPublishes.front().getQos(), 0); QCOMPARE(testContextReceiver2_ro->receivedPublishes.front().getQos(), 1); QCOMPARE(testContextReceiver3_ro->receivedPublishes.front().getQos(), 2); QCOMPARE(testContextReceiver4_ro->receivedPublishes.front().getQos(), 1); QCOMPARE(testContextReceiver5_ro->receivedPublishes.front().getQos(), 0); QCOMPARE(testContextReceiverMqtt3_ro->receivedPublishes.front().getQos(), 0); QCOMPARE(testContextReceiverMqtt5_ro->receivedPublishes.front().getQos(), 0); QCOMPARE(testContextReceiver1_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiver2_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiver3_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiver4_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiver5_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiverMqtt3_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiverMqtt5_ro->receivedPublishes.front().getPayloadCopy(), payload); QCOMPARE(testContextReceiver2_ro->receivedPublishes.front().getPacketId(), 1); QCOMPARE(testContextReceiver3_ro->receivedPublishes.front().getPacketId(), 1); QCOMPARE(testContextReceiver4_ro->receivedPublishes.front().getPacketId(), 1); } } } void MainTests::testUnSubscribe() { FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe("Rebecca/Bunch", 2); receiver.subscribe("Josh/Chan", 1); receiver.subscribe("White/Josh", 1); sender.publish("Rebecca/Bunch", "Bunch here", 2); sender.publish("White/Josh", "Anteater", 2); sender.publish("Josh/Chan", "Human flip-flop", 2); receiver.waitForMessageCount(3); { auto ro = receiver.receivedObjects.lock(); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "Bunch here" && pack.getTopic() == "Rebecca/Bunch"; })); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "Anteater" && pack.getTopic() == "White/Josh"; })); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "Human flip-flop" && pack.getTopic() == "Josh/Chan"; })); MYCASTCOMPARE(ro->receivedPublishes.size(), 3); } receiver.clearReceivedLists(); receiver.unsubscribe("Josh/Chan"); sender.publish("Rebecca/Bunch", "Bunch here", 2); sender.publish("White/Josh", "Anteater", 2); sender.publish("Josh/Chan", "Human flip-flop", 2); receiver.waitForMessageCount(2); { auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 2); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "Bunch here" && pack.getTopic() == "Rebecca/Bunch"; })); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](const MqttPacket &pack) { return pack.getPayloadCopy() == "Anteater" && pack.getTopic() == "White/Josh"; })); } } void MainTests::testUnsubscribeNonExistingWildcard() { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); sender.unsubscribe("#"); FMQ_VERIFY(true); // I just wanted to test asserts in the code } /** * @brief MainTests::testBasicsWithFlashMQTestClient was used to develop FlashMQTestClient. */ void MainTests::testBasicsWithFlashMQTestClient() { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt311); { auto ro = client.receivedObjects.lock(); MqttPacket &connAckPack = ro->receivedPackets.at(0); QVERIFY(connAckPack.packetType == PacketType::CONNACK); } { client.subscribe("a/b", 1); auto ro = client.receivedObjects.lock(); MqttPacket &subAck = ro->receivedPackets.at(0); SubAckData subAckData = subAck.parseSubAckData(); QVERIFY(subAckData.subAckCodes.size() == 1); QVERIFY(subAckData.subAckCodes.front() == 1); } { client.subscribe("c/d", 2); auto ro = client.receivedObjects.lock(); MqttPacket &subAck = ro->receivedPackets.at(0); SubAckData subAckData = subAck.parseSubAckData(); QVERIFY(subAckData.subAckCodes.size() == 1); QVERIFY(subAckData.subAckCodes.front() == 2); } client.clearReceivedLists(); FlashMQTestClient publisher; publisher.start(); publisher.connectClient(ProtocolVersion::Mqtt5); { publisher.publish("a/b", "wave", 2); client.waitForMessageCount(1); auto ro = client.receivedObjects.lock(); MqttPacket &p = ro->receivedPublishes.at(0); QCOMPARE(p.getPublishData().topic, "a/b"); QCOMPARE(p.getPayloadCopy(), "wave"); QCOMPARE(p.getPublishData().qos, 1); QVERIFY(p.getPacketId() > 0); QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311); } client.clearReceivedLists(); { publisher.publish("c/d", "asdfasdfasdf", 2); client.waitForMessageCount(1); auto ro = client.receivedObjects.lock(); MqttPacket &p = ro->receivedPublishes.back(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QCOMPARE(p.getPublishData().topic, "c/d"); QCOMPARE(p.getPayloadCopy(), "asdfasdfasdf"); QCOMPARE(p.getPublishData().qos, 2); QVERIFY(p.getPacketId() > 1); // It's the same client, so it should not re-use packet id 1 QVERIFY(p.protocolVersion == ProtocolVersion::Mqtt311); } } void MainTests::testDontRemoveSessionGivenToNewClientWithSameId() { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311, true, 60, [] (Connect &c) { c.clientid = "Sandra-nonrandom"; }); receiver.subscribe("just/a/path", 0); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); { Publish pub("just/a/path", "AAAAA", 0); pub.topicAlias = 1; sender.publish(pub); } { receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); const MqttPacket &pack1 = ro->receivedPublishes.at(0); QCOMPARE(pack1.getTopic(), "just/a/path"); QCOMPARE(pack1.getPayloadCopy(), "AAAAA"); } FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt311, true, 60, [] (Connect &c) { c.clientid = "Sandra-nonrandom"; }); receiver2.subscribe("just/a/path", 0); { Publish pub("just/a/path", "AAAAA", 0); pub.topicAlias = 1; sender.publish(pub); } { try { receiver2.waitForMessageCount(1); } catch(std::exception &ex) { QFAIL("The second subscriber did not get the message, so the subscription failed."); } auto ro = receiver2.receivedObjects.lock(); const MqttPacket &pack1 = ro->receivedPublishes.at(0); QCOMPARE(pack1.getTopic(), "just/a/path"); QCOMPARE(pack1.getPayloadCopy(), "AAAAA"); } } void MainTests::testKeepSubscriptionOnKickingOutExistingClientWithCleanSessionFalse() { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311, false, 60, [] (Connect &c) { c.clientid = "Carl-nonrandom"; }); receiver.subscribe("just/a/path", 0); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt311, false, 60, [] (Connect &c) { c.clientid = "Carl-nonrandom"; }); { Publish pub("just/a/path", "AAAAA", 0); pub.topicAlias = 1; sender.publish(pub); } { try { receiver2.waitForMessageCount(1); } catch(std::exception &ex) { QFAIL("The second subscriber did not get the message, so the subscription failed."); } auto ro = receiver2.receivedObjects.lock(); const MqttPacket &pack1 = ro->receivedPublishes.at(0); QCOMPARE(pack1.getTopic(), "just/a/path"); QCOMPARE(pack1.getPayloadCopy(), "AAAAA"); } } void MainTests::testPickUpSessionWithSubscriptionsAfterDisconnect() { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311, false, 60, [] (Connect &c) { c.clientid = "Hungus-nonrandom"; }); receiver.subscribe("just/a/path", 0); receiver.disconnect(ReasonCodes::Success); usleep(50000); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt311, false, 60, [] (Connect &c) { c.clientid = "Hungus-nonrandom"; }); { Publish pub("just/a/path", "AAAAAB", 0); pub.topicAlias = 1; sender.publish(pub); } { try { receiver2.waitForMessageCount(1); } catch(std::exception &ex) { QFAIL("The second subscriber did not get the message, so the subscription failed."); } auto ro = receiver2.receivedObjects.lock(); const MqttPacket &pack1 = ro->receivedPublishes.at(0); QCOMPARE(pack1.getTopic(), "just/a/path"); QCOMPARE(pack1.getPayloadCopy(), "AAAAAB"); } } void MainTests::testIncomingTopicAlias() { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("just/a/path", 0); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); { Publish pub("just/a/path", "AAAAA", 0); pub.topicAlias = 1; sender.publish(pub); } { Publish pub2("", "BBBBB", 0); pub2.topicAlias = 1; sender.publish(pub2); } receiver.waitForMessageCount(2); auto ro = receiver.receivedObjects.lock(); const MqttPacket &pack1 = ro->receivedPublishes.at(0); const MqttPacket &pack2 = ro->receivedPublishes.at(1); QCOMPARE(pack1.getTopic(), "just/a/path"); QCOMPARE(pack1.getPayloadCopy(), "AAAAA"); QCOMPARE(pack2.getTopic(), "just/a/path"); QCOMPARE(pack2.getPayloadCopy(), "BBBBB"); } void MainTests::testOutgoingTopicAlias() { FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5, true, 300, [](Connect &connect){ connect.maxIncomingTopicAliasValue = 10; }); receiver1.subscribe("don't/be/a/laywer", 0); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt5); receiver2.subscribe("don't/be/a/laywer", 0); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); sender.publish("don't/be/a/laywer", "ABCDEF", 0); sender.publish("don't/be/a/laywer", "ABCDEF", 0); receiver1.waitForMessageCount(2); receiver2.waitForMessageCount(2); { auto ro = receiver1.receivedObjects.lock(); const MqttPacket &fullPacket = ro->receivedPublishes.at(0); QCOMPARE(fullPacket.getTopic(), "don't/be/a/laywer"); QCOMPARE(fullPacket.getPayloadCopy(), "ABCDEF"); MYCASTCOMPARE(fullPacket.bites.size(), 31); std::string arrayContent(fullPacket.bites.data(), fullPacket.bites.size()); QVERIFY(strContains(arrayContent, "don't/be/a/laywer")); } { auto ro = receiver1.receivedObjects.lock(); const MqttPacket &shorterPacket = ro->receivedPublishes.at(1); QCOMPARE(shorterPacket.getTopic(), "don't/be/a/laywer"); QCOMPARE(shorterPacket.getPayloadCopy(), "ABCDEF"); MYCASTCOMPARE(shorterPacket.bites.size(), 14); std::string arrayContent(shorterPacket.bites.data(), shorterPacket.bites.size()); QVERIFY(!strContains(arrayContent, "don't/be/a/laywer")); } { auto ro = receiver2.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 2); std::for_each(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &packet) { QCOMPARE(packet.getTopic(), "don't/be/a/laywer"); QCOMPARE(packet.getPayloadCopy(), "ABCDEF"); MYCASTCOMPARE(packet.bites.size(), 28); // That's 3 less than the other one, because the alias id is not there. }); } } void MainTests::testOutgoingTopicAliasBeyondMax() { FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5, true, 300, [](Connect &connect){ connect.maxIncomingTopicAliasValue = 5; }); receiver1.subscribe("+/bottles/of/beer/on/the/wall/take/one/down/pass/it/around", 0); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); // Set all the aliases with first publishes. for (int i = 0; i < 7; i++) { sender.publish(std::to_string(i) + "/bottles/of/beer/on/the/wall/take/one/down/pass/it/around", "ABCDEF", 0); } receiver1.waitForMessageCount(7); { auto ro = receiver1.receivedObjects.lock(); for (int i = 0; i < 7; i++) { auto &packet = ro->receivedPublishes.at(i); FMQ_COMPARE(packet.getTopic(), std::to_string(i) + "/bottles/of/beer/on/the/wall/take/one/down/pass/it/around"); size_t expected_size = i < 5 ? 72 : 69; // The ones with a topic alias in them are slightly bigger. FMQ_COMPARE(packet.bites.size(), expected_size); } } receiver1.clearReceivedLists(); // Now again, which means the aliases should be used for the known topics. for (int i = 0; i < 7; i++) { sender.publish(std::to_string(i) + "/bottles/of/beer/on/the/wall/take/one/down/pass/it/around", "ABCDEF", 0); } receiver1.waitForMessageCount(7); { auto ro = receiver1.receivedObjects.lock(); // Now the first give should be smaller, and the last two normal (no topic alias property and the topic string included). for (int i = 0; i < 7; i++) { auto &packet = ro->receivedPublishes.at(i); FMQ_COMPARE(packet.getTopic(), std::to_string(i) + "/bottles/of/beer/on/the/wall/take/one/down/pass/it/around"); size_t expected_size = i < 5 ? 14 : 69; FMQ_COMPARE(packet.bites.size(), expected_size); } } } void MainTests::testOutgoingTopicAliasStoredPublishes() { std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "last/dance/long/long"; will->payload = "will payload"; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver1; receiver1.start(); receiver1.connectClient(ProtocolVersion::Mqtt5, true, 300, [](Connect &connect){ connect.maxIncomingTopicAliasValue = 10; }); receiver1.subscribe("last/dance/#", 0); FlashMQTestClient sender2; sender2.start(); sender2.connectClient(ProtocolVersion::Mqtt5); // Establish the first time use of the topic for the alias. sender2.publish("last/dance/long/long", "normal payload", 0); receiver1.waitForMessageCount(1); { auto ro = receiver1.receivedObjects.lock(); QCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), "normal payload"); MYCASTCOMPARE(ro->receivedPublishes.front().bites.size(), 42); } receiver1.clearReceivedLists(); // This will send a will, which should re-use the alias. sender.reset(); receiver1.waitForMessageCount(1); { auto ro = receiver1.receivedObjects.lock(); QCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), "will payload"); QVERIFY(ro->receivedPublishes.front().bites.size() < 35); } } void MainTests::testReceivingRetainedMessageWithQoS() { int testCount = 0; for (uint8_t sendQos = 0; sendQos < 3; sendQos++) { for (uint8_t subscribeQos = 0; subscribeQos < 3; subscribeQos++) { testCount++; FlashMQTestClient sender; sender.start(); const std::string payload = "We are testing"; sender.connectClient(ProtocolVersion::Mqtt311); Publish p1("topic1/FOOBAR", payload, sendQos); p1.retain = true; sender.publish(p1); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("+/+", subscribeQos); receiver.waitForMessageCount(1); const uint8_t expQos = std::min(sendQos, subscribeQos); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); MYCASTCOMPARE(ro->receivedPublishes.front().getQos(), expQos); MYCASTCOMPARE(ro->receivedPublishes.front().getTopic(), "topic1/FOOBAR"); MYCASTCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), payload); MYCASTCOMPARE(ro->receivedPublishes.front().getRetain(), true); } } MYCASTCOMPARE(9, testCount); } void MainTests::testQosDowngradeOnOfflineClients() { int testCount = 0; std::vector subscribePaths {"topic1/FOOBAR", "+/+", "#"}; for (uint8_t sendQos = 1; sendQos < 3; sendQos++) { for (uint8_t subscribeQos = 1; subscribeQos < 3; subscribeQos++) { for (const std::string &subscribePath : subscribePaths) { testCount++; // First start with clean_start to reset the session. std::unique_ptr receiver = std::make_unique(); receiver->start(); receiver->connectClient(ProtocolVersion::Mqtt5, true, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; }); receiver->subscribe(subscribePath, subscribeQos); receiver->disconnect(ReasonCodes::Success); receiver.reset(); const std::string payload = "We are testing"; FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); Publish p1("topic1/FOOBAR", payload, sendQos); for (int i = 0; i < 10; i++) { sender.publish(p1); } // Now we connect again, and we should now pick up the existing session. receiver = std::make_unique(); receiver->start(); receiver->connectClient(ProtocolVersion::Mqtt5, false, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; }); receiver->waitForMessageCount(10); const uint8_t expQos = std::min(sendQos, subscribeQos); { auto ro = receiver->receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 10); QVERIFY(std::all_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getQos() == expQos;})); QVERIFY(std::all_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getTopic() == "topic1/FOOBAR";})); QVERIFY(std::all_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [&](MqttPacket &pack) { return pack.getPayloadCopy() == payload;})); } receiver.reset(); // Now we connect again, and we should get nothing receiver = std::make_unique(); receiver->start(); receiver->connectClient(ProtocolVersion::Mqtt5, false, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; }); usleep(100000); auto ro = receiver->receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } } } MYCASTCOMPARE(12, testCount); } void MainTests::testUserProperties() { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver5; receiver5.start(); receiver5.connectClient(ProtocolVersion::Mqtt5); receiver5.subscribe("#", 1); FlashMQTestClient receiver3; receiver3.start(); receiver3.connectClient(ProtocolVersion::Mqtt311); receiver3.subscribe("#", 1); Publish pub("I'm/going/to/leave/a/message", "boo", 1); pub.addUserProperty("mykey", "myval"); pub.addUserProperty("mykeyhaha", "myvalhaha"); sender.publish(pub); receiver5.waitForMessageCount(1); receiver3.waitForMessageCount(1); { auto ro5 = receiver5.receivedObjects.lock(); MqttPacket &pack5 = ro5->receivedPublishes.front(); const std::vector> *properties5 = pack5.getUserProperties(); QVERIFY(properties5); MYCASTCOMPARE(properties5->size(), 2); const std::pair &firstPair = properties5->operator[](0); const std::pair &secondPair = properties5->operator[](1); QCOMPARE(firstPair.first, "mykey"); QCOMPARE(firstPair.second, "myval"); QCOMPARE(secondPair.first, "mykeyhaha"); QCOMPARE(secondPair.second, "myvalhaha"); } { auto ro3 = receiver3.receivedObjects.lock(); MqttPacket &pack3 = ro3->receivedPublishes.front(); const std::vector> *properties3 = pack3.getUserProperties(); QVERIFY(properties3 == nullptr); } } void MainTests::testMessageExpiry() { std::unique_ptr receiver; auto makeReceiver = [&](){ receiver = std::make_unique(); receiver->start(); receiver->connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &c) { c.clientid = "pietje"; }); }; makeReceiver(); receiver->subscribe("#", 2); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5, true, 120); Publish publishMe("a/b/c/d/e", "smoke", 1); publishMe.setExpireAfter(1); sender.publish(publishMe); // First we test that a message with an expiry set is still immediately delivered to active clients. receiver->waitForMessageCount(1); { auto ro = receiver->receivedObjects.lock(); QVERIFY(ro->receivedPublishes.size() == 1); QCOMPARE(ro->receivedPublishes.front().getTopic(), "a/b/c/d/e"); QCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), "smoke"); } // Then we test delivering it to an offline client and see if we get it if we are fast enough. receiver.reset(); publishMe.setExpireAfter(1); sender.publish(publishMe); usleep(300000); makeReceiver(); receiver->waitForMessageCount(1); { auto ro = receiver->receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); QCOMPARE(ro->receivedPublishes.front().getTopic(), "a/b/c/d/e"); QCOMPARE(ro->receivedPublishes.front().getPayloadCopy(), "smoke"); } // Then we test delivering it to an offline client that comes back too late. publishMe.setExpireAfter(1); receiver.reset(); sender.publish(publishMe); usleep(2100000); makeReceiver(); usleep(100000); receiver->waitForMessageCount(0); auto ro = receiver->receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } /** * @brief MainTests::testExpiredQueuedMessages Tests whether expiring messages clear out and make room in the send quota / receive maximum. */ void MainTests::testExpiredQueuedMessages() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("max_qos_msg_pending_per_client 32"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; for (ProtocolVersion version : versions) { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5, true, 120); std::unique_ptr receiver; auto makeReceiver = [&]() { receiver = std::make_unique(); receiver->start(); receiver->connectClient(version, false, 120, [](Connect &c) { c.clientid = "ReceiverDude01"; }); }; makeReceiver(); receiver->subscribe("#", 2); receiver.reset(); Publish expiringPub("this/one/expires", "Calimex", 1); expiringPub.setExpireAfter(1); sender.publish(expiringPub); for (int i = 0; i < 35; i++) { Publish normalPub(formatString("topic/%d", i), "adsf", 1); sender.publish(normalPub); } makeReceiver(); try { receiver->waitForMessageCount(32); } catch (std::exception &ex) { auto ro = receiver->receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 32); } { auto ro = receiver->receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 32); QVERIFY(std::any_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return p.getTopic() == "this/one/expires"; })); } receiver.reset(); for (int i = 0; i < 20; i++) { Publish normalPub(formatString("late/topic/%d", i), "adsf", 1); sender.publish(normalPub); } sender.publish(expiringPub); usleep(2000000); // Now that we've waited, there should be an extra spot. for (int i = 0; i < 15; i++) { Publish normalPub(formatString("topic/%d", i), "adsf", 1); sender.publish(normalPub); } makeReceiver(); try { receiver->waitForMessageCount(32); } catch (std::exception &ex) { auto ro = receiver->receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 32); } auto ro = receiver->receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 32); QVERIFY(std::none_of(ro->receivedPublishes.begin(), ro->receivedPublishes.end(), [](MqttPacket &p) { return p.getTopic() == "this/one/expires"; })); int pi = 0; for (int i = 0; i < 20; i++) { QCOMPARE(ro->receivedPublishes[pi++].getTopic(), formatString("late/topic/%d", i)); } for (int i = 0; i < 12; i++) { QCOMPARE(ro->receivedPublishes[pi++].getTopic(), formatString("topic/%d", i)); } MYCASTCOMPARE(pi, 32); } } /** * @brief MainTests::testQoSPublishQueue tests the queue and it's insertion order linked list. */ void MainTests::testQoSPublishQueue() { QoSPublishQueue q; uint16_t id = 1; std::shared_ptr qp; { Publish p1("one", "onepayload", 1); q.queuePublish(std::move(p1), id++, std::optional()); qp = q.popNext(); QVERIFY(qp); QCOMPARE(qp->getPublish().topic, "one"); QCOMPARE(qp->getPublish().payload, "onepayload"); qp = q.popNext(); QVERIFY(!qp); } { Publish p1("two", "asdf", 1); Publish p2("three", "wer", 1); q.queuePublish(std::move(p1), id++, std::optional()); q.queuePublish(std::move(p2), id++, std::optional()); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "two"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "three"); qp = q.popNext(); QVERIFY(!qp); } { Publish p1("four", "asdf", 1); Publish p2("five", "wer", 1); Publish p3("six", "wer", 1); q.queuePublish(std::move(p1), id++, std::optional()); q.queuePublish(std::move(p2), id++, std::optional()); q.queuePublish(std::move(p3), id++, std::optional()); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "four"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "five"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "six"); qp = q.popNext(); QVERIFY(!qp); } // Remove middle { uint16_t idToRemove = 0; Publish p1("seven", "asdf", 1); Publish p2("eight", "wer", 1); Publish p3("nine", "wer", 1); q.queuePublish(std::move(p1), id++, std::optional()); idToRemove = id; q.queuePublish(std::move(p2), id++, std::optional()); q.queuePublish(std::move(p3), id++, std::optional()); q.erase(idToRemove); Publish p4("tool2eW7", "wer", 1); q.queuePublish(std::move(p4), id++, std::optional()); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "seven"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "nine"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "tool2eW7"); qp = q.popNext(); QVERIFY(!qp); } // Remove first { uint16_t idToRemove = 0; Publish p1("ten", "asdf", 1); Publish p2("eleven", "wer", 1); Publish p3("twelve", "wer", 1); idToRemove = id; q.queuePublish(std::move(p1), id++, std::optional()); q.queuePublish(std::move(p2), id++, std::optional()); q.queuePublish(std::move(p3), id++, std::optional()); q.erase(idToRemove); Publish p4("iew2Bie1", "wer", 1); q.queuePublish(std::move(p4), id++, std::optional()); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "eleven"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "twelve"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "iew2Bie1"); qp = q.popNext(); QVERIFY(!qp); } // Remove last { uint16_t idToRemove = 0; Publish p1("13", "asdf", 1); Publish p2("14", "wer", 1); Publish p3("15", "wer", 1); q.queuePublish(std::move(p1), id++, std::optional()); q.queuePublish(std::move(p2), id++, std::optional()); idToRemove = id; q.queuePublish(std::move(p3), id++, std::optional()); q.erase(idToRemove); Publish p4("16", "wer", 1); q.queuePublish(std::move(p4), id++, std::optional()); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "13"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "14"); qp = q.popNext(); QCOMPARE(qp->getPublish().topic, "16"); qp = q.popNext(); QVERIFY(!qp); } } /** * relies on the leak sanitizer as part of ASAN. */ void MainTests::testQoSPublishQueueMemoryLeak() { QoSPublishQueue q; uint16_t id = 1; Publish p1("one", "onepayload", 1); q.queuePublish(std::move(p1), id++, std::optional()); Publish p2("one", "onepayload", 1); q.queuePublish(std::move(p2), id++, std::optional()); FMQ_VERIFY(q.size() == 2); } void MainTests::testTimePointToAge() { std::chrono::time_point now = std::chrono::steady_clock::now(); auto past = now - std::chrono::seconds(11); auto p = ageFromTimePoint(past); auto pastCalculated = timepointFromAge(p); std::chrono::seconds diff = std::chrono::duration_cast(pastCalculated - past); MYCASTCOMPARE(diff.count(), 0); } void MainTests::testMosquittoPasswordFile() { std::vector versions { ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5 }; ConfFileTemp passwd_file; passwd_file.writeLine("one:$6$JCNyGIZwxpaB++iTCwiT2e80YX6mEFymRCkRpHkm50dNP8IfHMWz97BdadZVsZCCC9yr7/OXxAbdfAVk71xqyA==$AL25hdhMm0CkQ3/nxtgGJ96xfSv6hCAf7aHZby8mZWnkNxmvRnuu6fHWi6yvyr1EjPD4P9vmIvKwqvdKEVDLLQ=="); passwd_file.writeLine("two:$7$101$QVgLoPCu8Lb9A6HRYFhcsIsYqE1QR5elwDr7oioyNw7n5OMqdpM0Xk+Iacbj+ZvXiIVihFYEVDgJMkr8vAR08A==$xTJ1tbPTZcaJH+ie9gXUDumHqdJYpGCMXW/asC/qMrdobawqU2tpBHzvJnm2VfsYCwgchOCegI8RvYt1IAUivg=="); passwd_file.closeFile(); ConfFileTemp confFile; confFile.writeLine(formatString("mosquitto_password_file %s", passwd_file.getFilePath().c_str())); confFile.writeLine("allow_anonymous false"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "one"; }); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "two"; connect.password = "two"; }); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "two"; connect.password = "wrongpasswordforexistinguser"; }); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "erqwerq"; connect.password = "nonexistinguser"; }); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } void MainTests::testOverrideAllowAnonymousToTrue() { ConfFileTemp passwd_file; passwd_file.writeLine("one:$6$JCNyGIZwxpaB++iTCwiT2e80YX6mEFymRCkRpHkm50dNP8IfHMWz97BdadZVsZCCC9yr" "7/OXxAbdfAVk71xqyA==$AL25hdhMm0CkQ3/nxtgGJ96xfSv6hCAf7aHZby8mZWnkNxmvRnuu6f" "HWi6yvyr1EjPD4P9vmIvKwqvdKEVDLLQ=="); passwd_file.closeFile(); ConfFileTemp confFile; confFile.writeLine(formatString("mosquitto_password_file %s", passwd_file.getFilePath().c_str())); confFile.writeLine("allow_anonymous false"); confFile.writeLine(R"( listen { protocol mqtt port 2883 allow_anonymous true } listen { protocol mqtt port 2884 })"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); // With users, both listeners should act the same. for (int port : {2883, 2884}) { { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "one"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "wrong"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2883); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2884); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "doesntexist"; connect.password = "asdf"; }, 2883); // allow_anonymous true override auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "doesntexist"; connect.password = "asdf"; }, 2884); // allow_anonymous false global setting. auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } // Test empty password for existing user. It should not think it's anonymous. { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = ""; }, 2883); // allow_anonymous true global setting. auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } void MainTests::testOverrideAllowAnonymousToFalse() { ConfFileTemp passwd_file; passwd_file.writeLine("one:$6$JCNyGIZwxpaB++iTCwiT2e80YX6mEFymRCkRpHkm50dNP8IfHMWz97BdadZVsZCCC9yr" "7/OXxAbdfAVk71xqyA==$AL25hdhMm0CkQ3/nxtgGJ96xfSv6hCAf7aHZby8mZWnkNxmvRnuu6f" "HWi6yvyr1EjPD4P9vmIvKwqvdKEVDLLQ=="); passwd_file.closeFile(); ConfFileTemp confFile; confFile.writeLine(formatString("mosquitto_password_file %s", passwd_file.getFilePath().c_str())); confFile.writeLine("allow_anonymous true"); confFile.writeLine("zero_byte_username_is_anonymous true"); confFile.writeLine(R"( listen { protocol mqtt port 2883 allow_anonymous false } listen { protocol mqtt port 2884 })"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); // With users, both listeners should act the same. for (int port : {2883, 2884}) { { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "one"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "wrong"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2883); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2884); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } // Test 'zero_byte_username_is_anonymous true' { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt311, false, 120, [](Connect &connect) { connect.username = ""; connect.password = "wrong"; }, 2884); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "doesntexist"; connect.password = "asdf"; }, 2883); // allow_anonymous false override auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "doesntexist"; connect.password = "asdf"; }, 2884); // allow_anonymous true global setting. auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } } void MainTests::testKeepAllowAnonymousFalse() { ConfFileTemp passwd_file; passwd_file.writeLine("one:$6$JCNyGIZwxpaB++iTCwiT2e80YX6mEFymRCkRpHkm50dNP8IfHMWz97BdadZVsZCCC9yr" "7/OXxAbdfAVk71xqyA==$AL25hdhMm0CkQ3/nxtgGJ96xfSv6hCAf7aHZby8mZWnkNxmvRnuu6f" "HWi6yvyr1EjPD4P9vmIvKwqvdKEVDLLQ=="); passwd_file.closeFile(); ConfFileTemp confFile; confFile.writeLine(formatString("mosquitto_password_file %s", passwd_file.getFilePath().c_str())); confFile.writeLine("allow_anonymous false"); confFile.writeLine(R"( listen { protocol mqtt port 2883 allow_anonymous false } listen { protocol mqtt port 2884 })"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); // With users, both listeners should act the same. for (int port : {2883, 2884}) { { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "one"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "wrong"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2883); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2884); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } void MainTests::testAllowAnonymousWithoutPasswordsLoaded() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous true"); confFile.writeLine(R"( listen { protocol mqtt port 2883 allow_anonymous true } listen { protocol mqtt port 2884 } listen { protocol mqtt port 2885 allow_anonymous false })"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); // With users, both listeners should act the same. for (int port : {2883, 2884}) { { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "one"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &connect) { connect.username = "one"; connect.password = "wrong"; }, port); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2883); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2884); auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::Success); } { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5, 2885); // allow_anonymous false auto ro = client.receivedObjects.lock(); auto ack = ro->receivedPackets.front(); ConnAckData ackData = ack.parseConnAckData(); QCOMPARE(ackData.reasonCode, ReasonCodes::NotAuthorized); } } void MainTests::testAddrMatchesSubnetIpv4() { struct sockaddr_in reference_addr; memset(&reference_addr, 0, sizeof(struct sockaddr_in)); reference_addr.sin_family = AF_INET; QVERIFY(inet_pton(AF_INET, "12.13.14.15", &reference_addr.sin_addr) == 1); const std::list positives = {"12.13.14.15", "12.13.14.15/24", "12.13.14.00/24", "0.0.0.0/0"}; const std::list negatives = {"12.13.99.00/24", "12.13.14.16/32", "11.13.14.15/32", "12.13.14.16"}; for(const std::string &network : positives) { Network net(network); QVERIFY2(net.match(&reference_addr), formatString("'%s' failed", network.c_str()).c_str()); } for(const std::string &network : negatives) { Network net(network); QVERIFY2(!net.match(&reference_addr), formatString("'%s' failed", network.c_str()).c_str()); } } void MainTests::testAddrMatchesSubnetIpv6() { struct sockaddr_in6 reference_addr; memset(&reference_addr, 0, sizeof(struct sockaddr_in6)); reference_addr.sin6_family = AF_INET6; QVERIFY(inet_pton(AF_INET6, "2001:db8::1", &reference_addr.sin6_addr) == 1); const std::list positives = {"2001:db8::1", "2001:db8:0::1", "2001:db8::1337:2/64", "2001:db8:2001:db8:1337::2/32", "3001:db8:2001:db8:1337::2/0"}; const std::list negatives = {"2002:db8::1", "2001:db8::2", "2003:db8::1/64"}; for(const std::string &network : positives) { Network net(network); QVERIFY2(net.match(&reference_addr), formatString("Equality '%s' failed", network.c_str()).c_str()); } for(const std::string &network : negatives) { Network net(network); QVERIFY2(!net.match(&reference_addr), formatString("Inequality '%s' failed", network.c_str()).c_str()); } for (int i = 0; i < 128; i++) { std::string addressWithMask = formatString("2001:db8:0::1/%d", i); Network net(addressWithMask); QVERIFY2(net.match(&reference_addr), formatString("Fail mask range equal check: %s", addressWithMask.c_str()).c_str()); } for (int i = 1; i < 128; i++) { std::string addressWithMask = formatString("f001:db8:0::1/%d", i); Network net(addressWithMask); Network netCopy(net); QVERIFY2(!netCopy.match(&reference_addr), formatString("Fail mask range not equal check: %s", addressWithMask.c_str()).c_str()); } } void MainTests::testPublishToItself() { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5); client.subscribe("mytopic", 1, false); try { client.publish("mytopic", "mypayload", 1); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } try { client.waitForMessageCount(1); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } auto ro = client.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); } void MainTests::testNoLocalPublishToItself() { FlashMQTestClient client; client.start(); client.connectClient(ProtocolVersion::Mqtt5); client.subscribe("mytopic", 1, true); try { client.publish("mytopic", "mypayload", 1); } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } usleep(1000000); auto ro = client.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 0); } void MainTests::testTopicMatchingInSubscriptionTreeHelper(const std::string &subscribe_topic, const std::string &publish_topic, int match_count) { if (!isValidSubscribePath(subscribe_topic)) throw std::runtime_error("invalid test: subscribe topic invalid"); if (!isValidPublishPath(publish_topic)) throw std::runtime_error("Invalid test: invalid publish topic: " + publish_topic); SubscriptionStore store; const std::vector subscribe_subtopics = splitTopic(subscribe_topic); const std::vector publish_subtopics = splitTopic(publish_topic); std::shared_ptr td; const Settings *settings = ThreadGlobals::getSettings(); std::shared_ptr client = std::make_shared(ClientType::Normal, -1, td, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, *settings, false); client->setClientProperties(ProtocolVersion::Mqtt5, "mytestclient", {}, "myusername", true, 60); store.registerClientAndKickExistingOne(client); store.addSubscription(client->getSession(), subscribe_subtopics, 0, false, false, "", 0); std::vector receivers; store.publishRecursively(publish_subtopics.begin(), publish_subtopics.end(), store.root.get(), receivers, "fakeclientid", {}); QVERIFY2(std::distance(receivers.begin(), receivers.end()) == match_count, publish_topic.c_str()); } void MainTests::testTopicMatchingInSubscriptionTree() { testTopicMatchingInSubscriptionTreeHelper("match/this/topic/#", "match/this/topic/wer"); testTopicMatchingInSubscriptionTreeHelper("match/this/topic/#", "match/this/topic/wer/zxcvzxcv"); testTopicMatchingInSubscriptionTreeHelper("match/+/topic/#", "match/this/topic/wer/zxcvzxcv"); testTopicMatchingInSubscriptionTreeHelper("match/+/topic/+", "match/wer/topic/bbb"); testTopicMatchingInSubscriptionTreeHelper("match/+/topic/+", "match/wer/topic/bbb/eee", 0); testTopicMatchingInSubscriptionTreeHelper("match/+/topic/+/wer/#", "match/uwoeirv.m/topic/bbb/eee", 0); testTopicMatchingInSubscriptionTreeHelper("#", "match/wer/topic/bbb/eee"); testTopicMatchingInSubscriptionTreeHelper("+/+/topic/+", "zcccvcv/wer/topic/haha"); testTopicMatchingInSubscriptionTreeHelper("a/b/c/d/#", "a/b/c/d/e"); testTopicMatchingInSubscriptionTreeHelper("a/#", "a"); testTopicMatchingInSubscriptionTreeHelper("a/#", "a/b"); testTopicMatchingInSubscriptionTreeHelper("/#", "/b"); testTopicMatchingInSubscriptionTreeHelper("/#", "/b/wer"); testTopicMatchingInSubscriptionTreeHelper("a/b/c/d/#", "a/b/c/d/e/f"); testTopicMatchingInSubscriptionTreeHelper("a/b/c/d/#", "a/b/c/", 0); testTopicMatchingInSubscriptionTreeHelper("a/b/c/d/#", "a/b/c", 0); testTopicMatchingInSubscriptionTreeHelper("a/b/c/d/#", "a/b/c/d"); // Taken from testTopicsMatch(), but that test will be removed. testTopicMatchingInSubscriptionTreeHelper("#", "asdf/b/sdf"); testTopicMatchingInSubscriptionTreeHelper("#", "/one/two/asdf"); testTopicMatchingInSubscriptionTreeHelper("#", "/one/two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/+/+/+/+", "/one/two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/+/#", "/one/two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/+/#", "/1234567890abcdef/two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/+/#", "/1234567890abcdefg/two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/+/#", "/1234567890abcde/two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/+/#", "1234567890abcde//two/asdf/"); testTopicMatchingInSubscriptionTreeHelper("+/santa", "/one/two/asdf/", 0); testTopicMatchingInSubscriptionTreeHelper("+/+/+/+/", "/one/two/asdf/a", 0); testTopicMatchingInSubscriptionTreeHelper("+/one/+/+/", "/one/two/asdf/a", 0); } void MainTests::testStartsWith() { FMQ_VERIFY(startsWith("", "")); FMQ_VERIFY(startsWith("a", "")); FMQ_VERIFY(startsWith("abcd", "abc")); FMQ_VERIFY(startsWith("a", "")); FMQ_VERIFY(!startsWith("abc", "abcd")); FMQ_VERIFY(!startsWith("abcd", "bcd")); FMQ_VERIFY(!startsWith("", "a")); } void MainTests::forkingTestForkingTestServer() { cleanup(); MainAppAsFork app; app.start(); app.waitForStarted(); FlashMQTestClient sender; FlashMQTestClient receiver; sender.start(); receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("#", 0); sender.connectClient(ProtocolVersion::Mqtt5); sender.publish("bla", "payload", 0); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MYCASTCOMPARE(ro->receivedPublishes.size(), 1); } /** * There was a regression in 1.17.2 that caused publish packets to be sent before connacks. This tests for that. */ void MainTests::testPacketOrderOnSessionPickup() { // First start with clean_start to reset the session. std::unique_ptr receiver = std::make_unique(); receiver->start(); receiver->connectClient(ProtocolVersion::Mqtt5, true, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; }); receiver->subscribe("subscribe/path", 2); receiver->disconnect(ReasonCodes::Success); receiver.reset(); const std::string payload = "We are testing"; FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt311); Publish p1("subscribe/path", payload, 2); sender.publish(p1); // Now we connect again, and we should now pick up the existing session. receiver = std::make_unique(); receiver->start(); receiver->connectClient(ProtocolVersion::Mqtt5, false, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; }); receiver->waitForPacketCount(3); auto ro = receiver->receivedObjects.lock(); FMQ_COMPARE(ro->receivedPackets.at(0).packetType, PacketType::CONNACK); FMQ_COMPARE(ro->receivedPackets.at(1).packetType, PacketType::PUBLISH); FMQ_COMPARE(ro->receivedPackets.at(2).packetType, PacketType::PUBREL); } void MainTests::testSessionTakeoverOtherUsername() { for (ProtocolVersion p : {ProtocolVersion::Mqtt311, ProtocolVersion::Mqtt5}) { FlashMQTestClient client1; client1.start(); client1.connectClient(p, true, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; connect.username = "mark"; }); { auto ro1 = client1.receivedObjects.lock(); auto &pack = ro1->receivedPackets.at(0); FMQ_COMPARE(pack.packetType, PacketType::CONNACK); ConnAckData ackData = pack.parseConnAckData(); FMQ_COMPARE(ackData.reasonCode, ReasonCodes::Success); } FlashMQTestClient client2; client2.start(); client2.connectClient(p, true, 600, [](Connect &connect) { connect.clientid = "TheReceiver"; connect.username = "marktwain"; }); { auto ro2 = client2.receivedObjects.lock(); auto &pack = ro2->receivedPackets.at(0); FMQ_COMPARE(pack.packetType, PacketType::CONNACK); ConnAckData ackData = pack.parseConnAckData(); int expectedCode = p == ProtocolVersion::Mqtt5 ? static_cast(ReasonCodes::NotAuthorized) : static_cast(ConnAckReturnCodes::NotAuthorized); FMQ_COMPARE(static_cast(ackData.reasonCode), expectedCode); } } } void MainTests::testCorrelationData() { FlashMQTestClient client1; client1.start(); client1.connectClient(ProtocolVersion::Mqtt5); client1.subscribe("several/sub/topics", 1); { FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5); Publish pub("several/sub/topics", "payload", 1); pub.correlationData = "INsI8czE8y3IZBxY"; sender.publish(pub); } client1.waitForMessageCount(1); auto ro = client1.receivedObjects.lock(); auto &pack = ro->receivedPublishes.at(0); FMQ_COMPARE(pack.publishData.correlationData, "INsI8czE8y3IZBxY"); FMQ_COMPARE(pack.getTopic(), "several/sub/topics"); FMQ_COMPARE(pack.getPayloadView(), "payload"); } ================================================ FILE: FlashMQTests/utiltests.cpp ================================================ #include #include "maintests.h" #include "testhelpers.h" #include "utils.h" #include "exceptions.h" #include "conffiletemp.h" #include "nocopy.h" void MainTests::testStringValuesParsing() { { std::vector result = parseValuesWithOptionalQuoting("one two three"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "one"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("\"one\" \"two\" three"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "one"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("\"one\" two \"three\""); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "one"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("\"one\" \"two\" \"three\""); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "one"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("\"o'ne\" \"two'\" \"three\""); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "o'ne"); FMQ_COMPARE(result.at(1), "two'"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("'one' 'two' 'three'"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "one"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("'o\"ne' 'two' 'three'"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "o\"ne"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting("'o\"ne' 'two' 'three' "); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), "o\"ne"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting(""); MYCASTCOMPARE(result.size(), 0); } { std::vector result = parseValuesWithOptionalQuoting(R"delim(one"")delim"); MYCASTCOMPARE(result.size(), 1); FMQ_COMPARE(result.at(0), R"delim(one)delim"); } { std::vector result = parseValuesWithOptionalQuoting(R"delim(one"two")delim"); MYCASTCOMPARE(result.size(), 1); FMQ_COMPARE(result.at(0), R"delim(onetwo)delim"); } } void MainTests::testStringValuesParsingEscaping() { { std::vector result = parseValuesWithOptionalQuoting(R"delim(on\"e two three)delim"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), R"delim(on"e)delim"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting(R"delim(t\\wo three)delim"); MYCASTCOMPARE(result.size(), 2); FMQ_COMPARE(result.at(0), R"delim(t\wo)delim"); FMQ_COMPARE(result.at(1), "three"); } { std::vector result = parseValuesWithOptionalQuoting(R"delim(on\"e 'two' three)delim"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), R"delim(on"e)delim"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), "three"); } { std::vector result = parseValuesWithOptionalQuoting(R"delim(on\"e 'two' 'thr"ee')delim"); MYCASTCOMPARE(result.size(), 3); FMQ_COMPARE(result.at(0), R"delim(on"e)delim"); FMQ_COMPARE(result.at(1), "two"); FMQ_COMPARE(result.at(2), R"delim(thr"ee)delim"); } } void MainTests::testStringValuesFuzz() { std::vector words; std::minstd_rand rnd; rnd.seed(16893578); for (uint32_t i = 0; i < 10000; i++) { uint32_t v = rnd() % 16; for (uint32_t j = 0; j < v; j++) { uint32_t w = rnd() % 10; std::string word; for (uint32_t k = 0; k < w; k++) { char random_char = rnd() % 256; word.push_back(random_char); } words.push_back(word); } } FMQ_VERIFY(words.size() > 10000); for (char q : {'"', '\''}) { std::string longline; for(const std::string &w : words) { std::string w2; for(char c : w) { if (c == '"' || c == '\'' || c == '\\') w2.push_back('\\'); w2.push_back(c); } longline.push_back(q); longline.append(w2); longline.push_back(q); longline.push_back(' '); } std::vector parsed = parseValuesWithOptionalQuoting(longline); FMQ_COMPARE(parsed, words); } } void MainTests::testStringValuesInvalid() { try { std::vector result = parseValuesWithOptionalQuoting("on\\\"e 'two' 'thr\"ee"); FMQ_VERIFY(false); } catch (ConfigFileException &ex) { std::string s(ex.what()); FMQ_VERIFY(s.find("Unterminated quote") != std::string::npos); } catch (std::exception &ex) { FMQ_VERIFY(false); } try { std::vector result = parseValuesWithOptionalQuoting("This is an \\invalid escape."); FMQ_VERIFY(false); } catch (ConfigFileException &ex) { std::string s(ex.what()); FMQ_VERIFY(s.find("Invalid escape") != std::string::npos); } catch (std::exception &ex) { FMQ_VERIFY(false); } try { std::vector result = parseValuesWithOptionalQuoting("Escaping space is not support \\ bla"); FMQ_VERIFY(false); } catch (ConfigFileException &ex) { std::string s(ex.what()); FMQ_VERIFY(s.find("Invalid escape") != std::string::npos); } catch (std::exception &ex) { FMQ_VERIFY(false); } } /** * @brief MainTests::testPreviouslyValidConfigFile tests a config file that worked before the new parser. It uses as many of * the allowed syntax as possible. */ void MainTests::testPreviouslyValidConfigFile() { // Base64 is required because my editor starts interferring with the content, like leading tabs on lines. const std::string b64( "I3RocmVhZF9jb3VudCA2NAoKI2xvZ19maWxlIC90bXAvYm9lMy5sb2cnCgojIComXigqXiooJl4K" "IykoKikoKgogICAgICAjICkoKikoCgkjIHRoaXMgaXMgYSB0YWIKCiMgVGhpcyBsaW5lIGNvbnRh" "aW5zIGEgdGFiCglhbGxvd19hbm9ueW1vdXMgdHJ1ZQoKI292ZXJsb2FkX21vZGUgY2xvc2VfbmV3" "X2NsaWVudHMKCnBsdWdpbl9vcHRfb25lICAgICAgYWJjY2M5YyNjZGVmQUVCRjRfLS86KysrCgog" "ICAgcGx1Z2luX3RpbWVyX3BlcmlvZCAgICAgNDAwMAoKbG9nX2xldmVsIG5vdGljZSAgIAoKd2Vi" "c29ja2V0X3NldF9yZWFsX2lwX2Zyb20gICAgICAgIDJhOjMzOjozMzowMC82NAoKbGlzdGVuIHsK" "ICBwcm90b2NvbCBtcXR0CiAgcG9ydCAxODgzCiAgdGNwX25vZGVsYXkgZmFsc2UKCn0KCmxpc3Rl" "biB7CiAgcHJvdG9jb2wgbXF0dAogIHBvcnQgODA3MAoKfQoKYnJpZGdlIHsKICBsb2NhbF91c2Vy" "bmFtZSBhYmNkZSNmQUVCRjRfLS86KwogIGFkZHJlc3MgZGVtby5mbGFzaG1xLm9yZwogIGNsaWVu" "dGlkX3ByZWZpeCBCckQKICBwdWJsaXNoIGZpdmUgMgogIHN1YnNjcmliZSBhc2RmYXNkZmVlL3dl" "ci8rIDIKICBzdWJzY3JpYmUgX19fNDQvd2VyLyMgMQogIGNsaWVudGlkX3ByZWZpeCB6SDUzXy06" "Ol8KICB0Y3Bfbm9kZWxheSB0cnVlCn0K"); const std::vector b64_bytes = base64Decode(b64); const std::string b64_string = std::string(b64_bytes.begin(), b64_bytes.end()); ConfFileTemp config; config.writeLine(b64_string); config.closeFile(); ConfigFileParser parser(config.getFilePath()); parser.loadFile(false); Settings settings = parser.getSettings(); std::list bridges = settings.stealBridges(); MYCASTCOMPARE(bridges.size(), 1); BridgeConfig &bridge = bridges.front(); FMQ_COMPARE(bridge.subscribes.at(0).topic, "asdfasdfee/wer/+"); FMQ_COMPARE(bridge.subscribes.at(0).qos, (uint8_t)2); FMQ_COMPARE(bridge.local_username, "abcde#fAEBF4_-/:+"); FMQ_COMPARE(bridge.clientidPrefix, "zH53_-::_"); } void MainTests::testNoCopy() { std::string mystring = "haha"; NoCopy s; s = mystring; { NoCopy s2(s); FMQ_VERIFY(!s2); } { NoCopy s3; s3 = s; FMQ_VERIFY(!s3); s3 = mystring; FMQ_COMPARE(s3.value(), mystring); } } void MainTests::testBase64() { const std::string s = getSecureRandomString(1024); const std::vector bytes(s.begin(), s.end()); const std::string b64 = base64Encode(bytes.data(), bytes.size()); const auto decoded = base64Decode(b64); FMQ_COMPARE(decoded, bytes); } ================================================ FILE: FlashMQTests/websockettests.cpp ================================================ #include "maintests.h" #include "testhelpers.h" #include #include #include #include #include "threadglobals.h" #include "filecloser.h" void pollFd(int fd, bool throw_on_timeout) { struct pollfd polls[1]; memset(polls, 0, sizeof(struct pollfd)); polls[0].fd = fd; polls[0].events = POLLIN; const int rc = poll(polls, 1, 1000); if (rc == 0 && throw_on_timeout) throw std::runtime_error("Poll readFromSocket timed out"); if (rc < 0) throw std::runtime_error(strerror(errno)); } std::vector readFromSocket(int fd, bool throw_on_timeout, size_t expected_bytes=0) { std::vector answer; char buf[1024]; pollFd(fd, throw_on_timeout); ssize_t n = 0; while ((n = read(fd, buf, 1024)) != 0) { if (n > 0) answer.insert(answer.end(), buf, buf + n); else if (errno == EWOULDBLOCK) { if (answer.size() < expected_bytes) pollFd(fd, throw_on_timeout); else break; } else throw std::runtime_error(strerror(errno)); } return answer; } void MainTests::testWebsocketPing() { try { Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); int listen_socket = socket(AF_INET, SOCK_STREAM, 0); FileCloser listener_closer(listen_socket); int optval = 1; check(setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); BindAddr bindAddr(AF_INET, "127.0.0.1", 22123); check(bind(listen_socket, bindAddr.get(), bindAddr.getLen())); check(listen(listen_socket, 64)); int client_socket = socket(AF_INET, SOCK_STREAM, 0); int flags = fcntl(listen_socket, F_GETFL); check(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK )); std::shared_ptr c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, HaProxyMode::Off, nullptr, settings, false)); c1->addToEpoll(EPOLLIN); std::shared_ptr client = c1; t->giveClient(std::move(c1)); ::connect(client_socket, bindAddr.get(), bindAddr.getLen()); int socket_to_client = accept(listen_socket, nullptr, nullptr); FileCloser socket_to_client_closer(socket_to_client); if (socket_to_client < 0) throw std::runtime_error("Couldn't accept socket."); flags = fcntl(listen_socket, F_GETFL); check(fcntl(socket_to_client, F_SETFL, flags | O_NONBLOCK )); int error = 0; socklen_t optlen = sizeof(int); int count = 0; do { check(getsockopt(client_socket, SOL_SOCKET, SO_ERROR, &error, &optlen)); } while(error == EINPROGRESS && count++ < 1000); if (error > 0 && error != EINPROGRESS) throw std::runtime_error(strerror(error)); std::ifstream input("plainwebsocketpacket1_handshake.dat", std::ios::binary); std::vector websocketstart(std::istreambuf_iterator(input), {}); { write(socket_to_client, websocketstart.data(), websocketstart.size()); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); std::string answer_string(answer.begin(), answer.end()); QVERIFY(startsWith(answer_string, "HTTP/1.1 101 Switching Protocols")); } // We now have an upgraded connection, and can test websocket frame decoding. { size_t l = 0; std::vector pingFrame(1024); pingFrame[l++] = 0x09; // opcode 9 pingFrame[l++] = 0x00; // Unmasked. payload length; write(socket_to_client, pingFrame.data(), l); pollFd(client_socket, true); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); MYCASTCOMPARE(answer.at(0), 0x8A); // 'final bit', final fragment of message, opcode A (pong). MYCASTCOMPARE(answer.at(1), 0x00); // Zero payload. } { size_t l = 0; std::vector pingFrameWithPayload(1024); pingFrameWithPayload[l++] = 0x09; // opcode 9 pingFrameWithPayload[l++] = 0x05; // Unmasked. payload length; pingFrameWithPayload[l++] = 'h'; pingFrameWithPayload[l++] = 'e'; pingFrameWithPayload[l++] = 'l'; pingFrameWithPayload[l++] = 'l'; pingFrameWithPayload[l++] = 'o'; write(socket_to_client, pingFrameWithPayload.data(), l); pollFd(client_socket, true); client->readFdIntoBuffer(); client->writeBufIntoFd(); { std::vector answer = readFromSocket(socket_to_client, true); int i = 0; MYCASTCOMPARE(answer.at(i++), 0x8A); // 'final bit', final fragment of message, opcode A (pong). MYCASTCOMPARE(answer.at(i++), 0x05); // Payload length MYCASTCOMPARE(answer.at(i++), 'h'); MYCASTCOMPARE(answer.at(i++), 'e'); MYCASTCOMPARE(answer.at(i++), 'l'); MYCASTCOMPARE(answer.at(i++), 'l'); MYCASTCOMPARE(answer.at(i++), 'o'); } // Again, but don't send all data. This would get stuck in a loop before, which should be fixed now. write(socket_to_client, pingFrameWithPayload.data(), l-1); pollFd(client_socket, true); client->readFdIntoBuffer(); client->writeBufIntoFd(); usleep(10000); { std::vector answer = readFromSocket(socket_to_client, false); QVERIFY(answer.empty()); } // And complete the last byte write(socket_to_client, pingFrameWithPayload.data() + (l-1), 1); pollFd(client_socket, true); client->readFdIntoBuffer(); client->writeBufIntoFd(); { std::vector answer = readFromSocket(socket_to_client, true); int i = 0; MYCASTCOMPARE(answer.at(i++), 0x8A); // 'final bit', final fragment of message, opcode A (pong). MYCASTCOMPARE(answer.at(i++), 0x05); // Payload length MYCASTCOMPARE(answer.at(i++), 'h'); MYCASTCOMPARE(answer.at(i++), 'e'); MYCASTCOMPARE(answer.at(i++), 'l'); MYCASTCOMPARE(answer.at(i++), 'l'); MYCASTCOMPARE(answer.at(i++), 'o'); } { int m = 0; char mask[4] = {31,11,66,120}; size_t l = 0; std::vector pingFrameWithMaskedPayload(1024); pingFrameWithMaskedPayload[l++] = 0x09; // opcode 9 pingFrameWithMaskedPayload[l++] = 0x86; // Unmasked. payload length; pingFrameWithMaskedPayload[l++] = mask[0]; pingFrameWithMaskedPayload[l++] = mask[1]; pingFrameWithMaskedPayload[l++] = mask[2]; pingFrameWithMaskedPayload[l++] = mask[3]; pingFrameWithMaskedPayload[l++] = 'a' ^ mask[m++ % 4]; pingFrameWithMaskedPayload[l++] = 'b' ^ mask[m++ % 4]; pingFrameWithMaskedPayload[l++] = 'c' ^ mask[m++ % 4]; pingFrameWithMaskedPayload[l++] = 'd' ^ mask[m++ % 4]; pingFrameWithMaskedPayload[l++] = 'e' ^ mask[m++ % 4]; pingFrameWithMaskedPayload[l++] = 'f' ^ mask[m++ % 4]; write(socket_to_client, pingFrameWithMaskedPayload.data(), l); pollFd(client_socket, true); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); std::string answer_string(answer.begin() + 2, answer.end()); QCOMPARE(answer_string.c_str(), "abcdef"); } } } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } /** * At some point there was a bug where websocketBytesToReadBuffer() would spin if you'd specify a large frame length, * because it kept trying to read bytes that weren't there. The fix for it still had the issue that when a small * amount of bytes were left to read, it would consider it not enough and break out of the loop. */ void MainTests::testWebsocketCorruptLengthFrame() { try { Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); int listen_socket = socket(AF_INET, SOCK_STREAM, 0); FileCloser listener_closer(listen_socket); int optval = 1; check(setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); BindAddr bindAddr(AF_INET, "127.0.0.1", 22123); check(bind(listen_socket, bindAddr.get(), bindAddr.getLen())); check(listen(listen_socket, 64)); int client_socket = socket(AF_INET, SOCK_STREAM, 0); int flags = fcntl(listen_socket, F_GETFL); check(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK )); std::shared_ptr c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, HaProxyMode::Off, nullptr, settings, false)); c1->addToEpoll(EPOLLIN); std::shared_ptr client = c1; t->giveClient(std::move(c1)); ::connect(client_socket, bindAddr.get(), bindAddr.getLen()); int socket_to_client = accept(listen_socket, nullptr, nullptr); FileCloser socket_to_client_closer(socket_to_client); if (socket_to_client < 0) throw std::runtime_error("Couldn't accept socket."); flags = fcntl(listen_socket, F_GETFL); check(fcntl(socket_to_client, F_SETFL, flags | O_NONBLOCK )); int error = 0; socklen_t optlen = sizeof(int); int count = 0; do { check(getsockopt(client_socket, SOL_SOCKET, SO_ERROR, &error, &optlen)); } while(error == EINPROGRESS && count++ < 1000); if (error > 0 && error != EINPROGRESS) throw std::runtime_error(strerror(error)); std::ifstream input("plainwebsocketpacket1_handshake.dat", std::ios::binary); std::vector websocketstart(std::istreambuf_iterator(input), {}); { write(socket_to_client, websocketstart.data(), websocketstart.size()); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); std::string answer_string(answer.begin(), answer.end()); QVERIFY(startsWith(answer_string, "HTTP/1.1 101 Switching Protocols")); } // We now have an upgraded connection, and can test websocket frame decoding. { size_t l = 0; std::vector frame(1024); frame[l++] = 0x02; // opcode 2, binary frame[l++] = 127; // Unmasked. payload length; // Huge extended payload length frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0b00010000; // invalid connect packet frame[l++] = 12; // length is too short. frame[l++] = 0; frame[l++] = 4; frame[l++] = 'M'; frame[l++] = 'Q'; frame[l++] = 'T'; frame[l++] = 'T'; frame[l++] = 4; frame[l++] = 0; // Keepalive frame[l++] = 0; frame[l++] = 60; // Clientid frame[l++] = 0; frame[l++] = 1; frame[l++] = 'a'; write(socket_to_client, frame.data(), l); pollFd(client_socket, true); client->readFdIntoBuffer(); std::vector packets; client->bufferToMqttPackets(packets, client); for (auto &pack : packets) { pack.handle(client); } QVERIFY2(false, "You shouldn't end up here."); std::vector answer = readFromSocket(socket_to_client, true); } } catch (std::exception &ex) { const std::string msg(ex.what()); QVERIFY(strContains(msg, "Invalid packet: header specifies invalid length.")); } } void MainTests::testWebsocketHugePing() { try { Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); int listen_socket = socket(AF_INET, SOCK_STREAM, 0); FileCloser listener_closer(listen_socket); int optval = 1; check(setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); BindAddr bindAddr(AF_INET, "127.0.0.1", 22123); check(bind(listen_socket, bindAddr.get(), bindAddr.getLen())); check(listen(listen_socket, 64)); int client_socket = socket(AF_INET, SOCK_STREAM, 0); int flags = fcntl(listen_socket, F_GETFL); check(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK )); std::shared_ptr c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, HaProxyMode::Off, nullptr, settings, false)); c1->addToEpoll(EPOLLIN); std::shared_ptr client = c1; t->giveClient(std::move(c1)); ::connect(client_socket, bindAddr.get(), bindAddr.getLen()); int socket_to_client = accept(listen_socket, nullptr, nullptr); FileCloser socket_to_client_closer(socket_to_client); if (socket_to_client < 0) throw std::runtime_error("Couldn't accept socket."); flags = fcntl(listen_socket, F_GETFL); check(fcntl(socket_to_client, F_SETFL, flags | O_NONBLOCK )); int error = 0; socklen_t optlen = sizeof(int); int count = 0; do { check(getsockopt(client_socket, SOL_SOCKET, SO_ERROR, &error, &optlen)); } while(error == EINPROGRESS && count++ < 1000); if (error > 0 && error != EINPROGRESS) throw std::runtime_error(strerror(error)); std::ifstream input("plainwebsocketpacket1_handshake.dat", std::ios::binary); std::vector websocketstart(std::istreambuf_iterator(input), {}); { write(socket_to_client, websocketstart.data(), websocketstart.size()); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); std::string answer_string(answer.begin(), answer.end()); QVERIFY(startsWith(answer_string, "HTTP/1.1 101 Switching Protocols")); } // We now have an upgraded connection, and can test websocket frame decoding. { size_t l = 0; std::vector frame(1024); frame[l++] = 0x09; // opcode 9, ping frame[l++] = 127; // Unmasked. payload length; // Huge extended payload length frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 0x99; frame[l++] = 'h'; frame[l++] = 'e'; frame[l++] = 'l'; frame[l++] = 'l'; frame[l++] = 'o'; write(socket_to_client, frame.data(), l); pollFd(client_socket, true); client->readFdIntoBuffer(); client->writeBufIntoFd(); usleep(10000); QVERIFY2(false, "Can't end up here. We must have gotten an exception"); { std::vector answer = readFromSocket(socket_to_client, false); QVERIFY(answer.empty()); } } } catch (std::exception &ex) { std::string msg(ex.what()); QVERIFY(strContains(msg, "The option 'client_max_write_buffer_size / 2' is lower than the ping frame we're are supposed to pong back")); QVERIFY2(true, ex.what()); } } /** * @brief test sending many medium sized websocket ping frames, exceeding the initial buffer site of 1k. This * tests for stall conditions, for instance. */ void MainTests::testWebsocketManyBigPingFrames() { try { Settings settings; MYCASTCOMPARE(settings.clientMaxWriteBufferSize, 1048576); MYCASTCOMPARE(settings.clientInitialBufferSize, 1024); std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); int listen_socket = socket(AF_INET, SOCK_STREAM, 0); FileCloser listener_closer(listen_socket); int optval = 1; check(setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); BindAddr bindAddr(AF_INET, "127.0.0.1", 22123); check(bind(listen_socket, bindAddr.get(), bindAddr.getLen())); check(listen(listen_socket, 64)); int client_socket = socket(AF_INET, SOCK_STREAM, 0); int flags = fcntl(listen_socket, F_GETFL); check(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK )); std::shared_ptr c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, HaProxyMode::Off, nullptr, settings, false)); c1->addToEpoll(EPOLLIN); std::shared_ptr client = c1; t->giveClient(std::move(c1)); ::connect(client_socket, bindAddr.get(), bindAddr.getLen()); int socket_to_client = accept(listen_socket, nullptr, nullptr); FileCloser socket_to_client_closer(socket_to_client); if (socket_to_client < 0) throw std::runtime_error("Couldn't accept socket."); flags = fcntl(listen_socket, F_GETFL); check(fcntl(socket_to_client, F_SETFL, flags | O_NONBLOCK )); int error = 0; socklen_t optlen = sizeof(int); int count = 0; do { check(getsockopt(client_socket, SOL_SOCKET, SO_ERROR, &error, &optlen)); } while(error == EINPROGRESS && count++ < 1000); if (error > 0 && error != EINPROGRESS) throw std::runtime_error(strerror(error)); std::ifstream input("plainwebsocketpacket1_handshake.dat", std::ios::binary); std::vector websocketstart(std::istreambuf_iterator(input), {}); { write(socket_to_client, websocketstart.data(), websocketstart.size()); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); std::string answer_string(answer.begin(), answer.end()); QVERIFY(startsWith(answer_string, "HTTP/1.1 101 Switching Protocols")); } // We now have an upgraded connection, and can test websocket frame decoding. { size_t l = 0; std::vector frame(1048576); for (int z = 0; z < 10; z++) { frame[l++] = 0x09; // opcode 9, ping frame[l++] = 126; // Unmasked. payload length signals following two bytes specify length; const uint16_t payload_size = 32768; frame[l++] = static_cast((payload_size & 0xFF00) >> 8); frame[l++] = static_cast(payload_size & 0x00FF); frame[l++] = 'h'; frame[l++] = 'e'; frame[l++] = 'l'; frame[l++] = 'l'; frame[l++] = 'o'; l += payload_size - 5; } write(socket_to_client, frame.data(), l); // Hacky, I know. while (true) { try { pollFd(client_socket, true); client->readFdIntoBuffer(); } catch (std::exception &ex) { break; } } client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true, 327720); MYCASTCOMPARE(answer.size(), 327720); size_t k = 0; for (int z = 0; z < 10; z++) { QVERIFY(static_cast(answer[k++]) == 0x8a); // Final frame, opcode A (pong) QVERIFY(static_cast(answer[k++]) == 126); // Payload length const uint8_t size_msb = answer[k++]; const uint8_t size_lsb = answer[k++]; const uint16_t size = (size_msb << 8) | size_lsb; QCOMPARE(size, 32768); QCOMPARE(answer[k++], 'h'); QCOMPARE(answer[k++], 'e'); QCOMPARE(answer[k++], 'l'); QCOMPARE(answer[k++], 'l'); QCOMPARE(answer[k++], 'o'); for (int i = 0; i < size - 5; i++) { QVERIFY(static_cast(answer[k++]) == 0); } } QCOMPARE(l,k); MYCASTCOMPARE(l, 327720); QVERIFY(l == answer.size()); } } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } void MainTests::testWebsocketClose() { try { Settings settings; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr store(new SubscriptionStore()); std::shared_ptr t(new ThreadData(0, settings, pluginLoader, std::weak_ptr())); int listen_socket = socket(AF_INET, SOCK_STREAM, 0); FileCloser listener_closer(listen_socket); int optval = 1; check(setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); BindAddr bindAddr(AF_INET, "127.0.0.1", 22123); check(bind(listen_socket, bindAddr.get(), bindAddr.getLen())); check(listen(listen_socket, 64)); int client_socket = socket(AF_INET, SOCK_STREAM, 0); int flags = fcntl(listen_socket, F_GETFL); check(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK )); std::shared_ptr c1(new Client(ClientType::Normal, client_socket, t, FmqSsl(), ConnectionProtocol::WebsocketMqtt, HaProxyMode::Off, nullptr, settings, false)); c1->addToEpoll(EPOLLIN); std::shared_ptr client = c1; t->giveClient(std::move(c1)); ::connect(client_socket, bindAddr.get(), bindAddr.getLen()); int socket_to_client = accept(listen_socket, nullptr, nullptr); FileCloser socket_to_client_closer(socket_to_client); if (socket_to_client < 0) throw std::runtime_error("Couldn't accept socket."); flags = fcntl(listen_socket, F_GETFL); check(fcntl(socket_to_client, F_SETFL, flags | O_NONBLOCK )); int error = 0; socklen_t optlen = sizeof(int); int count = 0; do { check(getsockopt(client_socket, SOL_SOCKET, SO_ERROR, &error, &optlen)); } while(error == EINPROGRESS && count++ < 1000); if (error > 0 && error != EINPROGRESS) throw std::runtime_error(strerror(error)); std::ifstream input("plainwebsocketpacket1_handshake.dat", std::ios::binary); std::vector websocketstart(std::istreambuf_iterator(input), {}); { write(socket_to_client, websocketstart.data(), websocketstart.size()); client->readFdIntoBuffer(); client->writeBufIntoFd(); std::vector answer = readFromSocket(socket_to_client, true); std::string answer_string(answer.begin(), answer.end()); QVERIFY(startsWith(answer_string, "HTTP/1.1 101 Switching Protocols")); } // We now have an upgraded connection, and can test websocket frame decoding. { { size_t l = 0; std::vector closeFrame(1024); closeFrame[l++] = 0x08; // opcode 8 = close closeFrame[l++] = 0x00; // Unmasked. payload length; write(socket_to_client, closeFrame.data(), l); } { size_t l = 0; std::vector pingFrameWithPayload(1024); pingFrameWithPayload[l++] = 0x09; // opcode 9 pingFrameWithPayload[l++] = 0x05; // Unmasked. payload length; pingFrameWithPayload[l++] = 'h'; pingFrameWithPayload[l++] = 'e'; pingFrameWithPayload[l++] = 'l'; pingFrameWithPayload[l++] = 'l'; pingFrameWithPayload[l++] = 'o'; write(socket_to_client, pingFrameWithPayload.data(), l); } pollFd(client_socket, true); DisconnectStage ds = client->readFdIntoBuffer(); QVERIFY(ds == DisconnectStage::Now); } } catch (std::exception &ex) { QVERIFY2(false, ex.what()); } } ================================================ FILE: FlashMQTests/willtests.cpp ================================================ #include #include "maintests.h" #include "testhelpers.h" #include "flashmqtestclient.h" #include "conffiletemp.h" #include "mainappasfork.h" #include "flashmqtempdir.h" void MainTests::testMqtt3will() { std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will"; will->payload = "mypayload"; will->qos = 1; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt311); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311, false, 300); receiver.subscribe("my/will", 0); FlashMQTestClient receiver2; receiver2.start(); receiver2.connectClient(ProtocolVersion::Mqtt311, false, 300); receiver2.subscribe("my/will", 1); sender.reset(); receiver.waitForMessageCount(1); receiver2.waitForMessageCount(1); { auto ro = receiver.receivedObjects.lock(); MqttPacket pubPack = ro->receivedPublishes.front(); std::shared_ptr client = receiver.getClient(); pubPack.parsePublishData(client); QCOMPARE(pubPack.getPublishData().topic, "my/will"); QCOMPARE(pubPack.getPublishData().payload, "mypayload"); QCOMPARE(pubPack.getPublishData().qos, 0); } // The second receiver subscribed at a QoS non-0, and we want to see if we actually get it, and that it wasn't demoted by the other subscriber. { auto ro2 = receiver2.receivedObjects.lock(); MqttPacket pubPack2 = ro2->receivedPublishes.front(); std::shared_ptr client = receiver2.getClient(); pubPack2.parsePublishData(client); QCOMPARE(pubPack2.getPublishData().topic, "my/will"); QCOMPARE(pubPack2.getPublishData().payload, "mypayload"); QCOMPARE(pubPack2.getPublishData().qos, 1); } } void MainTests::testMqtt3NoWillOnDisconnect() { std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will/testMqtt3NoWillOnDisconnect"; will->payload = "mypayload"; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt311); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe("my/will/testMqtt3NoWillOnDisconnect", 0); receiver.clearReceivedLists(); sender->disconnect(ReasonCodes::Success); sender.reset(); usleep(250000); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPackets.empty()); } void MainTests::testMqtt5NoWillOnDisconnect() { std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will/testMqtt5NoWillOnDisconnect"; will->payload = "mypayload"; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt5); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5); receiver.subscribe("my/will/testMqtt3NoWillOnDisconnect", 0); receiver.clearReceivedLists(); sender->disconnect(ReasonCodes::Success); sender.reset(); usleep(250000); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPackets.empty()); } void MainTests::testMqtt5DelayedWill() { std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will/testMqtt5DelayedWill"; will->payload = "mypayload"; will->will_delay = 2; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt5, true, 60); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, true, 60); receiver.subscribe("my/will/testMqtt5DelayedWill", 0); receiver.clearReceivedLists(); sender.reset(); usleep(250000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPackets.empty()); } receiver.waitForMessageCount(1, 5); { auto ro = receiver.receivedObjects.lock(); MqttPacket pubPack = ro->receivedPublishes.front(); std::shared_ptr client = receiver.getClient(); pubPack.parsePublishData(client); QCOMPARE(pubPack.getPublishData().topic, "my/will/testMqtt5DelayedWill"); QCOMPARE(pubPack.getPublishData().payload, "mypayload"); QCOMPARE(pubPack.getPublishData().qos, 0); } } void MainTests::testMqtt5DelayedWillAlwaysOnSessionEnd() { std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will/testMqtt5DelayedWillAlwaysOnSessionEnd"; will->payload = "mypayload"; will->will_delay = 120; // This long delay should not matter, because the session expires after 2s. sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt5, true, 2); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, true, 60); receiver.subscribe("my/will/testMqtt5DelayedWillAlwaysOnSessionEnd", 0); receiver.clearReceivedLists(); sender.reset(); usleep(1000000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPackets.empty()); } receiver.waitForMessageCount(1, 2); { auto ro = receiver.receivedObjects.lock(); MqttPacket pubPack = ro->receivedPublishes.front(); std::shared_ptr client = receiver.getClient(); pubPack.parsePublishData(client); QCOMPARE(pubPack.getPublishData().topic, "my/will/testMqtt5DelayedWillAlwaysOnSessionEnd"); QCOMPARE(pubPack.getPublishData().payload, "mypayload"); QCOMPARE(pubPack.getPublishData().qos, 0); } } /** * @brief MainTests::testWillOnSessionTakeOvers tests sending wills for both persistent and non-persistent sessions. * * Mosquitto is more liberal with not sending wills and will also not send one when you're taking over a persistent session. But, to me it seems * that the specs say you always send wills on client disconnects. * * See https://docs.oasis-open.org/mqtt/mqtt/v5.0/mqtt-v5.0.html "3.1.4 CONNECT Actions" * * See testOverrideWillDelayOnSessionDestructionByTakeOver() for more details. */ void MainTests::testWillOnSessionTakeOvers() { std::list cleanStarts { false, true}; for (bool cleanStart : cleanStarts) { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe("my/will", 0); FlashMQTestClient sender; sender.start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will"; will->payload = "mypayload"; sender.setWill(will); sender.connectClient(ProtocolVersion::Mqtt311, cleanStart, 0, [](Connect &connect){ connect.clientid = "OneOfOne"; }); FlashMQTestClient sender2; sender2.start(); std::shared_ptr will2 = std::make_shared(); will2->topic = "my/will"; will2->payload = "mypayload"; sender2.setWill(will2); sender2.connectClient(ProtocolVersion::Mqtt311, cleanStart, 0, [](Connect &connect){ connect.clientid = "OneOfOne"; }); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MqttPacket pubPack = ro->receivedPublishes.front(); std::shared_ptr client = receiver.getClient(); pubPack.parsePublishData(client); QCOMPARE(pubPack.getPublishData().topic, "my/will"); QCOMPARE(pubPack.getPublishData().payload, "mypayload"); QCOMPARE(pubPack.getPublishData().qos, 0); } } /** * @brief MainTests::testOverrideWillDelayOnSessionDestructionByTakeOver tests that when you connect with a second 'clean start' client, the delayed * will of the session you're destroying is sent. * * Mosquitto is more liberal with not sending wills and will also not send one when you're taking over a persistent session. But, to me it seems * that the specs say you always send wills on client disconnects and actually hasten delayed wills when you kill the session containing the delayed will * by connecting a new session with the same ID using 'clean start'. * * See https://docs.oasis-open.org/mqtt/mqtt/v5.0/mqtt-v5.0.html "3.1.4 CONNECT Actions" */ void MainTests::testOverrideWillDelayOnSessionDestructionByTakeOver() { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe("my/will", 0); FlashMQTestClient sender; sender.start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will"; will->payload = "mypayload"; will->will_delay = 120; sender.setWill(will); sender.connectClient(ProtocolVersion::Mqtt311, false, 300, [](Connect &connect){ connect.clientid = "OneOfOne"; }); FlashMQTestClient sender2; sender2.start(); sender2.connectClient(ProtocolVersion::Mqtt311, true, 300, [](Connect &connect){ connect.clientid = "OneOfOne"; }); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); MqttPacket pubPack = ro->receivedPublishes.front(); std::shared_ptr client = receiver.getClient(); pubPack.parsePublishData(client); QCOMPARE(pubPack.getPublishData().topic, "my/will"); QCOMPARE(pubPack.getPublishData().payload, "mypayload"); QCOMPARE(pubPack.getPublishData().qos, 0); } /** * @brief MainTests::testDisabledWills copied from testMqtt3will, but then wills disabled. */ void MainTests::testDisabledWills() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("wills_enabled no"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will"; will->payload = "mypayload"; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt311); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt311); receiver.subscribe("my/will", 0); sender.reset(); usleep(500000); receiver.waitForMessageCount(0); auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPublishes.empty()); } /** * @brief MainTests::testMqtt5DelayedWillsDisabled same as testMqtt5DelayedWill, but then with wills disabled. */ void MainTests::testMqtt5DelayedWillsDisabled() { ConfFileTemp confFile; confFile.writeLine("allow_anonymous yes"); confFile.writeLine("wills_enabled no"); confFile.closeFile(); std::vector args {"--config-file", confFile.getFilePath()}; cleanup(); init(args); std::unique_ptr sender = std::make_unique(); sender->start(); std::shared_ptr will = std::make_shared(); will->topic = "my/will/testMqtt5DelayedWill"; will->payload = "mypayload"; will->will_delay = 1; sender->setWill(will); sender->connectClient(ProtocolVersion::Mqtt5, true, 60); FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, true, 60); receiver.subscribe("my/will/testMqtt5DelayedWill", 0); receiver.clearReceivedLists(); sender.reset(); usleep(4000000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPackets.empty()); } receiver.waitForMessageCount(0); usleep(250000); { auto ro = receiver.receivedObjects.lock(); QVERIFY(ro->receivedPackets.empty()); } } /** * @brief Test saving sessions and reloading delays wills, that have reached a delay time of zero. These are loaded in the main thread * on start-up, so don't have normal thread data / event loops available to them. * * It tests the bug that FlashMQ crashes on sending wills immediately on loading them from disk on start-up. */ void MainTests::forkingTestSaveAndLoadDelayedWill() { FlashMQTempDir tmpdir; cleanup(); ConfFileTemp confFile; confFile.writeLine(R"( allow_anonymous true log_level debug )"); confFile.writeLine("storage_dir " + tmpdir.getPath().string()); confFile.closeFile(); const std::vector args {"--config-file", confFile.getFilePath()}; MainAppAsFork app(args); app.start(); app.waitForStarted(); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5, false, 120, [](Connect &c) { Publish pub("my/delayed/will/topic", "my delayed will payload", 0); c.will = std::make_shared(pub); c.will->will_delay = 1; }); sender.disconnect(ReasonCodes::DisconnectWithWill); sender.waitForQuit(); // We have to quit now to not have threads when we fork again later. std::this_thread::sleep_for(std::chrono::milliseconds(500)); app.stop(); std::cerr << "Waiting with starting a new server to allow will delay to expire..." << std::endl; std::this_thread::sleep_for(std::chrono::seconds(3)); std::cerr << "Starting new server..." << std::endl; MainAppAsFork app2(args); app2.start(); app2.waitForStarted(); // Is the new server really running? See header doc. try { FlashMQTestClient receiver; receiver.start(); receiver.connectClient(ProtocolVersion::Mqtt5, false, 0); FlashMQTestClient sender; sender.start(); sender.connectClient(ProtocolVersion::Mqtt5, false, 0); receiver.subscribe("pukapuka/boo", 0); sender.publish("pukapuka/boo", "haha", 0); receiver.waitForMessageCount(1); auto ro = receiver.receivedObjects.lock(); FMQ_COMPARE(ro->receivedPublishes.front().getTopic(), "pukapuka/boo"); } catch (std::exception &ex) { std::string msg = "We did not get a response. The server did not start perhaps? Check dmesg. Exception msg: "; msg += ex.what(); FMQ_VERIFY2(false, msg.c_str()); } app2.stop(); } ================================================ FILE: LICENSE ================================================ This Open Software License (the “License”) applies to any original work of authorship (the “Original Work”) whose owner (the “Licensor”) has placed the following licensing notice adjacent to the copyright notice for the Original Work: Licensed under the Open Software License version 3.0 1) Grant of Copyright License. Licensor grants You a worldwide, royalty-free, non-exclusive, sublicensable license, for the duration of the copyright, to do the following: a) to reproduce the Original Work in copies, either alone or as part of a collective work; b) to translate, adapt, alter, transform, modify, or arrange the Original Work, thereby creating derivative works (“Derivative Works”) based upon the Original Work; c) to distribute or communicate copies of the Original Work and Derivative Works to the public, with the proviso that copies of Original Work or Derivative Works that You distribute or communicate shall be licensed under this Open Software License; d) to perform the Original Work publicly; and e) to display the Original Work publicly. 2) Grant of Patent License. Licensor grants You a worldwide, royalty-free, non-exclusive, sublicensable license, under patent claims owned or controlled by the Licensor that are embodied in the Original Work as furnished by the Licensor, for the duration of the patents, to make, use, sell, offer for sale, have made, and import the Original Work and Derivative Works. 3) Grant of Source Code License. The term “Source Code” means the preferred form of the Original Work for making modifications to it and all available documentation describing how to modify the Original Work. Licensor agrees to provide a machine-readable copy of the Source Code of the Original Work along with each copy of the Original Work that Licensor distributes. Licensor reserves the right to satisfy this obligation by placing a machine-readable copy of the Source Code in an information repository reasonably calculated to permit inexpensive and convenient access by You for as long as Licensor continues to distribute the Original Work. 4) Exclusions From License Grant. Neither the names of Licensor, nor the names of any contributors to the Original Work, nor any of their trademarks or service marks, may be used to endorse or promote products derived from this Original Work without express prior permission of the Licensor. Except as expressly stated herein, nothing in this License grants any license to Licensor’s trademarks, copyrights, patents, trade secrets or any other intellectual property. No patent license is granted to make, use, sell, offer for sale, have made, or import embodiments of any patent claims other than the licensed claims defined in Section 2. No license is granted to the trademarks of Licensor even if such marks are included in the Original Work. Nothing in this License shall be interpreted to prohibit Licensor from licensing under terms different from this License any Original Work that Licensor otherwise would have a right to license. 5) External Deployment. The term “External Deployment” means the use, distribution, or communication of the Original Work or Derivative Works in any way such that the Original Work or Derivative Works may be used by anyone other than You, whether those works are distributed or communicated to those persons or made available as an application intended for use over a network. As an express condition for the grants of license hereunder, You must treat any External Deployment by You of the Original Work or a Derivative Work as a distribution under section 1(c). 6) Attribution Rights. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Original Work, as well as any notices of licensing and any descriptive text identified therein as an “Attribution Notice.” You must cause the Source Code for any Derivative Works that You create to carry a prominent Attribution Notice reasonably calculated to inform recipients that You have modified the Original Work. 7) Warranty of Provenance and Disclaimer of Warranty. Licensor warrants that the copyright in and to the Original Work and the patent rights granted herein by Licensor are owned by the Licensor or are sublicensed to You under the terms of this License with the permission of the contributor(s) of those copyrights and patent rights. Except as expressly stated in the immediately preceding sentence, the Original Work is provided under this License on an “AS IS” BASIS and WITHOUT WARRANTY, either express or implied, including, without limitation, the warranties of non-infringement, merchantability or fitness for a particular purpose. THE ENTIRE RISK AS TO THE QUALITY OF THE ORIGINAL WORK IS WITH YOU. This DISCLAIMER OF WARRANTY constitutes an essential part of this License. No license to the Original Work is granted by this License except under this disclaimer. 8) Limitation of Liability. Under no circumstances and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Licensor be liable to anyone for any indirect, special, incidental, or consequential damages of any character arising as a result of this License or the use of the Original Work including, without limitation, damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses. This limitation of liability shall not apply to the extent applicable law prohibits such limitation. 9) Acceptance and Termination. If, at any time, You expressly assented to this License, that assent indicates your clear and irrevocable acceptance of this License and all of its terms and conditions. If You distribute or communicate copies of the Original Work or a Derivative Work, You must make a reasonable effort under the circumstances to obtain the express assent of recipients to the terms of this License. This License conditions your rights to undertake the activities listed in Section 1, including your right to create Derivative Works based upon the Original Work, and doing so without honoring these terms and conditions is prohibited by copyright law and international treaty. Nothing in this License is intended to affect copyright exceptions and limitations (including “fair use” or “fair dealing”). This License shall terminate immediately and You may no longer exercise any of the rights granted to You by this License upon your failure to honor the conditions in Section 1(c). 10) Termination for Patent Action. This License shall terminate automatically and You may no longer exercise any of the rights granted to You by this License as of the date You commence an action, including a cross-claim or counterclaim, against Licensor or any licensee alleging that the Original Work infringes a patent. This termination provision shall not apply for an action alleging patent infringement by combinations of the Original Work with other software or hardware. 11) Jurisdiction, Venue and Governing Law. Any action or suit relating to this License may be brought only in the courts of a jurisdiction wherein the Licensor resides or in which Licensor conducts its primary business, and under the laws of that jurisdiction excluding its conflict-of-law provisions. The application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. Any use of the Original Work outside the scope of this License or after its termination shall be subject to the requirements and penalties of copyright or patent law in the appropriate jurisdiction. This section shall survive the termination of this License. 12) Attorneys’ Fees. In any action to enforce the terms of this License or seeking damages relating thereto, the prevailing party shall be entitled to recover its costs and expenses, including, without limitation, reasonable attorneys’ fees and costs incurred in connection with such action, including any appeal of such action. This section shall survive the termination of this License. 13) Miscellaneous. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. 14) Definition of “You” in This License. “You” throughout this License, whether in upper or lower case, means an individual or a legal entity exercising rights under, and complying with all of the terms of, this License. For legal entities, “You” includes any entity that controls, is controlled by, or is under common control with you. For purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 15) Right to Use. You may use the Original Work in all ways not otherwise restricted or conditioned by this License or by law, and Licensor promises not to interfere with or be responsible for such uses by You. 16) Modification of This License. This License is Copyright © 2005 Lawrence Rosen. Permission is granted to copy, distribute, or communicate this License without modification. Nothing in this License permits You to modify this License as applied to the Original Work or to Derivative Works. However, You may modify the text of this License and copy, distribute or communicate your modified version (the “Modified License”) and apply it to other original works of authorship subject to the following conditions: (i) You may not indicate in any way that your Modified License is the “Open Software License” or “OSL” and you may not use those names in the name of your Modified License; (ii) You must replace the notice specified in the first paragraph above with the notice “Licensed under ” or with a notice of your own that is not confusingly similar to the notice in this License; and (iii) You may not claim that your original works are open source software unless your Modified License has been approved by Open Source Initiative (OSI) and You comply with its license review and certification process. ================================================ FILE: README.md ================================================ # FlashMQ ![building](https://github.com/halfgaar/FlashMQ/actions/workflows/building.yml/badge.svg) ![testing](https://github.com/halfgaar/FlashMQ/actions/workflows/testing.yml/badge.svg) ![linting](https://github.com/halfgaar/FlashMQ/actions/workflows/linting.yml/badge.svg) ![docker](https://github.com/halfgaar/FlashMQ/actions/workflows/docker.yml/badge.svg) FlashMQ is a high-performance, light-weight MQTT broker/server, designed to take good advantage of multi-CPU environments. Builds (AppImage and a Debian/Ubuntu apt server) are provided on [www.flashmq.org](https://www.flashmq.org/download/). The apt repository's GPG key fingerprint is `874E6552A197BDD069A1526E1788BCA932A92BDF`. You can confirm it by comparing it to the one on the download page. ## Building from source Building from source should be done with `build.sh`. It's best to checkout a tagged release first. See `git tag`. If you build manually with `cmake` with default options, you won't have `-DCMAKE_BUILD_TYPE=Release` and you will have debug code enabled and no compiler optimizations (`-O3`). In other words, it's essentially a debug build, but without debugging symbols. ## Docker Official Docker images aren't available yet, but building your own Docker image can be done with the provided `Dockerfile`. Be sure to use the new buildx builder, as the `Dockerfile` is not compatible with the legacy build plugin. You can build using the `Dockerfile` only (replace `` with a proper version tag, like `v1.25.0`): ``` docker build --file Dockerfile https://github.com/halfgaar/FlashMQ.git# -t halfgaar/flashmq ``` Or, build using a local git clone (building a specific version tag is recommended): ``` git checkout docker build . -t halfgaar/flashmq ``` To run: ``` docker run -p 1883:1883 -v /srv/flashmq/etc/:/etc/flashmq --user 1000:1000 halfgaar/flashmq ``` Create extra volumes as you need, for the persistence DB file, logs, password files, plugin, etc. For development you can target the build stage to get an image you can use for development: ``` docker build . --build-arg BUILD_TYPE=Debug --target=build ``` ## Plugins A plugin interface is defined and documented in `flashmq_plugin.h`. It allows custom authentication and other behavior. See the `examples` directory for example implementations of this interface. ## Commercial services If your company requires commercial FlashMQ services, go to [www.flashmq.com](https://www.flashmq.com). Services include: - the development of custom FlashMQ plugins; - MQTT integration advice; and - managed FlashMQ. Send an email to service@flashmq.com to be notified for early trials. ================================================ FILE: acksender.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "acksender.h" #include "mqttpacket.h" #include "client.h" AckSender::AckSender(uint8_t qos, uint16_t packetId, ProtocolVersion protocolVersion, std::shared_ptr &client) : qos(qos), packetId(packetId), protocolVersion(protocolVersion), client(client) { } AckSender::~AckSender() { if (!sent) sendNow(); } void AckSender::sendNow() { this->sent = true; if (qos == 0) return; const PacketType responseType = qos == 1 ? PacketType::PUBACK : PacketType::PUBREC; PubResponse pubAck(this->protocolVersion, responseType, ackCode, packetId); MqttPacket response(pubAck); client->writeMqttPacket(response); } void AckSender::setAckCode(ReasonCodes ackCode) { this->ackCode = ackCode; } ================================================ FILE: acksender.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef ACKSENDER_H #define ACKSENDER_H #include #include "types.h" class AckSender { uint8_t qos; uint16_t packetId; ProtocolVersion protocolVersion = ProtocolVersion::None; std::shared_ptr &client; ReasonCodes ackCode = ReasonCodes::Success; bool sent = false; public: AckSender(const AckSender &other) = delete; AckSender(AckSender &&other) = delete; AckSender(uint8_t qos, uint16_t packetId, ProtocolVersion protocolVersion, std::shared_ptr &client); ~AckSender(); void sendNow(); void setAckCode(ReasonCodes ackCode); }; #endif // ACKSENDER_H ================================================ FILE: acltree.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "acltree.h" #include "utils.h" #include "exceptions.h" /** * @brief AclNode::getChildren gets the children node, and makes it if not there. Use in places that you know this level in the tree exists or should be created. * @param subtopic * @return */ AclNode *AclNode::getChildren(const std::string &subtopic, bool registerPattern) { std::unique_ptr &node = children[subtopic]; if (!node) { node = std::make_unique(); if (registerPattern) { if (subtopic == "%u") this->_hasUserWildcard = true; if (subtopic == "%c") this->_hasClientidWildcard = true; } } return node.get(); } /** * @brief AclNode::getChildren is a const version, and will dererence the end iterator (crash) if it doesn't exist. So, hence the assert. Don't to that. * @param subtopic * @return */ const AclNode *AclNode::getChildren(const std::string &subtopic) const { assert(children.find(subtopic) != children.end()); auto node_it = children.find(subtopic); return node_it->second.get(); } AclNode *AclNode::getChildrenPlus() { if (!childrenPlus) childrenPlus = std::make_unique(); return childrenPlus.get(); } const AclNode *AclNode::getChildrenPlus() const { assert(childrenPlus); return childrenPlus.get(); } bool AclNode::hasChildrenPlus() const { return childrenPlus.operator bool(); } bool AclNode::hasChild(const std::string &subtopic) const { if (children.empty()) return false; auto child_it = children.find(subtopic); return child_it != children.end(); } bool AclNode::hasPoundGrants() const { return !grantsPound.empty(); } bool AclNode::hasUserWildcard() const { return this->_hasUserWildcard; } bool AclNode::hasClientidWildcard() const { return _hasClientidWildcard; } bool AclNode::isEmpty() const { return this->empty; } void AclNode::addGrant(AclGrant grant) { this->empty = false; grants.push_back(grant); } void AclNode::addGrantPound(AclGrant grant) { this->empty = false; grantsPound.push_back(grant); } const std::vector &AclNode::getGrantsPound() const { return this->grantsPound; } const std::vector &AclNode::getGrants() const { return this->grants; } AclTree::AclTree() { collectedPermissions.reserve(16); } /** * @brief AclTree::addTopic adds a fixed topic or pattern to the ACL tree. * @param pattern * @param aclGrant * @param type * @param username is ignored for 'pattern' type, because patterns apply to all users; */ void AclTree::addTopic(const std::string &pattern, AclGrant aclGrant, AclTopicType type, const std::string &username) { const std::vector subtopics = splitTopic(pattern); AclNode *curEnd = &rootAnonymous; if (type == AclTopicType::Patterns) curEnd = &rootPatterns; else if (!username.empty()) { curEnd = &rootPerUser[username]; } for (const auto &subtop : subtopics) { AclNode *subnode = nullptr; if (subtop == "+") subnode = curEnd->getChildrenPlus(); else if (subtop == "#") { curEnd->addGrantPound(aclGrant); return; } else subnode = curEnd->getChildren(subtop, type == AclTopicType::Patterns); curEnd = subnode; } curEnd->addGrant(aclGrant); } void AclTree::findPermissionRecursive(std::vector::const_iterator cur_published_subtopic_it, std::vector::const_iterator end, const AclNode *this_node, std::vector &collectedPermissions, const std::string &username, const std::string &clientid) const { const std::string &cur_published_subtop = *cur_published_subtopic_it; if (cur_published_subtopic_it == end) { const std::vector &grants = this_node->getGrants(); collectedPermissions.insert(collectedPermissions.end(), grants.begin(), grants.end()); if (this_node->hasPoundGrants()) { const std::vector &grantsPound = this_node->getGrantsPound(); collectedPermissions.insert(collectedPermissions.end(), grantsPound.begin(), grantsPound.end()); } return; } if (this_node->hasPoundGrants()) { const std::vector &grants = this_node->getGrantsPound(); collectedPermissions.insert(collectedPermissions.end(), grants.begin(), grants.end()); } const auto next_subtopic_it = ++cur_published_subtopic_it; if (this_node->hasChild(cur_published_subtop)) { const AclNode *sub_node = this_node->getChildren(cur_published_subtop); findPermissionRecursive(next_subtopic_it, end, sub_node, collectedPermissions, username, clientid); } if (this_node->hasUserWildcard() && cur_published_subtop == username) { const AclNode *sub_node = this_node->getChildren("%u"); findPermissionRecursive(next_subtopic_it, end, sub_node, collectedPermissions, username, clientid); } if (this_node->hasClientidWildcard() && cur_published_subtop == clientid) { const AclNode *sub_node = this_node->getChildren("%c"); findPermissionRecursive(next_subtopic_it, end, sub_node, collectedPermissions, username, clientid); } if (this_node->hasChildrenPlus()) { findPermissionRecursive(next_subtopic_it, end, this_node->getChildrenPlus(), collectedPermissions, username, clientid); } } /** * @brief AclTree::findPermission tests permissions as loaded from the Mosquitto-compatible acl_file. * @param subtopicsPublish * @param access Whether to test read access or write access (`AclGrant::Read` or `AclGrant::Write` respectively). * @param username The user to test permission for. * @return * * It behaves like Mosquitto's ACL file. Some of that behavior is a bit limited, but sticking to it for compatability: * * - If your user is authenticated, there must a user specific definition for that user; it won't fall back on anonymous ACLs. * - You can't combine ACLs, like 'all clients read bla/#' and add 'user john readwrite bla/#. User specific ACLs don't add * to the general (anonymous) ACLs. * - You can't specify 'any authenticated user'. */ AuthResult AclTree::findPermission(const std::vector &subtopicsPublish, AclGrant access, const std::string &username, const std::string &clientid) { assert(access == AclGrant::Read || access == AclGrant::Write); // Empty clientid is when FlashMQ itself publishes, and that is fine for 'write'. on 'read', it should still never happen. assert(!(clientid.empty() && access == AclGrant::Read )); collectedPermissions.clear(); if (username.empty() && !rootAnonymous.isEmpty()) findPermissionRecursive(subtopicsPublish.begin(), subtopicsPublish.end(), &rootAnonymous, collectedPermissions, username, clientid); else { auto it = rootPerUser.find(username); if (it != rootPerUser.end()) { AclNode &rootOfUser = it->second; if (!rootOfUser.isEmpty()) findPermissionRecursive(subtopicsPublish.begin(), subtopicsPublish.end(), &rootOfUser, collectedPermissions, username, clientid); } } if (std::find(collectedPermissions.begin(), collectedPermissions.end(), AclGrant::Deny) != collectedPermissions.end()) return AuthResult::acl_denied; if (!rootPatterns.isEmpty()) findPermissionRecursive(subtopicsPublish.begin(), subtopicsPublish.end(), &rootPatterns, collectedPermissions, username, clientid); if (collectedPermissions.empty()) return AuthResult::acl_denied; bool allowed = false; for(AclGrant grant : collectedPermissions) { // A deny always overrides all other declarations. if (grant == AclGrant::Deny) return AuthResult::acl_denied; if (access == AclGrant::Read && (grant == AclGrant::Read || grant == AclGrant::ReadWrite)) allowed = true; if (access == AclGrant::Write && (grant == AclGrant::Write || grant == AclGrant::ReadWrite)) allowed = true; } AuthResult result = allowed ? AuthResult::success : AuthResult::acl_denied; return result; } AclGrant stringToAclGrant(const std::string &s) { const std::string s2 = str_tolower(s); AclGrant x = AclGrant::Deny; if (s2 == "read") x = AclGrant::Read; else if (s2 == "write") x = AclGrant::Write; else if (s2 == "readwrite") x = AclGrant::ReadWrite; else if (s2 == "deny") x = AclGrant::Deny; else throw ConfigFileException(formatString("Acl grant '%s' is invalid", s.c_str())); return x; } ================================================ FILE: acltree.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef ACLTREE_H #define ACLTREE_H #include #include #include #include "logger.h" enum class AclGrant { Deny, Read, Write, ReadWrite }; enum class AclTopicType { Strings, Patterns }; AclGrant stringToAclGrant(const std::string &s); /** * @brief Permissions for an MQTT topic path is a tree of `AclNode`s. Topic paths are broken up and matched down the tree. A '#' wildcard will match * all following subtopics, so therefore '#' is a 'grant', not a 'child'. */ class AclNode { bool empty = false; std::unordered_map> children; std::unique_ptr childrenPlus; // The + sign in MQTT represents a single-level wildcard std::vector grants; std::vector grantsPound; // The # sign. This is short-hand for avoiding one memory access though a layer of std::unique_ptr bool _hasUserWildcard = false; // %u bool _hasClientidWildcard = false; // %c public: AclNode *getChildren(const std::string &subtopic, bool registerPattern); const AclNode *getChildren(const std::string &subtopic) const; AclNode *getChildrenPlus(); const AclNode *getChildrenPlus() const; bool hasChildrenPlus() const; bool hasChild(const std::string &subtopic) const; bool hasPoundGrants() const; bool hasUserWildcard() const; bool hasClientidWildcard() const; bool isEmpty() const; void addGrant(AclGrant grant); void addGrantPound(AclGrant grant); const std::vector &getGrants() const; const std::vector &getGrantsPound() const; }; /** * @brief The AclTree class represents (Mosquitto compatible) permissions from mosquitto_acl_file. It's not thread safe, and designed for per-thread use. */ class AclTree { Logger *logger = Logger::getInstance(); AclNode rootAnonymous; std::unordered_map rootPerUser; AclNode rootPatterns; std::vector collectedPermissions; void findPermissionRecursive(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, const AclNode *node, std::vector &collectedPermissions, const std::string &username, const std::string &clientid) const; public: AclTree(); void addTopic(const std::string &pattern, AclGrant aclGrant, AclTopicType type, const std::string &username = std::string()); AuthResult findPermission(const std::vector &subtopicsPublish, AclGrant access, const std::string &username, const std::string &clientid); }; #endif // ACLTREE_H ================================================ FILE: backgroundworker.cpp ================================================ #include #include #include #include "backgroundworker.h" #include "logger.h" void BackgroundWorker::doWork() { struct pollfd fds[1]; memset(fds, 0, sizeof(struct pollfd)); fds[0].fd = wakeup_fd; fds[0].events = POLLIN; while (running) { executing_task = false; int fd_count = poll(fds, 1, 1000); if (fd_count == 0) continue; if (fd_count < 0) { Logger::getInstance()->log(LOG_ERR) << "poll() error in BackgroundWorker: " << strerror(errno); continue; } if (fds[0].revents & POLLIN) { uint64_t _; if (read(fds[0].fd, &_, sizeof(uint64_t)) < 0) { Logger::getInstance()->log(LOG_ERR) << "Error while reading from wakeup_fd: " << strerror(errno); } } if (!running) continue; std::list> copied_tasks; { auto locked_tasks = tasks.lock(); copied_tasks = std::move(*locked_tasks); locked_tasks->clear(); } for(auto &f : copied_tasks) { executing_task = true; try { f(); } catch (std::exception &ex) { Logger *logger = Logger::getInstance(); logger->log(LOG_ERR) << "Error in BackgroundWorker::do_work: " << ex.what(); } } } } void BackgroundWorker::wake_up_thread() { uint64_t one = 1; if (write(wakeup_fd, &one, sizeof(uint64_t)) < 0) { Logger::getInstance()->log(LOG_ERR) << "BackgroundWorker::wake_up_thread: " << strerror(errno); } } BackgroundWorker::BackgroundWorker() { wakeup_fd = eventfd(0, EFD_NONBLOCK); if (wakeup_fd < 0) { throw std::runtime_error("Failed to initialize eventfd in background worker: " + std::string(strerror(errno))); } } BackgroundWorker::~BackgroundWorker() { this->stop(); if (t.joinable()) t.join(); if (wakeup_fd >= 0) { close(wakeup_fd); wakeup_fd = -1; } } void BackgroundWorker::start() { auto locked_tasks = tasks.lock(); if (t.joinable()) return; auto f = std::bind(&BackgroundWorker::doWork, this); t = std::thread(f); pthread_t native = this->t.native_handle(); pthread_setname_np(native, "BgTasks"); } void BackgroundWorker::stop() { this->running = false; this->wake_up_thread(); } void BackgroundWorker::waitForStop() { if (t.joinable()) t.join(); } void BackgroundWorker::addTask(std::function f, bool only_if_idle) { if (only_if_idle && executing_task) return; { auto locked_tasks = tasks.lock(); locked_tasks->push_front(f); } wake_up_thread(); } ================================================ FILE: backgroundworker.h ================================================ #ifndef BACKGROUNDWORKER_H #define BACKGROUNDWORKER_H #include #include #include #include #include #include "mutexowned.h" class BackgroundWorker { std::thread t; bool running = true; bool executing_task = false; int wakeup_fd = -1; MutexOwned>> tasks; void doWork(); void wake_up_thread(); public: BackgroundWorker(); ~BackgroundWorker(); void start(); void stop(); void waitForStop(); void addTask(std::function f, bool only_if_idle); }; #endif // BACKGROUNDWORKER_H ================================================ FILE: bindaddr.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "bindaddr.h" #include #include #include #include #include #include #include #include "utils.h" #include "logger.h" BindAddr::BindAddr( int family, const std::string &bindAddress, int port, const std::optional &user, const std::optional &group, const std::optional &mode ) : unixsock_user(user), unixsock_group(group), unixsock_mode(mode) { if (!(family == AF_INET || family == AF_INET6 || family == AF_UNIX)) throw std::exception(); if (family == AF_UNIX) { if (bindAddress.empty()) throw std::runtime_error("Binding to a unix socket requires a path."); } else { if (port <= 0 || port > 0xFFFF) throw std::runtime_error("IP listen port invalid."); } this->family = family; if (family == AF_INET) { struct sockaddr_in in_addr_v4 {}; this->len = sizeof(in_addr_v4); if (bindAddress.empty()) in_addr_v4.sin_addr.s_addr = INADDR_ANY; else inet_pton(AF_INET, bindAddress.c_str(), &in_addr_v4.sin_addr); in_addr_v4.sin_port = htons(port); in_addr_v4.sin_family = AF_INET; std::memcpy(dat.data(), &in_addr_v4, sizeof(in_addr_v4)); } if (family == AF_INET6) { struct sockaddr_in6 in_addr_v6 {}; this->len = sizeof(in_addr_v6); if (bindAddress.empty()) in_addr_v6.sin6_addr = IN6ADDR_ANY_INIT; else inet_pton(AF_INET6, bindAddress.c_str(), &in_addr_v6.sin6_addr); in_addr_v6.sin6_port = htons(port); in_addr_v6.sin6_family = AF_INET6; std::memcpy(dat.data(), &in_addr_v6, sizeof(in_addr_v6)); } if (family == AF_UNIX) { struct sockaddr_un path {}; this->len = sizeof(path); if (bindAddress.length() > 100) throw std::runtime_error("Unix domain socket path can't be longer than 100 chars."); path.sun_family = AF_UNIX; std::memcpy(path.sun_path, bindAddress.data(), bindAddress.size()); std::memcpy(dat.data(), &path, sizeof(path)); this->unixsock_path = bindAddress; } } void BindAddr::bind_socket(int socket_fd) { check(bind(socket_fd, get(), len)); if (family == AF_UNIX) { FMQ_ENSURE(this->unixsock_path); std::optional uid; std::optional gid; if (unixsock_user) { std::optional parsed_uid = try_stoul(unixsock_user.value()); if (parsed_uid) uid = parsed_uid.value(); else { std::optional data = get_pw_name(unixsock_user.value()); if (data) uid = data->uid; } if (!uid) Logger::getInstance()->log(LOG_WARNING) << "Could not owner as name or uid '" << unixsock_user.value() << "' on " << unixsock_path.value() << "."; } if (unixsock_group) { std::optional parsed_gid = try_stoul(unixsock_group.value()); if (parsed_gid) gid = parsed_gid.value(); else { std::optional data = get_gr_name(unixsock_group.value()); if (data) gid = data->gid; } if (!gid) Logger::getInstance()->log(LOG_WARNING) << "Could not group as name or gid '" << unixsock_group.value() << "' on " << unixsock_path.value() << "."; } if (uid || gid) { if (chown(unixsock_path.value().c_str(), uid.value_or(-1), gid.value_or(-1)) < 0) { Logger::getInstance()->log(LOG_WARNING) << "Can't change owner/group of '" << unixsock_path.value() << "'. Reason: " << strerror(errno) << "."; } } if (unixsock_mode) { if (chmod(unixsock_path.value().c_str(), this->unixsock_mode.value()) < 0) { Logger::getInstance()->log(LOG_WARNING) << "Can't change mode of '" << unixsock_path.value() << "'. Reason: " << strerror(errno) << "."; } } } } ================================================ FILE: bindaddr.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef BINDADDR_H #define BINDADDR_H #include #include #include #include /** * @brief The BindAddr struct helps creating the resource for bind(). It uses an intermediate struct sockaddr to avoid compiler * warnings and type aliasing violations, and this class helps a bit with resource management of it. */ class BindAddr { std::vector dat = std::vector(sizeof (struct sockaddr_storage)); sa_family_t family = AF_UNSPEC; socklen_t len = 0; std::optional unixsock_path; const std::optional unixsock_user; const std::optional unixsock_group; const std::optional unixsock_mode; public: BindAddr() = delete; BindAddr( int family, const std::string &bindAddress, int port, const std::optional &user={}, const std::optional &group={}, const std::optional &mode={}); BindAddr(const BindAddr &other) = delete; BindAddr(BindAddr &&other) = delete; BindAddr &operator=(const BindAddr &other) = delete; BindAddr &operator=(BindAddr &&other) = delete; void bind_socket(int socket_fd); const sockaddr *get() const { return reinterpret_cast(dat.data()); } socklen_t getLen() const { return len; } }; #endif // BINDADDR_H ================================================ FILE: bridgeconfig.cpp ================================================ #include "bridgeconfig.h" #include #include #include #include #include "utils.h" #include "exceptions.h" #include "bridgeinfodb.h" #include "globals.h" std::string BridgeClientGroupIds::getClientGroupShareName(const std::string &client_id_prefix) { std::string &s = this->bridge_share_names[client_id_prefix]; if (s.empty()) s = getSecureRandomString(12); return s; } void BridgeClientGroupIds::setClientGroupShareName(const std::string &client_id_prefix, const std::string &share_name) { this->bridge_share_names[client_id_prefix] = share_name; } std::string BridgeClientGroupIds::getClientGroupId(const std::string &client_id_prefix) { std::string &s = this->bridge_group_ids[client_id_prefix]; if (s.empty()) s = getSecureRandomString(12); return s; } void BridgeClientGroupIds::loadShareNames(const std::string &path, bool real) { if (path.empty()) return; try { BridgeInfoDb db(path); db.openRead(); std::list data = db.readInfo(); if (real) { for (const BridgeInfoForSerializing &d : data) { bridge_share_names[d.prefix] = d.client_group_share_name; } } } catch (PersistenceFileCantBeOpened &ex) {} } void BridgeClientGroupIds::saveShareNames(const std::string &path) const { if (path.empty()) return; std::list bridgeInfos; for (const auto &p : bridge_share_names) { BridgeInfoForSerializing ser; ser.prefix = p.first; ser.client_group_share_name = p.second; bridgeInfos.emplace_back(std::move(ser)); } BridgeInfoDb bridgeInfoDb(path); bridgeInfoDb.openWrite(); bridgeInfoDb.saveInfo(bridgeInfos); } bool BridgeTopicPath::isValidQos() const { return qos < 3; } bool BridgeTopicPath::operator==(const BridgeTopicPath &other) const { return this->topic == other.topic && this->qos == other.qos; } BridgeState::BridgeState(const BridgeConfig &config) : c(config) { } FMQSockaddr BridgeState::popDnsResult() { if (dnsResults.empty()) throw std::runtime_error("Trying to get DNS results when there are none"); FMQSockaddr addr = dnsResults.front(); dnsResults.pop_front(); return addr; } void BridgeState::initSSL(bool reloadCertificates) { if (this->c.tlsMode == BridgeTLSMode::None) return; if (reloadCertificates) this->sslInitialized = false; if (this->sslInitialized) return; sslctx.emplace(TLS_client_method()); sslctx->setMinimumTlsVersion(c.minimumTlsVersion); SSL_CTX_set_mode(sslctx->get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); const char *privkey = c.sslPrivkey.empty() ? nullptr : c.sslPrivkey.c_str(); const char *fullchain = c.sslFullchain.empty() ? nullptr : c.sslFullchain.c_str(); if (fullchain) { if (SSL_CTX_use_certificate_chain_file(sslctx->get(), fullchain) != 1) { ERR_print_errors_cb(logSslError, NULL); throw std::runtime_error("Loading bridge SSL fullchain failed. This was after test loading the certificate, so is very unexpected."); } } if (privkey) { if (SSL_CTX_use_PrivateKey_file(sslctx->get(), privkey, SSL_FILETYPE_PEM) != 1) { ERR_print_errors_cb(logSslError, NULL); throw std::runtime_error("Loading bridge SSL privkey failed. This was after test loading the certificate, so is very unexpected."); } if (SSL_CTX_check_private_key(sslctx->get()) != 1) { ERR_print_errors_cb(logSslError, NULL); throw std::runtime_error("Verifying bridge SSL privkey failed. This was after test loading the certificate, so is very unexpected."); } } const char *ca_file = c.caFile.empty() ? nullptr : c.caFile.c_str(); const char *ca_dir = c.caDir.empty() ? nullptr : c.caDir.c_str(); if (ca_file || ca_dir) { if (SSL_CTX_load_verify_locations(sslctx->get(), ca_file, ca_dir) != 1) { ERR_print_errors_cb(logSslError, NULL); throw std::runtime_error("Loading ca_dir/ca_file failed. This was after test loading the certificate, so is very unexpected."); } } else { if (SSL_CTX_set_default_verify_paths(sslctx->get()) != 1) { ERR_print_errors_cb(logSslError, NULL); throw std::runtime_error("Setting default SSL paths failed."); } } this->sslInitialized = true; } bool BridgeState::timeForNewReconnectAttempt() { /* * When we are part of a group of connections to a server, don't back off reconnection, because we want to keep all * of the individual connections on-line, and not have some lag behind. */ if (c.getFmqClientGroupId().has_value()) return true; int next = 1; if (reconnectCounter > 0) next = baseReconnectInterval; if (reconnectCounter > 10) next = baseReconnectInterval + 300; if (next > 1 && intervalLogged != next) { intervalLogged = next; Logger *logger = Logger::getInstance(); logger->log(LOG_NOTICE) << "Bridge '" << c.clientidPrefix << "' connection failure count: " << reconnectCounter << ". Increasing reconnect interval to " << next << " seconds."; } return lastReconnectAttempt + std::chrono::seconds(next) < std::chrono::steady_clock::now(); } void BridgeState::registerReconnect() { lastReconnectAttempt = std::chrono::steady_clock::now(); reconnectCounter++; } void BridgeState::resetReconnectCounter() { lastReconnectAttempt = std::chrono::time_point(); reconnectCounter = 0; intervalLogged = 0; } void BridgeState::resetThreadOwners() { session.reset_thread(); threadData.reset_thread(); } /** * @brief BridgeConfig::setClientId is for setting the client ID on start to the one from a saved state. That's why it only works when the prefix matches. * @param prefix * @param id */ void BridgeConfig::setClientId(const std::string &prefix, const std::string &id) { // This is protection against calling this method too early. assert(!clientid.empty()); // Should never happen, but just in case; an empty client id can get confusing. if (id.empty()) return; if (prefix != this->clientidPrefix) return; this->clientid = id; } void BridgeConfig::setClientId() { if (!clientid.empty()) return; if (clientidPrefix.length() > this->client_id_max_length) throw std::runtime_error("clientidPrefix can't be longer than 10"); std::ostringstream oss; oss << clientidPrefix << "_" << getSecureRandomString(10); clientid = oss.str(); } void BridgeConfig::appendConnectionNumber(size_t no) { std::string no_s = std::to_string(no); this->client_id_max_length += no_s.length() + 1; this->clientidPrefix.append("_").append(std::to_string(no)); } void BridgeConfig::setSharedSubscriptionName(const std::string &share_name) { setSharedSubscriptionName(publishes, share_name); setSharedSubscriptionName(subscribes, share_name); } void BridgeConfig::setSharedSubscriptionName(std::vector &topics, const std::string &share_name) { for (BridgeTopicPath &t : topics) { std::vector subtopics = splitTopic(t.topic); std::string _; std::string __; parseSubscriptionShare(subtopics, _, __); std::string new_topic("$share/"); new_topic.append(share_name); for (const std::string &s : subtopics) { new_topic.append("/"); new_topic.append(s); } t.topic = new_topic; } } const std::string &BridgeConfig::getClientid() const { return clientid; } const std::optional &BridgeConfig::getFmqClientGroupId() const { return this->fmq_client_group_id; } void BridgeConfig::isValid() { if (sslPrivkey.empty() != sslFullchain.empty()) throw ConfigFileException("Specify both 'privkey' and 'fullchain' or neither."); if (tlsMode > BridgeTLSMode::None) { if (port == 0) { port = 8883; } if (sslFullchain.size() || sslPrivkey.size()) { testSsl(sslFullchain, sslPrivkey); } testSslVerifyLocations(caFile, caDir, "Loading bridge ca_file/ca_dir failed."); } else { if (port == 0) { port = 1883; } } if (address.empty()) throw ConfigFileException("No address specified in bridge"); if (publishes.empty() && subscribes.empty()) throw ConfigFileException("No subscribe or publish paths defined in bridge."); if (!caDir.empty() && !caFile.empty()) throw ConfigFileException("Specify only one 'ca_file' or 'ca_dir'"); if (clientidPrefix.length() > client_id_max_length) throw ConfigFileException("clientidPrefix can't be longer than 10"); if (protocolVersion <= ProtocolVersion::Mqtt311 && remote_password.has_value() && !remote_username.has_value()) throw ConfigFileException("MQTT 3.1.1 and lower require a username when you set a password."); if (local_prefix && !endsWith(local_prefix.value(), "/")) throw ConfigFileException("Option 'local_prefix' must end in a '/'."); if (remote_prefix && !endsWith(remote_prefix.value(), "/")) throw ConfigFileException("Option 'remote_prefix' must end in a '/'."); if (connection_count > 1 && protocolVersion < ProtocolVersion::Mqtt5) throw ConfigFileException("Using multiple bridge connections needs at least MQTT5"); if (connection_count > 1) { auto check = [](const BridgeTopicPath &x) { if (startsWith(x.topic, "$share")) { throw ConfigFileException("Bridges with multiple connections can't already define share names in the topics."); } }; std::for_each(publishes.begin(), publishes.end(), check); std::for_each(subscribes.begin(), subscribes.end(), check); } } std::vector BridgeConfig::multiply() const { std::vector result; result.reserve(this->connection_count); const std::string share_name = globals->bridgeClientGroupIds.getClientGroupShareName(this->clientidPrefix); const std::string group_id = globals->bridgeClientGroupIds.getClientGroupId(this->clientidPrefix); for (size_t i = 0; i < this->connection_count; i++) { result.push_back(*this); if (this->connection_count > 1) { result.back().setSharedSubscriptionName(share_name); /* * This means that when people have an existing bridge config with `use_saved_clientid`, it will lose its state when * they change the amount of connections from 1 to something else. That's good, because otherwise one of the * connections will get session state that no longer applies to it. */ result.back().appendConnectionNumber(i); result.back().fmq_client_group_id = group_id; } result.back().setClientId(); result.back().connection_count = 1; result.back().isValid(); } return result; } bool BridgeConfig::operator ==(const BridgeConfig &other) const { return this->address == other.address && this->port == other.port && this->inet_protocol == other.inet_protocol && this->tlsMode == other.tlsMode && this->sslFullchain == other.sslFullchain && this->sslPrivkey == other.sslPrivkey && this->caFile == other.caFile && this->caDir == other.caDir && this->protocolVersion == other.protocolVersion && this->bridgeProtocolBit == other.bridgeProtocolBit && this->keepalive == other.keepalive && this->clientidPrefix == other.clientidPrefix && this->publishes == other.publishes && this->subscribes == other.subscribes && this->local_username == other.local_username && this->remote_username == other.remote_username && this->remote_password == other.remote_password && this->remoteCleanStart == other.remoteCleanStart && this->localCleanStart == other.localCleanStart && this->remoteSessionExpiryInterval == other.remoteSessionExpiryInterval && this->localSessionExpiryInterval == other.localSessionExpiryInterval && this->remoteRetainAvailable == other.remoteRetainAvailable && this->useSavedClientId == other.useSavedClientId && this->maxOutgoingTopicAliases == other.maxOutgoingTopicAliases && this->maxIncomingTopicAliases == other.maxIncomingTopicAliases && this->tcpNoDelay == other.tcpNoDelay && this->local_prefix == other.local_prefix && this->remote_prefix == other.remote_prefix && this->connection_count == other.connection_count && this->maxBufferSize == other.maxBufferSize; } bool BridgeConfig::operator !=(const BridgeConfig &other) const { bool r = *this == other; return !r; } ================================================ FILE: bridgeconfig.h ================================================ #ifndef BRIDGECONFIG_H #define BRIDGECONFIG_H #include #include #include #include #include "session.h" #include "dnsresolver.h" #include "sslctxmanager.h" #include "utils.h" #include "threadlocked.h" enum class BridgeTLSMode { None, Unverified, On }; struct BridgeTopicPath { std::string topic; uint8_t qos = 0; bool isValidQos() const; bool operator==(const BridgeTopicPath &other) const; }; /** * @brief The BridgeClientGroupIds class manages the random IDs used in fmq_client_group_id and shared * subscription names for them. * * They need to remain constant during the program's lifetime, and also be settable when loading state from disk. */ class BridgeClientGroupIds { std::unordered_map bridge_group_ids; std::unordered_map bridge_share_names; public: BridgeClientGroupIds() = default; BridgeClientGroupIds(const BridgeClientGroupIds &other) = delete; BridgeClientGroupIds(BridgeClientGroupIds &&other) = delete; BridgeClientGroupIds &operator=(const BridgeClientGroupIds &other) = delete; std::string getClientGroupShareName(const std::string &client_id_prefix); void setClientGroupShareName(const std::string &client_id_prefix, const std::string &share_name); std::string getClientGroupId(const std::string &client_id_prefix); void loadShareNames(const std::string &path, bool real); void saveShareNames(const std::string &path) const; }; class BridgeConfig { std::string clientid; std::optional fmq_client_group_id; // For a custom feature of no-local shared subscriptions. size_t client_id_max_length = 10; void setClientId(); void appendConnectionNumber(size_t no); static void setSharedSubscriptionName(std::vector &topics, const std::string &share_name); public: ListenerProtocol inet_protocol = ListenerProtocol::IPv46; std::string address; uint16_t port = 0; BridgeTLSMode tlsMode = BridgeTLSMode::None; std::string sslFullchain; std::string sslPrivkey; std::string caFile; std::string caDir; ProtocolVersion protocolVersion = ProtocolVersion::Mqtt311; bool bridgeProtocolBit = true; std::string clientidPrefix = "fmqbridge"; size_t connection_count = 1; std::optional local_username; std::optional remote_username; std::optional remote_password; bool remoteCleanStart = true; uint32_t remoteSessionExpiryInterval = 0; bool localCleanStart = true; uint32_t localSessionExpiryInterval = 0; uint16_t keepalive = 60; uint16_t maxIncomingTopicAliases = 0; uint16_t maxOutgoingTopicAliases = 0; bool useSavedClientId = false; bool remoteRetainAvailable = true; std::vector subscribes; std::vector publishes; std::weak_ptr owner; bool queueForDelete = false; bool tcpNoDelay = false; TLSVersion minimumTlsVersion = TLSVersion::TLSv1_1; std::optional maxBufferSize; std::optional local_prefix; std::optional remote_prefix; void setClientId(const std::string &prefix, const std::string &id); const std::string &getClientid() const; const std::optional &getFmqClientGroupId() const; void isValid(); std::vector multiply() const; void setSharedSubscriptionName(const std::string &share_name); bool operator ==(const BridgeConfig &other) const; bool operator !=(const BridgeConfig &other) const; }; class BridgeState { bool sslInitialized = false; std::chrono::time_point lastReconnectAttempt; int reconnectCounter = 0; const int baseReconnectInterval = (get_random_int() % 30) + 30; int intervalLogged = 0; public: const BridgeConfig c; ThreadLocked> session; ThreadLocked> threadData; // kind of hacky, but I need it later. std::optional sslctx; BridgeState(const BridgeConfig &config); DnsResolver dns; std::list dnsResults; FMQSockaddr popDnsResult(); void initSSL(bool reloadCertificates); bool timeForNewReconnectAttempt(); void registerReconnect(); void resetReconnectCounter(); void resetThreadOwners(); }; #endif // BRIDGECONFIG_H ================================================ FILE: bridgeinfodb.cpp ================================================ #include "bridgeinfodb.h" using std::unordered_map; BridgeInfoForSerializing::BridgeInfoForSerializing(const BridgeConfig &bridge) : prefix(bridge.clientidPrefix), clientId(bridge.getClientid()), client_group_share_name("") // In this context (constructing from BridgeConfig) we don't save this. It's done at an earlier stage. { } std::list BridgeInfoForSerializing::getBridgeInfosForSerializing(const unordered_map &input) { std::list result; for (auto &pair : input) { const BridgeConfig &bridge = pair.second; result.emplace_back(bridge); } return result; } BridgeInfoDb::BridgeInfoDb(const std::string &filePath) : PersistenceFile(filePath) { } void BridgeInfoDb::openWrite() { PersistenceFile::openWrite(MAGIC_STRING_BRIDGEINFO_FILE_V2); } void BridgeInfoDb::openRead() { const std::string current_magic_string(MAGIC_STRING_BRIDGEINFO_FILE_V2); PersistenceFile::openRead(current_magic_string); if (detectedVersionString == MAGIC_STRING_BRIDGEINFO_FILE_V1) readVersion = ReadVersion::v1; else if (detectedVersionString == current_magic_string) readVersion = ReadVersion::v2; else throw std::runtime_error("Unknown file version."); } void BridgeInfoDb::saveInfo(const std::list &bridgeInfos) { if (!f) return; writeUint32(bridgeInfos.size()); for (const BridgeInfoForSerializing &b : bridgeInfos) { writeString(b.prefix); writeString(b.clientId); writeString(b.client_group_share_name); } } std::list BridgeInfoDb::readInfo() { std::list result; if (!f) return result; while (!feof(f)) { bool eofFound = false; uint32_t number_of_bridges = readUint32(eofFound); if (eofFound) continue; for (uint32_t i = 0; i < number_of_bridges; i++) { BridgeInfoForSerializing r; r.prefix = readString(eofFound); r.clientId = readString(eofFound); if (readVersion >= ReadVersion::v2) r.client_group_share_name = readString(eofFound); result.push_back(std::move(r)); } } return result; } ================================================ FILE: bridgeinfodb.h ================================================ #ifndef BRIDGEINFODB_H #define BRIDGEINFODB_H #include #include "persistencefile.h" #include "bridgeconfig.h" #define MAGIC_STRING_BRIDGEINFO_FILE_V1 "BridgeInfoDbV1" #define MAGIC_STRING_BRIDGEINFO_FILE_V2 "BridgeInfoDbV2" struct BridgeInfoForSerializing { std::string prefix; std::string clientId; std::string client_group_share_name; BridgeInfoForSerializing() = default; BridgeInfoForSerializing(const BridgeConfig &bridge); static std::list getBridgeInfosForSerializing(const std::unordered_map &input); }; class BridgeInfoDb : private PersistenceFile { enum class ReadVersion { unknown, v1, v2 }; ReadVersion readVersion = ReadVersion::unknown; public: BridgeInfoDb(const std::string &filePath); void openWrite(); void openRead(); void saveInfo(const std::list &bridgeInfos); std::list readInfo(); }; #endif // BRIDGEINFODB_H ================================================ FILE: build.sh ================================================ #!/bin/bash thisfile=$(readlink --canonicalize "$0") thisdir=$(dirname "$thisfile") script_name=$(basename "$0") DEFAULT_BUILD_TYPE="Release" usage() { cat <] [ --asan ] $script_name --help Display this help 'Debug', 'Release' or 'RelWithDebInfo'; default: $DEFAULT_BUILD_TYPE EOF } FAIL_ON_DOC_FAILURE="" BUILD_TYPE="$DEFAULT_BUILD_TYPE" NJOBS_OVERRIDE="" FMQ_ASAN="" while [[ "$#" -gt 0 ]]; do case "$1" in --help|-h) usage exit 0 ;; --fail-on-doc-failure) FAIL_ON_DOC_FAILURE="yeah" shift ;; --njobs-override) shift NJOBS_OVERRIDE="$1" shift ;; --asan) FMQ_ASAN=1 shift ;; --*|-*) echo -e "\e[31mUnknown option: \e[1m$1\e[0m" >&2 usage >&2 exit 2 ;; Debug|Release|RelWithDebInfo) if [[ "$#" -gt 1 ]]; then echo -e "\e[31mRelease type (\e[1m$1\e[22m) must be the last argument\e[0m" >&2 usage >&2 exit 2 fi BUILD_TYPE="$1" shift ;; *) echo -e "\e[31mUnknown positional argument: \e[1m$1\e[0m" >&2 usage >&2 exit 2 ;; esac done if ! [[ "$NJOBS_OVERRIDE" =~ ^[0-9]*$ ]] ; then >&2 echo "--njobs-override must be a number" exit 1 fi if ! make -C "$thisdir/man"; then if [[ -z "$FAIL_ON_DOC_FAILURE" ]]; then echo -e "\e[33mIgnoring failed man page builds; run \e[1m$script_name\e[22m with the \e[1m--fail-on-doc-failure\e[22m option to make such failures fatal.\e[0m" else echo -e "\e[31mMaking the man pages failed; dying now in obedience of the \e[1m--fail-on-doc-failure\e[22m option.\e[0m" exit 3 fi fi BUILD_DIR="FlashMQBuild$BUILD_TYPE" set -eu if [[ -e "$BUILD_DIR" ]]; then >&2 echo "$BUILD_DIR already exists. You can run 'make' in it, if you want. " else mkdir "$BUILD_DIR" fi nprocs=4 if _nprocs=$(nproc); then nprocs="$_nprocs" fi if [[ -n "$NJOBS_OVERRIDE" ]]; then nprocs="$NJOBS_OVERRIDE" fi args=() if [[ -n "$FMQ_ASAN" ]]; then args+=("-DFMQ_ASAN=1") fi cd "$BUILD_DIR" cmake -DCMAKE_BUILD_TYPE="$BUILD_TYPE" "${args[@]}" "$thisdir" make -j "$nprocs" cpack FLASHMQ_VERSION=$(./flashmq --version | grep -Ei 'Flashmq.*version.*' | grep -Eo '[0-9]+\.[0-9]+\.[0-9]+[^ ]*') if command -v linuxdeploy-x86_64.AppImage &> /dev/null; then linuxdeploy-x86_64.AppImage --create-desktop-file --icon-file "../flashmq.png" --appdir "AppImageDir" --executable "flashmq" --output appimage mv flashmq-*.AppImage "flashmq-${FLASHMQ_VERSION}-linux-amd64.AppImage" fi ================================================ FILE: checkedsharedptr.h ================================================ #ifndef CHECKEDSHAREDPTR_H #define CHECKEDSHAREDPTR_H #include #include #include template class CheckedSharedPtr { std::shared_ptr d; public: CheckedSharedPtr() = default; CheckedSharedPtr(const std::shared_ptr &org) : d(org) { } CheckedSharedPtr &operator=(const std::shared_ptr &other) { d = other; return *this; } T &operator*() const { assert(d); if (!d) throw std::runtime_error("CheckedSharedPtr is null"); return *d; } T *operator->() const { assert(d); if (!d) throw std::runtime_error("CheckedSharedPtr is null"); return d.get(); } operator bool() const { return static_cast(d); } void reset() { d.reset(); } }; #endif // CHECKEDSHAREDPTR_H ================================================ FILE: checkedweakptr.h ================================================ #ifndef CHECKEDWEAKPTR_H #define CHECKEDWEAKPTR_H #include class __attribute__((visibility("default"))) PointerNullException : public std::exception { public: virtual const char* what() const noexcept override { return "CheckedWeakPtr pointer null"; } }; template class CheckedWeakPtr { std::weak_ptr d; public: CheckedWeakPtr() = default; CheckedWeakPtr(const std::shared_ptr org) : d(org) { } CheckedWeakPtr &operator=(const std::shared_ptr &other) { d = other; return *this; } std::shared_ptr lock() { std::shared_ptr d2 = d.lock(); if (!d2) throw PointerNullException(); return d2; } bool expired() const { return d.expired(); } }; #endif // CHECKEDWEAKPTR_H ================================================ FILE: cirbuf.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "cirbuf.h" #include #include #include #include #include #include "logger.h" #include "utils.h" CirBuf::CirBuf(size_t size) : size(size) { if (size == 0) return; assert(isPowerOfTwo(size)); buf = (char*)malloc(size); if (buf == NULL) throw std::runtime_error("Malloc error constructing buffer."); #ifndef NDEBUG memset(buf, 0, size); #endif } CirBuf::~CirBuf() { free(buf); buf = nullptr; } uint32_t CirBuf::usedBytes() const { int result = (head - tail) & (size-1); return result; } uint32_t CirBuf::freeSpace() const { int result = (tail - (head + 1)) & (size-1); return result; } uint32_t CirBuf::maxWriteSize() const { int end = size - 1 - head; int n = (end + tail) & (size-1); int result = n <= end ? n : end+1; return result; } uint32_t CirBuf::maxReadSize() const { int end = size - tail; int n = (head + end) & (size-1); int result = n < end ? n : end; return result; } char *CirBuf::headPtr() { return &buf[head]; } char *CirBuf::tailPtr() { return &buf[tail]; } void CirBuf::advanceHead(uint32_t n) { assert(n <= freeSpace()); if (n > freeSpace()) throw std::runtime_error("Trying to advance buffer head more then there bytes free are in the buffer."); head = (head + n) & (size -1); assert(tail != head); // Putting things in the buffer must never end on tail, because tail == head == empty. } void CirBuf::advanceTail(uint32_t n) { assert(n <= usedBytes()); if (n > usedBytes()) throw std::runtime_error("Trying to advance buffer tail more then there bytes used are in the buffer."); tail = (tail + n) & (size -1); } char CirBuf::peakAhead(uint32_t offset) const { assert(offset < usedBytes()); if (offset >= usedBytes()) throw std::runtime_error("Trying to peek more bytes then there are in the buffer."); char b = buf[(tail + offset) & (size - 1)]; return b; } void CirBuf::ensureFreeSpace(const size_t n, const size_t max) { if (n <= freeSpace()) return; const size_t _usedBytes = usedBytes(); size_t mul = 1; while((mul * size - _usedBytes - 1) < n && (mul*size) < max) { mul = mul << 1; } doubleCapacity(mul); } void CirBuf::doubleCapacity(uint factor) { if (factor == 1) return; assert(isPowerOfTwo(factor)); if ((static_cast(size) * factor) > 2147483648) throw std::runtime_error("Trying to exceed circular buffer beyond its 2 GB limit."); uint32_t newSize = size * factor; char *newBuf = (char*)realloc(buf, newSize); if (newBuf == NULL) throw std::runtime_error("Malloc error doubling buffer size."); #ifdef TESTING // I use this to detect the affected memory locations. memset(&newBuf[size], 68, newSize - size); #endif uint32_t maxRead = maxReadSize(); buf = newBuf; if (head < tail) { std::memcpy(&buf[tail + maxRead], buf, head); } head = tail + usedBytes(); size = newSize; #ifndef NDEBUG Logger *logger = Logger::getInstance(); logger->logf(LOG_DEBUG, "New buf size: %d", size); #endif #ifdef TESTING memset(&buf[head], 5, maxWriteSize()); #endif primedForSizeReset = false; } uint32_t CirBuf::getCapacity() const { return size; } void CirBuf::resetCapacityIfEligable(size_t size) { // Ensuring the reset will only happen the second time the timer event hits. if (!primedForSizeReset) { primedForSizeReset = true; return; } if (usedBytes() > 0) return; resetCapacity(size); } void CirBuf::resetCapacity(size_t newSize) { assert(usedBytes() == 0); primedForSizeReset = false; if (this->size == newSize) return; char *newBuf = (char*)realloc(buf, newSize); if (newBuf == NULL) throw std::runtime_error("Malloc error resizing buffer."); buf = newBuf; this->size = newSize; head = 0; tail = 0; #ifndef NDEBUG Logger *logger = Logger::getInstance(); logger->logf(LOG_DEBUG, "Reset buf size: %d", size); memset(buf, 0, newSize); #endif } void CirBuf::reset() { head = 0; tail = 0; #ifndef NDEBUG memset(buf, 0, size); #endif } void CirBuf::write(uint8_t b) { ensureFreeSpace(1); buf[head] = b; advanceHead(1); } void CirBuf::write(uint8_t b, uint8_t b2) { ensureFreeSpace(2); buf[head] = b; advanceHead(1); buf[head] = b2; advanceHead(1); } void CirBuf::write(const void *buf, size_t count) { assert(size > 0); ensureFreeSpace(count); ssize_t len_left = count; size_t src_i = 0; while (len_left > 0) { const size_t len = std::min(len_left, maxWriteSize()); assert(len > 0); const char *src = &reinterpret_cast(buf)[src_i]; std::memcpy(headPtr(), src, len); advanceHead(len); src_i += len; len_left -= len; } assert(len_left == 0); assert(src_i == count); } std::vector CirBuf::peekAllToVector() { const uint32_t tail_org = tail; std::vector result = readAllToVector(); tail = tail_org; return result; } std::vector CirBuf::readToVector(const uint32_t max) { assert(size > 0); uint32_t bytes_left = std::min(max, usedBytes()); std::vector result(bytes_left); int guard = 0; auto pos = result.begin(); while (bytes_left > 0 && guard++ < 10) { const uint32_t readlen = std::min(maxReadSize(), bytes_left); assert(readlen <= maxReadSize()); std::copy(tailPtr(), tailPtr() + readlen, pos); advanceTail(readlen); pos += readlen; bytes_left -= readlen; } assert(guard < 3); assert(bytes_left == 0); assert(pos == result.end()); return result; } std::vector CirBuf::readAllToVector() { return readToVector(std::numeric_limits::max()); } /** * @brief CirBuf::operator == simplistic comparision. It doesn't take the fact that it's circular into account. * @param other * @return * * It was created for unit testing. read() and write() are non-const, so taking the circular properties into account * would need more/duplicate code that I don't need at this point. */ bool CirBuf::operator==(const CirBuf &other) const { #ifdef NDEBUG throw std::exception(); // you can't use it in release builds, because new buffers aren't zeroed. #endif return tail == 0 && other.tail == 0 && usedBytes() == other.usedBytes() && std::memcmp(buf, other.buf, size) == 0; } ================================================ FILE: cirbuf.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef CIRBUF_H #define CIRBUF_H #include #include #include #include #include #include #include // Optimized circular buffer, works only with sizes power of two. class CirBuf { #ifdef TESTING friend class MainTests; #endif char *buf = NULL; uint32_t head = 0; uint32_t tail = 0; uint32_t size = 0; bool primedForSizeReset = false; public: CirBuf(const CirBuf &other) = delete; CirBuf(CirBuf &&other) = delete; CirBuf(size_t size); ~CirBuf(); CirBuf &operator=(const CirBuf &other) = delete; CirBuf &operator=(CirBuf &&other) = delete; uint32_t usedBytes() const; uint32_t freeSpace() const; uint32_t maxWriteSize() const; uint32_t maxReadSize() const; char *headPtr(); char *tailPtr(); void advanceHead(uint32_t n); void advanceTail(uint32_t n); char peakAhead(uint32_t offset) const; void ensureFreeSpace(const size_t n, const size_t max = UINT_MAX); void doubleCapacity(uint factor = 2); uint32_t getCapacity() const; void resetCapacityIfEligable(size_t size); void resetCapacity(size_t size); void reset(); void write(uint8_t b); void write(uint8_t b, uint8_t b2); void write(const void *buf, size_t count); std::vector peekAllToVector(); std::vector readToVector(const uint32_t max); std::vector readAllToVector(); /** * Write the whole range into the buf, making space as needed. */ template void writerange(InputIt begin, InputIt end) { const auto input_size = end - begin; ensureFreeSpace(input_size); size_t len_left = input_size; int guard = 0; auto pos = begin; while (len_left > 0 && guard++ < 10) { const size_t len = std::min(len_left, maxWriteSize()); std::copy(pos, pos + len, headPtr()); advanceHead(len); pos += len; len_left -= len; } assert(len_left == 0); assert(pos == end); } bool operator==(const CirBuf &other) const; }; #endif // CIRBUF_H ================================================ FILE: client.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "client.h" #include #include #include #include #include #include #include "logger.h" #include "utils.h" #include "threadglobals.h" #include "subscriptionstore.h" #include "exceptions.h" #include "settings.h" #include "threaddata.h" #include "globals.h" #include "http.h" StowedClientRegistrationData::StowedClientRegistrationData(bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) : clean_start(clean_start), clientReceiveMax(clientReceiveMax), sessionExpiryInterval(sessionExpiryInterval) { } AsyncAuthResult::AsyncAuthResult(AuthResult result, const std::string authMethod, const std::string &authData) : result(result), authMethod(authMethod), authData(authData) { } Client::WriteBuf::WriteBuf(size_t size) : buf(size) { } /** * @brief Client::Client * @param fd * @param threadData * @param ssl * @param websocket * @param haproxy * @param addr * @param settings The client is constructed in the main thread, so we need to use its settings copy * @param fuzzMode */ Client::Client( ClientType type, int fd, std::shared_ptr threadData, FmqSsl &&ssl, ConnectionProtocol connectionProtocol, HaProxyMode haProxyMode, const struct sockaddr *addr, const Settings &settings, bool fuzzMode) : fd(fd), fuzzMode(fuzzMode), maxOutgoingPacketSize(settings.maxPacketSize), maxIncomingPacketSize(settings.maxPacketSize), maxQos(settings.maxQos), mqtt3QoSExceedAction(settings.mqtt3QoSExceedAction), maxIncomingTopicAliasValue(settings.maxIncomingTopicAliasValue), // Retaining snapshot of current setting, to not confuse clients when the setting changes. ioWrapper(std::move(ssl), connectionProtocol, settings.clientInitialBufferSize, this), readbuf(settings.clientInitialBufferSize), writebuf(settings.clientInitialBufferSize), clientType(type), mHaProxyMode(haProxyMode), threadData(threadData), addr(addr) { assert(fd == -1 || fd > 2); if (haProxyMode > HaProxyMode::Off) ioWrapper.setHaProxy(); if (fd > -1) { int flags = fcntl(fd, F_GETFL); fcntl(fd, F_SETFL, flags | O_NONBLOCK); } std::string haproxy_s; if (haProxyMode == HaProxyMode::On) haproxy_s = "/HAProxy"; else if (haProxyMode >= HaProxyMode::HaProxyClientVerification) haproxy_s = "/HAProxySslClientVerification"; const std::string ssl_s = this->ioWrapper.isSsl() ? "/SSL" : "/Non-SSL"; const std::string websocket_s = connectionProtocol == ConnectionProtocol::WebsocketMqtt ? "/Websocket" : ""; transportStr = formatString("TCP%s%s%s", haproxy_s.c_str(), websocket_s.c_str(), ssl_s.c_str()); // Avoid giving this log line for dummy clients. if (addr && clientType != ClientType::LocalBridge) { logger->log(LOG_NOTICE) << "Accepting connection from: " << repr_endpoint(); } } Client::~Client() { // Dummy clients, that I sometimes need just because the interface demands it but there's not actually a client, have no thread. if (this->epoll_fd < 0) return; if (disconnectReason.empty()) disconnectReason = "not specified"; logger->logf(disconnectReasonLogLevel, "Removing client '%s'. Reason(s): %s", repr().c_str(), disconnectReason.c_str()); std::shared_ptr td = this->threadData.lock(); if (td) { td->queueClientDisconnectActions( authenticated, this->getClientId(), std::move(willPublish), std::move(session), std::move(bridgeState), disconnectReason); } assert(!session); assert(!willPublish); if (fd.get() > 0) // this check is essentially for testing, when working with a dummy fd. { if (epoll_ctl(this->epoll_fd, EPOLL_CTL_DEL, fd.get(), NULL) != 0) logger->logf(LOG_ERR, "Removing fd %d of client '%s' from epoll produced error: %s", fd.get(), repr().c_str(), strerror(errno)); this->epoll_fd = -1; } } void Client::addToEpoll(uint32_t events) { if (this->epoll_fd > -1) return; auto owner_thread = threadData.lock(); if (!owner_thread) return; struct epoll_event ev{}; ev.data.fd = fd.get(); ev.events = events; check(epoll_ctl(owner_thread->getEpollFd(), EPOLL_CTL_ADD, fd.get(), &ev)); this->epoll_fd = owner_thread->getEpollFd(); } bool Client::isSslAccepted() const { return ioWrapper.isSslAccepted(); } bool Client::isSsl() const { return ioWrapper.isSsl(); } HaProxyStage Client::readHaProxyData() { assert(this->ioWrapper.getHaProxyStage() != HaProxyStage::DoneOrNotNeeded); if (this->ioWrapper.getHaProxyStage() == HaProxyStage::HeaderPending) { const std::optional new_sock_addr = this->ioWrapper.readHaProxyHeader(this->fd.get()); if (new_sock_addr) addr = new_sock_addr.value(); } if (this->ioWrapper.getHaProxyStage() == HaProxyStage::AdditionalBytesPending) { this->ioWrapper.readHaProxyAdditionalData(this->fd.get()); } return this->getHaProxyStage(); } bool Client::getSslReadWantsWrite() const { return ioWrapper.getSslReadWantsWrite(); } bool Client::getSslWriteWantsRead() const { return ioWrapper.getSslWriteWantsRead(); } ProtocolVersion Client::getProtocolVersion() const { return protocolVersion; } void Client::setProtocolVersion(ProtocolVersion version) { this->protocolVersion = version; } void Client::connectToBridgeTarget(FMQSockaddr addr) { this->lastActivity = std::chrono::steady_clock::now(); std::shared_ptr bridge = this->bridgeState.lock(); if(!bridge) return; this->outgoingConnection = true; if (bridge->c.tcpNoDelay) { int tcp_nodelay_optval = 1; check(setsockopt(fd.get(), IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay_optval, sizeof(tcp_nodelay_optval))); } addr.setPort(bridge->c.port); int rc = connect(fd.get(), addr.getSockaddr(), addr.getSize()); if (rc < 0) { if (errno != EINPROGRESS) logger->logf(LOG_WARNING, "Client connect error: %s", strerror(errno)); return; } assert(rc == 0); setBridgeConnected(); } void Client::setMaxBufSizeOverride(uint32_t val) { this->maxBufSizeOverride = val; } void Client::startOrContinueSslHandshake() { const bool acceptedBefore = isSslAccepted(); ioWrapper.startOrContinueSslHandshake(); if (!acceptedBefore && isSslAccepted()) { ssl_version = ioWrapper.getSslVersion(); if (this->outgoingConnection) { writeLoginPacket(); } } } void Client::setDisconnectStage(DisconnectStage val) { if (val <= this->disconnectStage) return; this->disconnectStage = val; } DisconnectStage Client::readFdIntoBuffer() { if (this->disconnectStage == DisconnectStage::Now) return DisconnectStage::Now; IoWrapResult error = IoWrapResult::Success; int n = 0; while (readbuf.freeSpace() > 0 && (n = ioWrapper.readWebsocketAndOrSsl(fd.get(), readbuf.headPtr(), readbuf.maxWriteSize(), &error)) != 0) { if (n > 0) { readbuf.advanceHead(n); } if (error == IoWrapResult::Interrupted) continue; if (error == IoWrapResult::Wouldblock || error == IoWrapResult::Disconnected) break; // Make sure we either always have enough space for a next call of this method, or stop reading the fd. if (readbuf.freeSpace() == 0) { const Settings *settings = ThreadGlobals::getSettings(); // I guess I should have just made a 'max buffer size' option, and not distinguish between read/write? const uint32_t maxBufSize = this->maxBufSizeOverride.value_or(settings->clientMaxWriteBufferSize); const uint32_t maxBufOrBigPacketSize= std::max(this->maxIncomingPacketSize, maxBufSize); // We always grow for another iteration when there are still decoded websocket/SSL bytes, because epoll doesn't tell us that buffer has data. if (readbuf.getCapacity() * 2 <= maxBufOrBigPacketSize || error == IoWrapResult::WantRead || ioWrapper.hasProcessedBufferedBytesToRead()) { readbuf.doubleCapacity(); } else { setReadyForReading(false); break; } } } if (error == IoWrapResult::Disconnected) return DisconnectStage::Now; lastActivity = std::chrono::steady_clock::now(); return this->disconnectStage; } void Client::writeText(const std::string &text) { assert(ioWrapper.getConnectionProtocol() == ConnectionProtocol::WebsocketMqtt || !authenticated ); assert(ioWrapper.getWebsocketState() == WebsocketState::NotUpgraded); auto write_buf_locked = writebuf.lock(); write_buf_locked->buf.writerange(text.begin(), text.end()); setReadyForWriting(true, write_buf_locked); } void Client::writePing() { auto write_buf_locked = writebuf.lock(); write_buf_locked->buf.write(0b11000000, 0); setReadyForWriting(true, write_buf_locked); } PacketDropReason Client::writeMqttPacket(const MqttPacket &packet) { const size_t packetSize = packet.getSizeIncludingNonPresentHeader(); // "Where a Packet is too large to send, the Server MUST discard it without sending it and then behave as if it had completed // sending that Application Message [MQTT-3.1.2-25]." if (packetSize > this->maxOutgoingPacketSize) { return PacketDropReason::BiggerThanPacketLimit; } const Settings *settings = ThreadGlobals::getSettings(); // After introducing the client_max_write_buffer_size with low default, this makes it somewhat backwards compatible with the default big packet size. const uint32_t maxBufSize = this->maxBufSizeOverride.value_or(settings->clientMaxWriteBufferSize); const uint32_t growBufMaxTo = std::max(maxBufSize, packetSize * 2); auto write_buf_locked = writebuf.lock(); // Grow as far as we can. We have to make room for one MQTT packet. write_buf_locked->buf.ensureFreeSpace(packetSize, growBufMaxTo); // And drop a publish when it doesn't fit, even after resizing. This means we do allow pings. And // QoS packet are queued and limited elsewhere. if (packet.packetType == PacketType::PUBLISH && packet.getQos() == 0 && packetSize > write_buf_locked->buf.freeSpace()) { return PacketDropReason::BufferFull; } packet.readIntoBuf(write_buf_locked->buf); if (packet.packetType == PacketType::PUBLISH) { ThreadGlobals::getThreadData()->sentMessageCounter.inc(); } else if (packet.packetType == PacketType::DISCONNECT) setDisconnectStage(DisconnectStage::SendPendingAppData); setReadyForWriting(true, write_buf_locked); return PacketDropReason::Success; } PacketDropReason Client::writeMqttPacketAndBlameThisClient( PublishCopyFactory ©Factory, uint8_t max_qos, uint16_t packet_id, bool retain, uint32_t subscriptionIdentifier, const std::optional &topic_override) { uint16_t topic_alias = 0; uint16_t topic_alias_next = 0; bool skip_topic = false; /* * Required for two reasons: * * 1) Upon first use of an alias, we need to hold the lock until we know the packet is actually not dropped. * 2) Upon first use of an alias, we need to make sure another sender using the same topic won't get * their packet sent first. * * I'm not fully happy that by doing this, we'll be holding two mutexes at the same time: this one and the buffer * write mutex, but it's OK for now. They are never locked in opposite order, so deadlocks shouldn't happen. */ MutexLocked locked_aliases_extended; const std::string &topic = topic_override.value_or(copyFactory.getTopic()); if (protocolVersion >= ProtocolVersion::Mqtt5 && this->maxOutgoingTopicAliasValue > 0) { MutexLocked locked_aliases = outgoingTopicAliases.lock(); auto alias_pos = locked_aliases->aliases.find(topic); if (alias_pos != locked_aliases->aliases.end()) { topic_alias = alias_pos->second; skip_topic = true; } else if (locked_aliases->cur_alias < this->maxOutgoingTopicAliasValue) { topic_alias_next = locked_aliases->cur_alias + 1; topic_alias = topic_alias_next; locked_aliases_extended = std::move(locked_aliases); } } MqttPacket *p = copyFactory.getOptimumPacket(max_qos, this->protocolVersion, topic_alias, skip_topic, subscriptionIdentifier, topic_override); assert(static_cast(p->getQos()) == static_cast(max_qos)); assert(PublishCopyFactory::getPublishLayoutCompareKey(this->protocolVersion, p->getQos()) == PublishCopyFactory::getPublishLayoutCompareKey(p->getProtocolVersion(), p->getQos())); if (p->getQos() > 0) { // This may change the packet ID and QoS of the incoming packet for each subscriber, but because we don't store that packet anywhere, // that should be fine. p->setPacketId(packet_id); p->setQos(copyFactory.getEffectiveQos(max_qos)); } p->setRetain(retain); PacketDropReason dropReason = writeMqttPacketAndBlameThisClient(*p); if (dropReason == PacketDropReason::Success && topic_alias_next > 0) { locked_aliases_extended->aliases[topic] = topic_alias_next; locked_aliases_extended->cur_alias = topic_alias_next; } return dropReason; } // Helper method to avoid the exception ending up at the sender of messages, which would then get disconnected. PacketDropReason Client::writeMqttPacketAndBlameThisClient(const MqttPacket &packet) { try { return this->writeMqttPacket(packet); } catch (std::exception &ex) { std::shared_ptr td = this->threadData.lock(); if (td) td->removeClientQueued(fd.get()); return PacketDropReason::ClientError; } } // Ping responses are always the same, so hardcoding it for optimization. void Client::writePingResp() { auto write_buf_locked = writebuf.lock(); write_buf_locked->buf.write(0b11010000, 0); setReadyForWriting(true, write_buf_locked); } void Client::writeLoginPacket() { std::shared_ptr config = this->bridgeState.lock(); if (!config) throw std::runtime_error("No bridge config in bridge?"); Connect connectInfo(protocolVersion, clientid); connectInfo.username = config->c.remote_username; connectInfo.password = config->c.remote_password; connectInfo.clean_start = config->c.remoteCleanStart; connectInfo.keepalive = config->c.keepalive; connectInfo.sessionExpiryInterval = config->c.remoteSessionExpiryInterval; connectInfo.maxIncomingTopicAliasValue = config->c.maxIncomingTopicAliases; connectInfo.bridgeProtocolBit = config->c.bridgeProtocolBit; connectInfo.fmq_client_group_id = config->c.getFmqClientGroupId(); MqttPacket pack(connectInfo); writeMqttPacket(pack); } void Client::writeBufIntoFd() { auto write_buf_locked = writebuf.lock(std::try_to_lock); if (!write_buf_locked.get_lock().owns_lock()) return; // We can abort the write; the client is about to be removed anyway. if (this->disconnectStage == DisconnectStage::Now) return; IoWrapResult error = IoWrapResult::Success; int n; while (write_buf_locked->buf.usedBytes() > 0 || ioWrapper.hasPendingWrite()) { n = ioWrapper.writeWebsocketAndOrSsl(fd.get(), write_buf_locked->buf.tailPtr(), write_buf_locked->buf.maxReadSize(), &error); if (n > 0) write_buf_locked->buf.advanceTail(n); if (error == IoWrapResult::Interrupted) continue; if (error == IoWrapResult::Wouldblock) break; } const bool data_pending = write_buf_locked->buf.usedBytes() > 0 || ioWrapper.hasPendingWrite() || error == IoWrapResult::Wouldblock; if (this->disconnectStage == DisconnectStage::SendPendingAppData && !data_pending) { this->disconnectStage = DisconnectStage::Now; } setReadyForWriting(data_pending, write_buf_locked); } const FMQSockaddr &Client::getAddr() const { return this->addr; } std::string Client::repr() { std::string transport(this->transportStr); if (!ssl_version.empty()) { transport = transport + " (" + ssl_version + ")"; } if (this->protocolVersion == ProtocolVersion::None) { std::ostringstream oss; oss << "[fd=" << fd.get() << ", transport='" << transport << "', address='" << this->addr.getText() << "']"; return oss.str(); } std::string bridge; if (clientType == ClientType::Mqtt3DefactoBridge) bridge = "Mqtt3Bridge "; else if (clientType == ClientType::LocalBridge) bridge = "LocalBridge "; std::string fmq_client_group_id_part; if (this->fmq_client_group_id) { fmq_client_group_id_part.append("fmq_client_group_id='"); fmq_client_group_id_part.append(this->fmq_client_group_id.value()); fmq_client_group_id_part.append("', "); } std::string s = formatString( "[%sClientID='%s', %susername='%s', fd=%d, keepalive=%ds, transport='%s', address='%s', prot=%s, clean=%d]", bridge.c_str(), clientid.c_str(), fmq_client_group_id_part.c_str(), username.c_str(), fd.get(), keepalive, transport.c_str(), this->addr.getText().c_str(), protocolVersionString(protocolVersion).c_str(), this->clean_start); return s; } std::string Client::repr_endpoint() { std::string s = formatString("address='%s', transport='%s', fd=%d", this->addr.getText().c_str(), this->transportStr.c_str(), fd.get()); return s; } /** * @brief Client::keepAliveExpired * @return * * [MQTT-3.1.2-24]: "If the Keep Alive value is non-zero and the Server does not receive a Control Packet from the * Client within one and a half times the Keep Alive time period, it MUST disconnect the Network Connection to * the Client as if the network had failed." */ bool Client::keepAliveExpired() { if (keepalive == 0) return false; const std::chrono::time_point now = std::chrono::steady_clock::now(); std::chrono::seconds x(keepalive + keepalive/2); bool result = (lastActivity + x) < now; return result; } std::string Client::getKeepAliveInfoString() const { std::chrono::seconds secondsSinceLastActivity = std::chrono::duration_cast(std::chrono::steady_clock::now() - lastActivity); std::string s = formatString("authenticated=%s, keep-alive=%ss, last activity=%s seconds ago.", std::to_string(authenticated).c_str(), std::to_string(keepalive).c_str(), std::to_string(secondsSinceLastActivity.count()).c_str()); return s; } void Client::resetBuffersIfEligible() { const Settings *settings = ThreadGlobals::getSettings(); const size_t initialBufferSize = settings->clientInitialBufferSize; readbuf.resetCapacityIfEligable(initialBufferSize); ioWrapper.resetBuffersIfEligible(); auto write_buf_locked = writebuf.lock(); write_buf_locked->buf.resetCapacityIfEligable(initialBufferSize); } void Client::setTopicAlias(const uint16_t alias_id, const std::string &topic) { if (alias_id == 0) throw ProtocolError("Client tried to set topic alias 0, which is a protocol error.", ReasonCodes::ProtocolError); if (topic.empty()) return; // The specs actually say "The Client MUST NOT send a Topic Alias [...] to the Server greater than this value [Topic Alias Maximum]". So, it's not about count. if (alias_id > this->maxIncomingTopicAliasValue) throw ProtocolError(formatString("Client tried to set more topic aliases than the server max of %d per client", this->maxIncomingTopicAliasValue), ReasonCodes::TopicAliasInvalid); this->incomingTopicAliases[alias_id] = topic; } const std::string &Client::getTopicAlias(const uint16_t id) const { auto pos = this->incomingTopicAliases.find(id); if (pos == this->incomingTopicAliases.end()) throw ProtocolError("Requesting topic alias ID (" + std::to_string(id) + ") that wasn't set before.", ReasonCodes::TopicAliasInvalid); return pos->second; } /** * @brief We use this for doing the checks on client traffic, as opposed to using settings.maxPacketSize, because the latter than change on config reload, * possibly resulting in exceeding what the other side uses as maximum. * @return */ uint32_t Client::getMaxIncomingPacketSize() const { return this->maxIncomingPacketSize; } /** * @brief We use this to send back in the connack, so we know we don't race with the value from settings, which may change during the connection handshake. * @return */ uint16_t Client::getMaxIncomingTopicAliasValue() const { return this->maxIncomingTopicAliasValue; } void Client::sendOrQueueWill() { if (this->threadData.expired()) return; if (!this->willPublish) return; std::shared_ptr store = globals->subscriptionStore; store->queueOrSendWillMessage(willPublish, session); this->willPublish.reset(); } /** * @brief Client::setRegistrationData sets parameters for the session to be registered. We set them as arguments here to * possibly use later, because with extended authentication, session registration doesn't happen on the first CONNECT packet. * @param clean_start * @param maxQosPackets * @param sessionExpiryInterval */ void Client::setRegistrationData(bool clean_start, uint16_t client_receive_max, uint32_t sessionExpiryInterval) { this->clean_start = clean_start; this->registrationData = std::make_unique(clean_start, client_receive_max, sessionExpiryInterval); } const std::unique_ptr &Client::getRegistrationData() const { return this->registrationData; } void Client::clearRegistrationData() { this->registrationData.reset(); } /** * @brief Client::stageConnack saves the success connack for later use. * @param c * * The connack to be generated is known on the initial connect packet, but in extended authentication, the client won't get it * until the authentication is complete. */ void Client::stageConnack(std::unique_ptr &&c) { this->stagedConnack = std::move(c); } void Client::sendConnackSuccess() { if (!stagedConnack) { throw ProtocolError("Programming bug: trying to send a prepared connack when there is none.", ReasonCodes::ProtocolError); } ConnAck &connAck = *this->stagedConnack.get(); MqttPacket response(connAck); writeMqttPacket(response); logger->logf(LOG_NOTICE, "Client '%s' logged in successfully", repr().c_str()); this->stagedConnack.reset(); } void Client::sendConnackDeny(ReasonCodes reason) { ConnAck connDeny(protocolVersion, reason, false); MqttPacket response(connDeny); setDisconnectReason("Access denied"); setDisconnectStage(DisconnectStage::SendPendingAppData); writeMqttPacket(response); logger->logf(LOG_NOTICE, "User '%s' access denied", username.c_str()); } void Client::addAuthReturnDataToStagedConnAck(const std::string &authData) { if (authData.empty()) return; if (!stagedConnack) { throw ProtocolError("Programming bug: trying to add auth return data when there is no staged connack.", ReasonCodes::ProtocolError); } stagedConnack->propertyBuilder->writeAuthenticationData(authData); } void Client::setExtendedAuthenticationMethod(const std::string &authMethod) { this->extendedAuthenticationMethod = authMethod; } const std::string &Client::getExtendedAuthenticationMethod() const { return this->extendedAuthenticationMethod; } std::shared_ptr Client::lockThreadData() { return this->threadData.lock(); } void Client::setBridgeState(std::shared_ptr bridgeState) { this->bridgeState = bridgeState; this->outgoingConnection = true; this->clientType = ClientType::LocalBridge; if (bridgeState) { this->protocolVersion = bridgeState->c.protocolVersion; this->clean_start = bridgeState->c.localCleanStart; this->clientid = bridgeState->c.getClientid(); this->fmq_client_group_id = bridgeState->c.getFmqClientGroupId(); this->username = bridgeState->c.local_username.value_or(std::string()); this->keepalive = bridgeState->c.keepalive; this->addr.setAddressName(bridgeState->c.address); // Not setting maxOutgoingTopicAliasValue, because that must remain 0 until the other side says (in the connack) we can uses aliases. this->maxIncomingTopicAliasValue = bridgeState->c.maxIncomingTopicAliases; if (bridgeState->c.tlsMode > BridgeTLSMode::None) { const int mode = bridgeState->c.tlsMode == BridgeTLSMode::On ? SSL_VERIFY_PEER : SSL_VERIFY_NONE; ioWrapper.setSslVerify(mode, bridgeState->c.address); } } } bool Client::isOutgoingConnection() const { return this->outgoingConnection; } std::shared_ptr Client::getBridgeState() { return this->bridgeState.lock(); } void Client::setBridgeConnected() { this->outgoingConnectionEstablished = true; std::shared_ptr bridge = this->bridgeState.lock(); if (bridge) { bridge->dnsResults.clear(); } if (isSsl()) this->startOrContinueSslHandshake(); else this->writeLoginPacket(); } void Client::detectOutgoingConnectionEstablished() { int error = 0; socklen_t optlen = sizeof(int); int rc = getsockopt(fd.get(), SOL_SOCKET, SO_ERROR, &error, &optlen); if (rc == 0 && error == 0) { setBridgeConnected(); } if (error > 0 && error != EINPROGRESS) throw BadClientException(strerror(error), LOG_WARNING); } bool Client::getOutgoingConnectionEstablished() const { return this->outgoingConnectionEstablished; } void Client::setClientType(ClientType val) { this->clientType = val; if (!session) return; session->setClientType(val); } #ifndef NDEBUG /** * @brief IoWrapper::setFakeUpgraded(). */ void Client::setFakeUpgraded() { ioWrapper.setFakeUpgraded(); } #endif void Client::setReadyForWriting(bool val) { auto write_buf_locked = writebuf.lock(); setReadyForWriting(val, write_buf_locked); } void Client::setReadyForWriting(bool val, MutexLocked &writebuf) { #ifndef NDEBUG if (fuzzMode) return; #endif #ifdef TESTING if (fd.get() == 0) return; #endif if (this->disconnectStage == DisconnectStage::Now) return; if (ioWrapper.getSslReadWantsWrite()) val = true; // This looks a bit like a race condition, but all calls to this method should be under lock of writeBufMutex, so it should be OK. if (val == writebuf->readyForWriting) return; writebuf->readyForWriting = val; struct epoll_event ev {}; ev.data.fd = fd.get(); ev.events = readyForReading*EPOLLIN | val*EPOLLOUT; check(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd.get(), &ev)); } void Client::setReadyForReading(bool val) { #ifndef NDEBUG if (fuzzMode) return; #endif #ifdef TESTING if (fd.get() == 0) return; #endif if (this->disconnectStage == DisconnectStage::Now) return; // This looks a bit like a race condition, but all calls to this method are from a threads's event loop, so we should be OK. if (val == this->readyForReading) return; readyForReading = val; struct epoll_event ev {}; ev.data.fd = fd.get(); { auto write_buf_locked = writebuf.lock(); ev.events = readyForReading*EPOLLIN | write_buf_locked->readyForWriting*EPOLLOUT; check(epoll_ctl(this->epoll_fd, EPOLL_CTL_MOD, fd.get(), &ev)); } } void Client::setAddr(const std::string &address) { const Settings *settings = ThreadGlobals::getSettings(); if (!settings->matchAddrWithSetRealIpFrom(this->addr.getSockaddr())) return; addr.setAddress(address); } bool Client::tryAcmeRedirect() { if (authenticated) return false; if (isSsl()) return false; if (getConnectionProtocol() == ConnectionProtocol::AcmeOnly && !acmeRedirectUrl) throw std::runtime_error("ACME-only client has no redirect URL."); if (!acmeRedirectUrl) return false; std::optional req = parseHttpHeader(readbuf); if (!req) return false; if (!req.value()) return false; const HttpRequest::Data &d = req.value().value(); if (startsWith(d.request, "/.well-known/acme-challenge/")) { respondWithRedirectURL(d.request); return true; } return false; } void Client::respondWithRedirectURL(const std::string &request) { if (!acmeRedirectUrl) throw std::runtime_error("Trying to redirect ACME without URL defined."); std::string redirected_request = *acmeRedirectUrl; rtrim(redirected_request, '/'); redirected_request.append(request); std::string response = generateRedirect(redirected_request); writeText(response); setDisconnectReason("Redirecting ACME request"); setDisconnectStage(DisconnectStage::SendPendingAppData); } void Client::bufferToMqttPackets(std::vector &packetQueueIn, std::shared_ptr &sender) { MqttPacket::bufferToMqttPackets(readbuf, packetQueueIn, sender); setReadyForReading(readbuf.freeSpace() > 0); } void Client::setClientProperties( ProtocolVersion protocolVersion, const std::string &clientId, const std::optional &fmq_client_group_id, const std::string username, bool connectPacketSeen, uint16_t keepalive) { const Settings *settings = ThreadGlobals::getSettings(); setClientProperties(protocolVersion, clientId, fmq_client_group_id, username, connectPacketSeen, keepalive, settings->maxPacketSize, 0); } void Client::setClientProperties( ProtocolVersion protocolVersion, const std::string &clientId, const std::optional &fmq_client_group_id, const std::string username, bool connectPacketSeen, uint16_t keepalive, uint32_t maxOutgoingPacketSize, uint16_t maxOutgoingTopicAliasValue) { this->protocolVersion = protocolVersion; this->clientid = clientId; this->fmq_client_group_id = fmq_client_group_id; this->username = username; this->connectPacketSeen = connectPacketSeen; this->keepalive = keepalive; this->maxOutgoingPacketSize = maxOutgoingPacketSize; this->maxOutgoingTopicAliasValue = maxOutgoingTopicAliasValue; } void Client::setClientProperties(bool connectPacketSeen, uint16_t keepalive, uint32_t maxOutgoingPacketSize, uint16_t maxOutgoingTopicAliasValue, bool supportsRetained) { if (logger->wouldLog(LOG_DEBUG)) { logger->log(LOG_DEBUG) << "Client '" << repr() << "' properties set: keep_alive=" << keepalive << ", max_outgoing_packet_size=" << maxOutgoingPacketSize << ", max_outgoing_topic_aliases=" << maxOutgoingTopicAliasValue << "."; } this->connectPacketSeen = connectPacketSeen; this->keepalive = keepalive; this->maxOutgoingPacketSize = maxOutgoingPacketSize; this->maxOutgoingTopicAliasValue = maxOutgoingTopicAliasValue; this->supportsRetained = supportsRetained; } void Client::stageWill(WillPublish &&willPublish) { this->stagedWillPublish = std::make_shared(std::move(willPublish)); this->stagedWillPublish->client_id = this->clientid; this->stagedWillPublish->username = this->username; } void Client::setWillFromStaged() { this->willPublish = std::move(stagedWillPublish); } void Client::assignSession(const std::shared_ptr &session) { this->session = session; } std::shared_ptr Client::getSession() { if (!this->session) throw std::runtime_error("Client has no session in getSession(). It was probably meant to be discarded."); return this->session; } void Client::setDisconnectReason(const std::string &reason, const int logLevel) { #ifndef TESTING // Because of testing trickery, we can't assert this in testing. #ifndef NDEBUG auto td = this->threadData.lock(); if (td) { assert(pthread_self() == td->thread_id); } #endif #endif if (!this->disconnectReason.empty()) this->disconnectReason += ", "; this->disconnectReason.append(reason); // Purposefully disallow downgrading it again. if (logLevel > this->disconnectReasonLogLevel) this->disconnectReasonLogLevel = logLevel; } void Client::setDisconnectReasonFromSocketError() { int error = 0; socklen_t errlen = sizeof(error); if (getsockopt(fd.get(), SOL_SOCKET, SO_ERROR, &error, &errlen) == 0) { setDisconnectReason(strerror(error)); } } /** * @brief Client::getSecondsTillKeepAliveAction gets the amount of seconds from now at which this client should be killed when * it was quiet, or in case of outgoing client, when a new ping is required. * @return * * "If the Keep Alive value is non-zero and the Server does not receive an MQTT Control Packet from the Client within one and a * half times the Keep Alive time period, it MUST close the Network Connection to the Client as if the network had failed [MQTT-3.1.2-22]. */ std::chrono::seconds Client::getSecondsTillKeepAliveAction() const { if (isOutgoingConnection()) return std::chrono::seconds(this->keepalive); if (!this->authenticated) return std::chrono::seconds(30); if (this->keepalive == 0) return std::chrono::seconds(0); const uint32_t timeOfSilenceMeansKill = this->keepalive + (this->keepalive / 2) + 2; std::chrono::time_point killTime = this->lastActivity + std::chrono::seconds(timeOfSilenceMeansKill); std::chrono::seconds secondsTillKillTime = std::chrono::duration_cast(killTime - std::chrono::steady_clock::now()); // We floor it, but also protect against the theoretically impossible negative value. Kill time shouldn't be in the past, because then we would // have killed it already. if (secondsTillKillTime < std::chrono::seconds(5)) return std::chrono::seconds(5); return secondsTillKillTime; } const std::optional &Client::getLocalPrefix() const { if (!this->session) throw std::runtime_error("Client has no session in getSession(). It was probably meant to be discarded."); return session->getLocalPrefix(); } const std::optional &Client::getRemotePrefix() const { if (!this->session) throw std::runtime_error("Client has no session in getSession(). It was probably meant to be discarded."); return session->getRemotePrefix(); } void Client::clearWill() { willPublish.reset(); stagedWillPublish.reset(); if (session) session->clearWill(); } void Client::setClientId(const std::string &id) { this->clientid = id; } std::string &Client::getMutableUsername() { return this->username; } void Client::setSslVerify(X509ClientVerification verificationMode) { const int mode = verificationMode > X509ClientVerification::None ? SSL_VERIFY_PEER : SSL_VERIFY_NONE; this->x509ClientVerification = verificationMode; ioWrapper.setSslVerify(mode, ""); } std::optional Client::getUsernameFromPeerCertificate() { if (!ioWrapper.isSsl() || x509ClientVerification == X509ClientVerification::None) return std::optional(); X509Manager client_cert = ioWrapper.getPeerCertificate(); if (!client_cert) throw ProtocolError("Client did not provide X509 peer certificate", ReasonCodes::BadUserNameOrPassword); X509_NAME *x509_name = X509_get_subject_name(client_cert.get()); int index = X509_NAME_get_index_by_NID(x509_name, NID_commonName, -1); if (index < 0) return std::optional(); X509_NAME_ENTRY *name_entry = X509_NAME_get_entry(x509_name, index); if (!name_entry) throw std::runtime_error("X509_NAME_get_entry failed. This should be impossible."); ASN1_STRING *asn1_string = X509_NAME_ENTRY_get_data(name_entry); if (!asn1_string) throw std::runtime_error("Cannot obtain asn1 string from x509 certificate."); const auto len = ASN1_STRING_length(asn1_string); const unsigned char *str = ASN1_STRING_get0_data(asn1_string); if (!str) throw std::runtime_error("ASN1_STRING_get0_data failed. This should be impossible."); std::vector str_data(str, str + len); std::string username = make_string(str_data, 0, len); if (!isValidUtf8(username)) throw ProtocolError("Common name from peer certificate is not valid UTF8.", ReasonCodes::MalformedPacket); return username; } X509ClientVerification Client::getX509ClientVerification() const { return x509ClientVerification; } void Client::setAllowAnonymousOverride(const AllowListenerAnonymous allow) { allowAnonymousOverride = allow; } AllowListenerAnonymous Client::getAllowAnonymousOverride() const { return allowAnonymousOverride; } void Client::setAcmeRedirect(const std::optional &url) { if (url) this->acmeRedirectUrl = std::make_unique(url.value()); } void Client::addPacketToAfterAsyncQueue(MqttPacket &&p) { if (!packetQueueAfterAsync) packetQueueAfterAsync = std::make_unique>(); if (packetQueueAfterAsync->size() > 64) throw ProtocolError("Client sending too many packets without waiting for CONNACK. This is likely an abuser", ReasonCodes::ImplementationSpecificError); packetQueueAfterAsync->push_back(std::move(p)); } void Client::handleAfterAsyncQueue(std::shared_ptr &sender) { if (!this->asyncAuthenticating) return; this->asyncAuthenticating = false; if (!this->packetQueueAfterAsync) return; std::unique_ptr> packets = std::move(this->packetQueueAfterAsync); this->packetQueueAfterAsync.reset(); for (MqttPacket &p : *packets) { p.handle(sender); } } void Client::setAsyncAuthResult(const AsyncAuthResult &v) { this->asyncAuthResult = std::make_unique(v); setReadyForWriting(true); } std::unique_ptr Client::stealAsyncAuthResult() { std::unique_ptr r(std::move(this->asyncAuthResult)); this->asyncAuthResult.reset(); return r; } ================================================ FILE: client.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef CLIENT_H #define CLIENT_H #include #include #include #include #include #include #include #include #include #include "forward_declarations.h" #include "mqttpacket.h" #include "cirbuf.h" #include "types.h" #include "iowrapper.h" #include "bridgeconfig.h" #include "enums.h" #include "fdmanaged.h" #include "mutexowned.h" #include "fmqssl.h" #include "publishcopyfactory.h" #define MQTT_HEADER_LENGH 2 /** * @brief The StowedClient struct stores the client when doing an extended authentication, and we need to keep the info around how * the client will be registered once the authentication succeeds. */ struct StowedClientRegistrationData { const bool clean_start; const uint16_t clientReceiveMax; const uint32_t sessionExpiryInterval; StowedClientRegistrationData(bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval); }; enum class DisconnectStage { NotInitiated, SendPendingAppData, // Set this before making a client ready to write, and EPOLLOUT will take care of it. Now }; struct AsyncAuthResult { AuthResult result; std::string authMethod; std::string authData; public: AsyncAuthResult(AuthResult result, const std::string authMethod, const std::string &authData); }; class Client { struct OutgoingTopicAliases { uint16_t cur_alias = 0; std::unordered_map aliases; }; struct WriteBuf { CirBuf buf; bool readyForWriting = false; WriteBuf(size_t size); }; friend class IoWrapper; FdManaged fd; bool fuzzMode = false; ProtocolVersion protocolVersion = ProtocolVersion::None; uint32_t maxOutgoingPacketSize; const uint32_t maxIncomingPacketSize; std::optional maxBufSizeOverride; uint8_t maxQos = 2; Mqtt3QoSExceedAction mqtt3QoSExceedAction = Mqtt3QoSExceedAction::Disconnect; uint16_t maxOutgoingTopicAliasValue = 0; uint16_t maxIncomingTopicAliasValue = 0; IoWrapper ioWrapper; std::string transportStr; CirBuf readbuf; MutexOwned writebuf; bool authenticated = false; bool connectPacketSeen = false; bool readyForReading = true; DisconnectStage disconnectStage = DisconnectStage::NotInitiated; bool outgoingConnection = false; bool outgoingConnectionEstablished = false; ClientType clientType = ClientType::Normal; bool supportsRetained = true; // Interestingly, only SERVERS can tell CLIENTS they don't support it (in CONNACK). The CONNECT packet has no field for it. std::string disconnectReason; int disconnectReasonLogLevel = LOG_NOTICE; std::chrono::time_point lastActivity = std::chrono::steady_clock::now(); std::string ssl_version; std::string clientid; std::string username; std::optional fmq_client_group_id; uint16_t keepalive = 10; bool clean_start = false; X509ClientVerification x509ClientVerification = X509ClientVerification::None; AllowListenerAnonymous allowAnonymousOverride = AllowListenerAnonymous::None; HaProxyMode mHaProxyMode = HaProxyMode::Off; std::unique_ptr acmeRedirectUrl; std::shared_ptr stagedWillPublish; std::shared_ptr willPublish; int epoll_fd = -1; LockedWeakPtr threadData; // The thread (data) that this client 'lives' in. std::shared_ptr session; std::unordered_map incomingTopicAliases; MutexOwned outgoingTopicAliases; std::string extendedAuthenticationMethod; std::unique_ptr stagedConnack; std::unique_ptr registrationData; Logger *logger = Logger::getInstance(); FMQSockaddr addr; std::weak_ptr bridgeState; bool asyncAuthenticating = false; std::unique_ptr> packetQueueAfterAsync; std::unique_ptr asyncAuthResult; void setReadyForWriting(bool val); void setReadyForWriting(bool val, MutexLocked &writebuf); void setReadyForReading(bool val); void setAddr(const std::string &address); public: uint8_t preAuthPacketCounter = 0; Client( ClientType type, int fd, std::shared_ptr threadData, FmqSsl &&ssl, ConnectionProtocol connectionProtocol, HaProxyMode haProxyMode, const struct sockaddr *addr, const Settings &settings, bool fuzzMode=false); Client(const Client &other) = delete; Client(Client &&other) = delete; ~Client(); void addToEpoll(uint32_t events); int getFd() { return fd.get();} bool isSslAccepted() const; bool isSsl() const; const std::optional &getHaProxySslCnName() const { return this->ioWrapper.getHaProxySslCnName(); } HaProxyMode getHaProxyMode() const { return this->mHaProxyMode; } HaProxyStage getHaProxyStage() const { return this->ioWrapper.getHaProxyStage();}; HaProxyConnectionType getHaProxyConnectionType() const { return this->ioWrapper.getHaProxyConnectionType(); } HaProxyStage readHaProxyData(); bool getSslReadWantsWrite() const; bool getSslWriteWantsRead() const; ProtocolVersion getProtocolVersion() const; ConnectionProtocol getConnectionProtocol() const { return this->ioWrapper.getConnectionProtocol(); } void setProtocolVersion(ProtocolVersion version); void connectToBridgeTarget(FMQSockaddr addr); void setMaxBufSizeOverride(uint32_t val); void startOrContinueSslHandshake(); void setDisconnectStage(DisconnectStage val); DisconnectStage readFdIntoBuffer(); bool tryAcmeRedirect(); void respondWithRedirectURL(const std::string &request); void bufferToMqttPackets(std::vector &packetQueueIn, std::shared_ptr &sender); void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::optional &fmq_client_group_id, const std::string username, bool connectPacketSeen, uint16_t keepalive); void setClientProperties(ProtocolVersion protocolVersion, const std::string &clientId, const std::optional &fmq_client_group_id, const std::string username, bool connectPacketSeen, uint16_t keepalive, uint32_t maxOutgoingPacketSize, uint16_t maxOutgoingTopicAliasValue); void setClientProperties(bool connectPacketSeen, uint16_t keepalive, uint32_t maxOutgoingPacketSize, uint16_t maxOutgoingTopicAliasValue, bool supportsRetained); void setWill(const std::string &topic, const std::string &payload, bool retain, uint8_t qos); void stageWill(WillPublish &&willPublish); void setWillFromStaged(); void clearWill(); void setAuthenticated(bool value) { authenticated = value;} bool getAuthenticated() { return authenticated; } bool hasConnectPacketSeen() { return connectPacketSeen; } void setHasConnectPacketSeen() { connectPacketSeen = true; } const std::string &getClientId() { return this->clientid; } void setClientId(const std::string &id); const std::string &getUsername() const { return this->username; } const std::optional &getFmqClientGroupId() const { return this->fmq_client_group_id; } std::string &getMutableUsername(); std::shared_ptr &getWill() { return this->willPublish; } const std::shared_ptr &getStagedWill() { return this->stagedWillPublish; } void assignSession(const std::shared_ptr &session); std::shared_ptr getSession(); void setDisconnectReason(const std::string &reason, const int logLevel=-1); void setDisconnectReasonFromSocketError(); std::chrono::seconds getSecondsTillKeepAliveAction() const; const std::optional &getLocalPrefix() const; const std::optional &getRemotePrefix() const; void writeText(const std::string &text); void writePing(); void writePingResp(); void writeLoginPacket(); PacketDropReason writeMqttPacket(const MqttPacket &packet); PacketDropReason writeMqttPacketAndBlameThisClient( PublishCopyFactory ©Factory, uint8_t max_qos, uint16_t packet_id, bool retain, uint32_t subscriptionIdentifier, const std::optional &topic_override); PacketDropReason writeMqttPacketAndBlameThisClient(const MqttPacket &packet); void writeBufIntoFd(); DisconnectStage getDisconnectStage() const { return disconnectStage; } const FMQSockaddr &getAddr() const; std::string repr(); std::string repr_endpoint(); bool keepAliveExpired(); std::string getKeepAliveInfoString() const; void resetBuffersIfEligible(); void setTopicAlias(const uint16_t alias_id, const std::string &topic); const std::string &getTopicAlias(const uint16_t id) const; uint32_t getMaxIncomingPacketSize() const; uint16_t getMaxIncomingTopicAliasValue() const; void sendOrQueueWill(); void setRegistrationData(bool clean_start, uint16_t client_receive_max, uint32_t sessionExpiryInterval); const std::unique_ptr &getRegistrationData() const; void clearRegistrationData(); void stageConnack(std::unique_ptr &&c); void sendConnackSuccess(); void sendConnackDeny(ReasonCodes reason); void addAuthReturnDataToStagedConnAck(const std::string &authData); void setExtendedAuthenticationMethod(const std::string &authMethod); const std::string &getExtendedAuthenticationMethod() const; std::shared_ptr lockThreadData(); void setBridgeState(std::shared_ptr bridgeState); bool isOutgoingConnection() const; std::shared_ptr getBridgeState(); void setBridgeConnected(); void detectOutgoingConnectionEstablished(); bool getOutgoingConnectionEstablished() const; ClientType getClientType() const { return clientType; } void setClientType(ClientType val); bool isRetainedAvailable() const {return supportsRetained; }; #ifdef TESTING std::function onPacketReceived; #endif #ifndef NDEBUG void setFakeUpgraded(); #endif void setSslVerify(X509ClientVerification verificationMode); std::optional getUsernameFromPeerCertificate(); X509ClientVerification getX509ClientVerification() const; void setAllowAnonymousOverride(const AllowListenerAnonymous allow); AllowListenerAnonymous getAllowAnonymousOverride() const; void setAcmeRedirect(const std::optional &url); const std::unique_ptr &getAcmeRedirectUrl() const { return this->acmeRedirectUrl; }; void setAsyncAuthenticating() { this->asyncAuthenticating = true; } bool getAsyncAuthenticating() const { return this->asyncAuthenticating; } void addPacketToAfterAsyncQueue(MqttPacket &&p); void handleAfterAsyncQueue(std::shared_ptr &sender); void setAsyncAuthResult(const AsyncAuthResult &v); bool hasAsyncAuthResult() const { return this->asyncAuthResult.operator bool() ; } std::unique_ptr stealAsyncAuthResult(); bool getCleanStart() const { return clean_start;} Mqtt3QoSExceedAction getMqtt3QoSExceedAction() const { return this->mqtt3QoSExceedAction;} void setMqtt3QoSExceedAction(Mqtt3QoSExceedAction action) { this->mqtt3QoSExceedAction = action;} uint8_t getMaxQos() const { return this->maxQos; } void setMaxQos(uint8_t qos) { this->maxQos = qos; } }; #endif // CLIENT_H ================================================ FILE: clientacceptqueue.cpp ================================================ #include #include "utils.h" #include "logger.h" #include "clientacceptqueue.h" void ClientAcceptQueue::wakeUp() { uint64_t one = 1; check(write(event_fd.get(), &one, sizeof(uint64_t))); } void ClientAcceptQueue::readFd() { uint64_t eventfd_value = 0; if (read(event_fd.get(), &eventfd_value, sizeof(uint64_t)) < 0) Logger::getInstance()->log(LOG_ERROR) << "Error reading fd in ClientAcceptQueue: " << strerror(errno); } void ClientAcceptQueue::giveClient(std::shared_ptr &&client) { bool wakeUpNeeded = true; { auto locked = clients.lock(); wakeUpNeeded = locked->empty(); locked->emplace_back(std::move(client)); // We must give up ownership here, to avoid calling the client destructor in the main thread. } if (wakeUpNeeded) wakeUp(); } std::vector> ClientAcceptQueue::takeClients() { std::vector> clientsTaken; { auto locked = clients.lock(); clientsTaken = std::move(*locked); locked->clear(); } return clientsTaken; } void ClientAcceptQueue::giveBridge(std::shared_ptr &&bridge) { auto locked = bridges.lock(); locked->emplace_back(std::move(bridge)); // We must give up ownership here, to avoid calling the client destructor in the main thread. } std::vector> ClientAcceptQueue::takeBridges() { std::vector> bridgesTaken; { auto locked = bridges.lock(); bridgesTaken = std::move(*locked); locked->clear(); } return bridgesTaken; } ================================================ FILE: clientacceptqueue.h ================================================ #ifndef CLIENTACCEPTQUEUE_H #define CLIENTACCEPTQUEUE_H #include #include #include "mutexowned.h" #include "fdmanaged.h" #include "client.h" struct ClientAcceptQueue { MutexOwned>> clients; MutexOwned>> bridges; FdManaged event_fd = FdManaged(eventfd(0, EFD_NONBLOCK)); void wakeUp(); void readFd(); void giveClient(std::shared_ptr &&client); std::vector> takeClients(); void giveBridge(std::shared_ptr &&bridge); std::vector> takeBridges(); }; #endif // CLIENTACCEPTQUEUE_H ================================================ FILE: configfileparser.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "configfileparser.h" #include #include #include #include #include #include #include #include #include #include #include #include "exceptions.h" #include "utils.h" #include "globber.h" #include "globals.h" /** * @brief Like std::stoi, but demands that the entire value is consumed * @param key Unused except for informing the user in case of problems * @param value the string to parse * @return the parsed integer */ int full_stoi(const std::string &key, const std::string &value) { size_t ptr; int newVal = std::stoi(value, &ptr); if (ptr != value.length()) { throw ConfigFileException(formatString("%s's value of '%s' can't be parsed to a number", key.c_str(), value.c_str())); } return newVal; } /** * @brief Like std::stoul, but demands that the entire value is consumed * @param key Unused except for informing the user in case of problems * @param value the string to parse * @return the parsed unsigned long **/ unsigned long full_stoul(const std::string &key, const std::string &value, int base=10) { size_t ptr; unsigned long newVal = std::stoul(value, &ptr, base); if (ptr != value.length()) { throw ConfigFileException(formatString("%s's value of '%s' can't be parsed to a number", key.c_str(), value.c_str())); } return newVal; } void ConfigFileParser::testCorrectNumberOfValues(const std::string &key, size_t expected_values, const std::vector &values) { if (values.size() != expected_values) { std::ostringstream oss; oss << "Option " << key << " expected " << expected_values << ", got " << values.size() << " arguments"; if (values.size() > expected_values) { oss << ". Superflous ones: "; for (size_t i = expected_values; i < values.size(); i++) { const std::string &rest = values.at(i); oss << rest; if (i + 1 >= values.size()) oss << "."; else oss << ", "; } } throw ConfigFileException(oss.str()); } } /** * @brief ConfigFileParser::testKeyValidity tests if two strings match and whether it's a valid config key. * @param key * @param matchKey * @param validKeys * @return * * Use of this function prevents adding config keys that you forget to add to the sets with valid keys. */ bool ConfigFileParser::testKeyValidity(const std::string &key, const std::string &matchKey, const std::set &validKeys) const { auto valid_key_it = validKeys.find(key); if (valid_key_it == validKeys.end()) { std::ostringstream oss; oss << "Config key '" << key << "' is not valid (here)."; auto alternative = findCloseStringMatch(validKeys.begin(), validKeys.end(), key); if (alternative != validKeys.end()) { // The space before the question mark is to make copying using mouse-double-click possible. oss << " Did you mean: " << *alternative << " ?"; } throw ConfigFileException(oss.str()); } { auto valid_key_it = validKeys.find(matchKey); if (valid_key_it == validKeys.end()) { std::ostringstream oss; oss << "BUG: you still need to add '" << matchKey << "' as valid config key."; throw ConfigFileException(oss.str()); } } return key == matchKey; } void ConfigFileParser::checkFileExistsAndReadable(const std::string &key, const std::string &pathToCheck, ssize_t max_size) { if (access(pathToCheck.c_str(), R_OK) != 0) { std::ostringstream oss; oss << "Error for '" << key << "': " << pathToCheck << " is not there or not readable"; throw ConfigFileException(oss.str()); } struct stat statbuf {}; if (stat(pathToCheck.c_str(), &statbuf) < 0) throw ConfigFileException(formatString("Reading stat of '%s' failed.", pathToCheck.c_str())); if (!S_ISREG(statbuf.st_mode)) { throw ConfigFileException(formatString("Error for '%s': '%s' is not a regular file.", key.c_str(), pathToCheck.c_str())); } if (statbuf.st_size > max_size) { throw ConfigFileException(formatString("Error for '%s': '%s' is bigger than %zd bytes.", key.c_str(), pathToCheck.c_str(), max_size)); } } void ConfigFileParser::checkFileOrItsDirWritable(const std::string &filepath) { if (access(filepath.c_str(), F_OK) == 0) { if (access(filepath.c_str(), W_OK) != 0) { std::string msg = formatString("File '%s' is there, but not writable", filepath.c_str()); throw ConfigFileException(msg); } return; } std::string dirname(dirnameOf(filepath)); if (access(dirname.c_str(), W_OK) != 0) { std::string msg = formatString("File '%s' is not there and can't be created, because '%s' is also not writable", filepath.c_str(), dirname.c_str()); throw ConfigFileException(msg); } } void ConfigFileParser::checkDirExists(const std::string &key, const std::string &dir) { struct stat statbuf {}; if (stat(dir.c_str(), &statbuf) < 0) throw ConfigFileException(formatString("Error for '%s': path '%s' does not exist or reading stat failed.", key.c_str(), dir.c_str())); if (!S_ISDIR(statbuf.st_mode)) { throw ConfigFileException(formatString("Error for '%s': '%s' is not a directory.", key.c_str(), dir.c_str())); } } ConfigFileParser::ConfigFileParser(const std::string &path) : path(path) { validKeys.insert("plugin"); validKeys.insert("plugin_serialize_init"); validKeys.insert("plugin_serialize_auth_checks"); validKeys.insert("plugin_timer_period"); validKeys.insert("log_file"); validKeys.insert("quiet"); validKeys.insert("allow_unsafe_clientid_chars"); validKeys.insert("allow_unsafe_username_chars"); validKeys.insert("client_initial_buffer_size"); validKeys.insert("max_packet_size"); validKeys.insert("max_qos"); validKeys.insert("mqtt3_qos_exceed_action"); validKeys.insert("log_debug"); validKeys.insert("log_subscriptions"); validKeys.insert("log_publishes"); validKeys.insert("log_level"); validKeys.insert("mosquitto_password_file"); validKeys.insert("mosquitto_acl_file"); validKeys.insert("allow_anonymous"); validKeys.insert("rlimit_nofile"); validKeys.insert("expire_sessions_after_seconds"); validKeys.insert("thread_count"); validKeys.insert("storage_dir"); validKeys.insert("max_qos_msg_pending_per_client"); validKeys.insert("max_qos_bytes_pending_per_client"); validKeys.insert("wills_enabled"); validKeys.insert("retained_messages_mode"); validKeys.insert("retained_messages_node_limit"); validKeys.insert("expire_retained_messages_after_seconds"); validKeys.insert("retained_message_node_lifetime"); validKeys.insert("expire_retained_messages_time_budget_ms"); validKeys.insert("websocket_set_real_ip_from"); validKeys.insert("shared_subscription_targeting"); validKeys.insert("max_incoming_topic_alias_value"); validKeys.insert("max_outgoing_topic_alias_value"); validKeys.insert("client_max_write_buffer_size"); validKeys.insert("retained_messages_delivery_limit"); validKeys.insert("include_dir"); validKeys.insert("rebuild_subscription_tree_interval_seconds"); validKeys.insert("minimum_wildcard_subscription_depth"); validKeys.insert("max_topic_split_depth"); validKeys.insert("wildcard_subscription_deny_mode"); validKeys.insert("zero_byte_username_is_anonymous"); validKeys.insert("overload_mode"); validKeys.insert("max_event_loop_drift"); validKeys.insert("set_retained_message_defer_timeout"); validKeys.insert("set_retained_message_defer_timeout_spread"); validKeys.insert("save_state_interval"); validKeys.insert("subscription_node_lifetime"); validKeys.insert("subscription_identifiers_enabled"); validKeys.insert("persistence_data_to_save"); validKeys.insert("max_string_length"); validListenKeys.insert("port"); validListenKeys.insert("protocol"); validListenKeys.insert("fullchain"); validListenKeys.insert("privkey"); validListenKeys.insert("inet_protocol"); validListenKeys.insert("inet4_bind_address"); validListenKeys.insert("inet6_bind_address"); validListenKeys.insert("unix_socket_path"); validListenKeys.insert("unix_socket_user"); validListenKeys.insert("unix_socket_group"); validListenKeys.insert("unix_socket_mode"); validListenKeys.insert("haproxy"); validListenKeys.insert("client_verification_ca_file"); validListenKeys.insert("client_verification_ca_dir"); validListenKeys.insert("client_verification_still_do_authn"); validListenKeys.insert("allow_anonymous"); validListenKeys.insert("tcp_nodelay"); validListenKeys.insert("minimum_tls_version"); validListenKeys.insert("overload_mode"); validListenKeys.insert("acme_redirect_url"); validListenKeys.insert("drop_on_absent_certificate"); validListenKeys.insert("max_buffer_size"); validListenKeys.insert("max_qos"); validListenKeys.insert("mqtt3_qos_exceed_action"); validListenKeys.insert("only_allow_from"); validListenKeys.insert("deny_from"); validBridgeKeys.insert("local_username"); validBridgeKeys.insert("remote_username"); validBridgeKeys.insert("remote_password"); validBridgeKeys.insert("remote_clean_start"); validBridgeKeys.insert("remote_session_expiry_interval"); validBridgeKeys.insert("remote_retain_available"); validBridgeKeys.insert("local_clean_start"); validBridgeKeys.insert("local_session_expiry_interval"); validBridgeKeys.insert("subscribe"); validBridgeKeys.insert("publish"); validBridgeKeys.insert("clientid_prefix"); validBridgeKeys.insert("use_saved_clientid"); validBridgeKeys.insert("inet_protocol"); validBridgeKeys.insert("address"); validBridgeKeys.insert("fullchain"); validBridgeKeys.insert("privkey"); validBridgeKeys.insert("ca_file"); validBridgeKeys.insert("ca_dir"); validBridgeKeys.insert("port"); validBridgeKeys.insert("tls"); validBridgeKeys.insert("protocol_version"); validBridgeKeys.insert("bridge_protocol_bit"); validBridgeKeys.insert("keepalive"); validBridgeKeys.insert("max_outgoing_topic_aliases"); validBridgeKeys.insert("max_incoming_topic_aliases"); validBridgeKeys.insert("tcp_nodelay"); validBridgeKeys.insert("local_prefix"); validBridgeKeys.insert("remote_prefix"); validBridgeKeys.insert("minimum_tls_version"); validBridgeKeys.insert("connection_count"); validBridgeKeys.insert("max_buffer_size"); } std::list ConfigFileParser::readFileRecursively(const std::string &path) const { std::list lines; if (path.empty()) return lines; checkFileExistsAndReadable("application config file", path, 1024*1024*10); std::ifstream infile(path, std::ios::in); if (!infile.is_open()) { std::ostringstream oss; oss << "Error loading " << path; throw ConfigFileException(oss.str()); } for(std::string line; getline(infile, line ); ) { lines.push_back(line); if (strContains(line, "include_dir")) { std::smatch matches; if (std::regex_match(line, matches, key_value_regex)) { const std::string key = matches[1].str(); const std::string value = matches[2].str(); if (key == "include_dir") { Logger *logger = Logger::getInstance(); checkDirExists(key, value); Globber globber; std::vector files = globber.getGlob(value + "/*.conf"); if (files.empty()) { logger->logf(LOG_WARNING, "Including '%s' yielded 0 files.", value.c_str()); } for(const std::string &path_from_glob : files) { std::list newLines = readFileRecursively(path_from_glob); lines.insert(lines.cend(), newLines.begin(), newLines.end()); } } } } } return lines; } void ConfigFileParser::loadFile(bool test) { if (path.empty()) return; const std::list unprocessed_lines = readFileRecursively(path); std::list lines; int blockDepth = 0; std::ostringstream oss; int linenr = 0; // First parse the file and keep the valid lines. for (std::string line : unprocessed_lines) { trim(line); linenr++; if (startsWith(line, "#")) continue; if (line.empty()) continue; // The regex matcher can be made to crash on very long lines, so we're protecting ourselves. if (line.length() > 256) { throw ConfigFileException(formatString("Error at line %d in '%s': line suspiciouly long.", linenr, path.c_str())); } std::smatch matches; const bool blockStartMatch = std::regex_search(line, matches, block_regex_start); const bool blockEndMatch = std::regex_search(line, matches, block_regex_end); if (!std::regex_search(line, matches, key_value_regex) && !blockStartMatch && !blockEndMatch) { oss << "Line '" << line << "' invalid"; throw ConfigFileException(oss.str()); } if (blockStartMatch) blockDepth++; if (blockEndMatch) blockDepth--; lines.push_back(line); } if (blockDepth != 0) { throw ConfigFileException("Unmatched curly braces"); } std::unordered_map pluginOpts; std::stack curParseLevel; std::shared_ptr curListener; std::optional curBridge; std::list preMultipliedBridges; Settings tmpSettings; curParseLevel.push(ConfigParseLevel::Root); const std::set blockNames {"listen", "bridge"}; // Then once we know the config file is valid, process it. for (std::string &line : lines) { std::smatch matches; if (std::regex_match(line, matches, block_regex_start)) { const std::string &key = matches[1].str(); if (testKeyValidity(key, "listen", blockNames)) { if (curParseLevel.top() != ConfigParseLevel::Root) throw ConfigFileException("Block '" + key + "' must be a root level"); curParseLevel.push(ConfigParseLevel::Listen); curListener = std::make_shared(); } else if (testKeyValidity(key, "bridge", blockNames)) { if (curParseLevel.top() != ConfigParseLevel::Root) throw ConfigFileException("Block '" + key + "' must be a root level"); curParseLevel.push(ConfigParseLevel::Bridge); curBridge = std::make_optional(); } else { std::ostringstream oss; oss << "'" << key << "' is not a valid block."; auto alt = findCloseStringMatch(blockNames.begin(), blockNames.end(), key); if (alt != blockNames.end()) { oss << " Did you mean: " << *alt << " ?"; } throw ConfigFileException(oss.str()); } continue; } else if (std::regex_match(line, matches, block_regex_end)) { if (curParseLevel.top() == ConfigParseLevel::Listen) { curListener->isValid(); if (curListener->dropListener()) { if (test) { Logger::getInstance()->log(LOG_NOTICE) << "Approved missing certificates: dropping " << curListener->getProtocolName() << " listener, port " << curListener->port; } } else { tmpSettings.listeners.push_back(curListener); } curListener.reset(); } else if (curParseLevel.top() == ConfigParseLevel::Bridge) { curBridge->isValid(); preMultipliedBridges.push_back(std::move(curBridge.value())); } curParseLevel.pop(); if (curParseLevel.empty()) { throw ConfigFileException("Too many closing '}'"); } continue; } std::regex_match(line, matches, key_value_regex); std::string key = matches[1].str(); const std::string value_unparsed = matches[2].str(); const std::vector values = parseValuesWithOptionalQuoting(value_unparsed); const std::string &value = values.at(0); size_t number_of_expected_values = 1; // Most lines only accept 1 argument, a select few 2. std::string valueTrimmed = value; trim(valueTrimmed); try { if (curParseLevel.top() == ConfigParseLevel::Listen) { if (testKeyValidity(key, "protocol", validListenKeys)) { if (value != "mqtt" && value != "websockets" && value != "acme") throw ConfigFileException(formatString("Protocol '%s' is not a valid listener protocol", value.c_str())); if (value == "mqtt") curListener->connectionProtocol = ConnectionProtocol::Mqtt; else if (value == "websockets") curListener->connectionProtocol = ConnectionProtocol::WebsocketMqtt; else if (value == "acme") curListener->connectionProtocol = ConnectionProtocol::AcmeOnly; } else if (testKeyValidity(key, "port", validListenKeys)) { curListener->port = full_stoi(key, value); } else if (testKeyValidity(key, "fullchain", validListenKeys)) { curListener->sslFullchain = value; } if (testKeyValidity(key, "privkey", validListenKeys)) { curListener->sslPrivkey = value; } if (testKeyValidity(key, "inet_protocol", validListenKeys)) { if (value == "ip4") curListener->protocol = ListenerProtocol::IPv4; else if (value == "ip6") curListener->protocol = ListenerProtocol::IPv6; else if (value == "ip4_ip6") curListener->protocol = ListenerProtocol::IPv46; else if (value == "unix") curListener->protocol = ListenerProtocol::Unix; else throw ConfigFileException(formatString("Invalid inet protocol: %s", value.c_str())); } if (testKeyValidity(key, "inet4_bind_address", validListenKeys)) { curListener->inet4BindAddress = value; } if (testKeyValidity(key, "inet6_bind_address", validListenKeys)) { curListener->inet6BindAddress = value; } if (testKeyValidity(key, "unix_socket_path", validListenKeys)) { curListener->unixSocketPath = value; } if (testKeyValidity(key, "unix_socket_user", validListenKeys)) { /* * We're not verifying the existence of the user. This would make it non-deterministic between * environments, and perhaps users want to already have the config ready before the users are * present on the system. */ curListener->unixSocketUser = valueTrimmed; } if (testKeyValidity(key, "unix_socket_group", validListenKeys)) { // See comment above about not checking the the presence of the group. curListener->unixSocketGroup = valueTrimmed; } if (testKeyValidity(key, "unix_socket_mode", validListenKeys)) { curListener->unixSocketMode = full_stoul(key, value, 8); } if (testKeyValidity(key, "haproxy", validListenKeys)) { if (valueTrimmed == "client_verification") curListener->haProxyMode = HaProxyMode::HaProxyClientVerification; else if (valueTrimmed == "client_verification_with_authn") curListener->haProxyMode = HaProxyMode::HaProxyClientVerficiationWithAuthn; else { const bool val = stringTruthiness(value); if (val) curListener->haProxyMode = HaProxyMode::On; } } if (testKeyValidity(key, "client_verification_ca_file", validListenKeys)) { checkFileExistsAndReadable(key, valueTrimmed, 1024*1024); curListener->clientVerificationCaFile = valueTrimmed; } if (testKeyValidity(key, "client_verification_ca_dir", validListenKeys)) { checkDirExists(key, value); curListener->clientVerificationCaDir = valueTrimmed; } if (testKeyValidity(key, "client_verification_still_do_authn", validListenKeys)) { bool val = stringTruthiness(value); curListener->clientVerifictionStillDoAuthn = val; } if (testKeyValidity(key, "allow_anonymous", validListenKeys)) { bool val = stringTruthiness(value); curListener->allowAnonymous = val ? AllowListenerAnonymous::Yes : AllowListenerAnonymous::No; } if (testKeyValidity(key, "tcp_nodelay", validListenKeys)) { bool val = stringTruthiness(value); curListener->tcpNoDelay = val; } if (testKeyValidity(key, "minimum_tls_version", validListenKeys)) { if (valueTrimmed == "tlsv1.3") curListener->minimumTlsVersion = TLSVersion::TLSv1_3; else if (valueTrimmed == "tlsv1.2") curListener->minimumTlsVersion = TLSVersion::TLSv1_2; else if (valueTrimmed == "tlsv1.1") curListener->minimumTlsVersion = TLSVersion::TLSv1_1; else throw ConfigFileException("Value '" + valueTrimmed + "' is not a valid value for " + key); } if (testKeyValidity(key, "overload_mode", validListenKeys)) { const std::string _val = str_tolower(value); if (_val == "log") curListener->overloadMode = OverloadMode::Log; else if (_val == "close_new_clients") curListener->overloadMode = OverloadMode::CloseNewClients; else throw ConfigFileException(formatString("Value '%s' for '%s' is invalid.", value.c_str(), key.c_str())); } if (testKeyValidity(key, "acme_redirect_url", validListenKeys)) { curListener->acmeRedirectURL = valueTrimmed; } if (testKeyValidity(key, "drop_on_absent_certificate", validListenKeys)) { curListener->dropOnAbsentCertificates = stringTruthiness(value); } if (testKeyValidity(key, "max_buffer_size", validListenKeys)) { size_t val = full_stoul(key, valueTrimmed); if (val > 1073741824) throw ConfigFileException("Bridge's " + key + " cannot be bigger than 1 GB"); curListener->maxBufferSize = val; } if (testKeyValidity(key, "max_qos", validListenKeys)) { const int qos = full_stoi(key, value); if (qos < 0 || qos > 2) { std::ostringstream oss; oss << "Value " << qos << " for " << key << " is invalid."; throw ConfigFileException(oss.str()); } curListener->maxQos = qos; } if (testKeyValidity(key, "mqtt3_qos_exceed_action", validListenKeys)) { if (valueTrimmed == "disconnect") curListener->mqtt3QoSExceedAction = Mqtt3QoSExceedAction::Disconnect; else if (valueTrimmed == "drop") curListener->mqtt3QoSExceedAction = Mqtt3QoSExceedAction::Drop; else { std::ostringstream oss; oss << "Value '" << valueTrimmed << "' for " << key << " is invalid."; throw ConfigFileException(oss.str()); } } if (testKeyValidity(key, "only_allow_from", validListenKeys)) { Network net(valueTrimmed); curListener->exclusiveAllowList.push_back(net); } if (testKeyValidity(key, "deny_from", validListenKeys)) { Network net(valueTrimmed); curListener->denyList.push_back(net); } testCorrectNumberOfValues(key, number_of_expected_values, values); continue; } else if (curParseLevel.top() == ConfigParseLevel::Bridge) { if (testKeyValidity(key, "local_username", validBridgeKeys)) { curBridge->local_username = value; } if (testKeyValidity(key, "remote_username", validBridgeKeys)) { curBridge->remote_username = value; } if (testKeyValidity(key, "remote_password", validBridgeKeys)) { curBridge->remote_password = value; } if (testKeyValidity(key, "remote_clean_start", validBridgeKeys)) { curBridge->remoteCleanStart = stringTruthiness(value); } if (testKeyValidity(key, "remote_session_expiry_interval", validBridgeKeys)) { curBridge->remoteSessionExpiryInterval = value_to_int_ranged(key, value); } if (testKeyValidity(key, "local_clean_start", validBridgeKeys)) { curBridge->localCleanStart = stringTruthiness(value); } if (testKeyValidity(key, "local_session_expiry_interval", validBridgeKeys)) { curBridge->localSessionExpiryInterval = value_to_int_ranged(key, value); } if (testKeyValidity(key, "subscribe", validBridgeKeys)) { if (!isValidUtf8(value) || !isValidSubscribePath(value)) throw ConfigFileException(formatString("Path '%s' is not a valid subscribe match", value.c_str())); BridgeTopicPath topicPath; if (values.size() >= 2) { number_of_expected_values = 2; const std::string &qosstr = values.at(1); if (!qosstr.empty()) { topicPath.qos = value_to_int_ranged(key, qosstr, 0, 2); } } topicPath.topic = value; curBridge->subscribes.push_back(topicPath); } if (testKeyValidity(key, "publish", validBridgeKeys)) { if (!isValidUtf8(value) || !isValidSubscribePath(value)) throw ConfigFileException(formatString("Path '%s' is not a valid publish match", value.c_str())); BridgeTopicPath topicPath; if (values.size() >= 2) { number_of_expected_values = 2; const std::string &qosstr = values.at(1); if (!qosstr.empty()) { topicPath.qos = value_to_int_ranged(key, qosstr, 0, 2); } } topicPath.topic = value; curBridge->publishes.push_back(topicPath); } if (testKeyValidity(key, "clientid_prefix", validBridgeKeys)) { if (value.length() > 10) throw ConfigFileException("Value for 'clientid_prefix' can't be longer than 10 chars"); if (!isValidShareName(value)) throw ConfigFileException("Value for 'clientid_prefix' contains invalid charachters"); curBridge->clientidPrefix = value; } if (testKeyValidity(key, "address", validBridgeKeys)) { curBridge->address = value; } if (testKeyValidity(key, "port", validBridgeKeys)) { curBridge->port = value_to_int_ranged(key, value); } if (testKeyValidity(key, "protocol_version", validBridgeKeys)) { ProtocolVersion v; if (value == "mqtt3.1") v = ProtocolVersion::Mqtt31; else if (value == "mqtt3.1.1") v = ProtocolVersion::Mqtt311; else if (value == "mqtt5") v = ProtocolVersion::Mqtt5; else throw ConfigFileException(formatString("Value '%s' is not valid for 'protocol_version'", value.c_str())); curBridge->protocolVersion = v; } if (testKeyValidity(key, "bridge_protocol_bit", validBridgeKeys)) { curBridge->bridgeProtocolBit = stringTruthiness(value); } if (testKeyValidity(key, "keepalive", validBridgeKeys)) { curBridge->keepalive = value_to_int_ranged(key, value, 10); } if (testKeyValidity(key, "tls", validBridgeKeys)) { BridgeTLSMode mode = BridgeTLSMode::None; if (value == "unverified") mode = BridgeTLSMode::Unverified; else if (value == "on") mode = BridgeTLSMode::On; else if (value == "off") mode = BridgeTLSMode::None; else throw ConfigFileException(formatString("Value '%s' is not valid for 'tls'", value.c_str())); curBridge->tlsMode = mode; } if (testKeyValidity(key, "fullchain", validBridgeKeys)) { checkFileExistsAndReadable(key, value, 1024*1024*100); curBridge->sslFullchain = value; } if (testKeyValidity(key, "privkey", validBridgeKeys)) { checkFileExistsAndReadable(key, value, 1024*1024*100); curBridge->sslPrivkey = value; } if (testKeyValidity(key, "ca_file", validBridgeKeys)) { checkFileExistsAndReadable(key, value, 1024*1024*100); curBridge->caFile = value; } if (testKeyValidity(key, "ca_dir", validBridgeKeys)) { checkDirExists(key, value); curBridge->caDir = value; } if (testKeyValidity(key, "max_incoming_topic_aliases", validBridgeKeys)) { curBridge->maxIncomingTopicAliases = value_to_int_ranged(key, value); } if (testKeyValidity(key, "max_outgoing_topic_aliases", validBridgeKeys)) { curBridge->maxOutgoingTopicAliases = value_to_int_ranged(key, value); } if (testKeyValidity(key, "inet_protocol", validBridgeKeys)) { if (value == "ip4") curBridge->inet_protocol = ListenerProtocol::IPv4; else if (value == "ip6") curBridge->inet_protocol = ListenerProtocol::IPv6; else if (value == "ip4_ip6") curBridge->inet_protocol = ListenerProtocol::IPv46; else throw ConfigFileException(formatString("Invalid inet protocol: %s", value.c_str())); } if (testKeyValidity(key, "use_saved_clientid", validBridgeKeys)) { curBridge->useSavedClientId = stringTruthiness(value); } if (testKeyValidity(key, "remote_retain_available", validBridgeKeys)) { curBridge->remoteRetainAvailable = stringTruthiness(value); } if (testKeyValidity(key, "tcp_nodelay", validBridgeKeys)) { curBridge->tcpNoDelay = true; } if (testKeyValidity(key, "local_prefix", validBridgeKeys)) { if (value.empty()) throw ConfigFileException("Option '" + key + "' can't be empty."); if (!endsWith(value, "/")) throw ConfigFileException("Option '" + key + "' must end in a '/'."); curBridge->local_prefix = value; } if (testKeyValidity(key, "remote_prefix", validBridgeKeys)) { if (value.empty()) throw ConfigFileException("Option '" + key + "' can't be empty."); if (!endsWith(value, "/")) throw ConfigFileException("Option '" + key + "' must end in a '/'."); curBridge->remote_prefix = value; } if (testKeyValidity(key, "minimum_tls_version", validBridgeKeys)) { if (valueTrimmed == "tlsv1.3") curBridge->minimumTlsVersion = TLSVersion::TLSv1_3; else if (valueTrimmed == "tlsv1.2") curBridge->minimumTlsVersion = TLSVersion::TLSv1_2; else if (valueTrimmed == "tlsv1.1") curBridge->minimumTlsVersion = TLSVersion::TLSv1_1; else throw ConfigFileException("Value '" + valueTrimmed + "' is not a valid value for " + key); } if (testKeyValidity(key, "connection_count", validBridgeKeys)) { if (valueTrimmed == "auto") { curBridge->connection_count = get_nprocs(); } else curBridge->connection_count = full_stoul(key, valueTrimmed); } if (testKeyValidity(key, "max_buffer_size", validBridgeKeys)) { size_t val = full_stoul(key, valueTrimmed); if (val > 1073741824) throw ConfigFileException("Bridge's " + key + " cannot be bigger than 1 GB"); curBridge->maxBufferSize = val; } testCorrectNumberOfValues(key, number_of_expected_values, values); continue; } const std::string plugin_opt_ = "plugin_opt_"; if (startsWith(key, plugin_opt_)) { key.replace(0, plugin_opt_.length(), ""); pluginOpts[key] = value; } else { if (testKeyValidity(key, "plugin", validKeys)) { checkFileExistsAndReadable(key, value, 1024*1024*100); tmpSettings.pluginPath = value; } if (testKeyValidity(key, "log_file", validKeys)) { checkFileOrItsDirWritable(value); tmpSettings.logPath = value; } if (testKeyValidity(key, "quiet", validKeys)) { Logger::getInstance()->log(LOG_WARNING) << "The config option '" << key << "' is deprecated. Use log_level instead."; bool tmp = stringTruthiness(value); tmpSettings.quiet = tmp; } if (testKeyValidity(key, "allow_unsafe_clientid_chars", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.allowUnsafeClientidChars = tmp; } if (testKeyValidity(key, "allow_unsafe_username_chars", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.allowUnsafeUsernameChars = tmp; } if (testKeyValidity(key, "plugin_serialize_init", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.pluginSerializeInit = tmp; } if (testKeyValidity(key, "plugin_serialize_auth_checks", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.pluginSerializeAuthChecks = tmp; } if (testKeyValidity(key, "client_initial_buffer_size", validKeys)) { int newVal = full_stoi(key, value); if (!isPowerOfTwo(newVal)) throw ConfigFileException("client_initial_buffer_size value " + value + " is not a power of two."); tmpSettings.clientInitialBufferSize = newVal; } if (testKeyValidity(key, "max_packet_size", validKeys)) { int newVal = full_stoi(key, value); if (newVal > ABSOLUTE_MAX_PACKET_SIZE) { std::ostringstream oss; oss << "Value for max_packet_size " << newVal << " is higher than absolute maximum " << ABSOLUTE_MAX_PACKET_SIZE; throw ConfigFileException(oss.str()); } tmpSettings.maxPacketSize = newVal; } if (testKeyValidity(key, "max_qos", validKeys)) { tmpSettings.maxQos = value_to_int_ranged(key, value, 0, 2); } if (testKeyValidity(key, "mqtt3_qos_exceed_action", validKeys)) { if (valueTrimmed == "disconnect") tmpSettings.mqtt3QoSExceedAction = Mqtt3QoSExceedAction::Disconnect; else if (valueTrimmed == "drop") tmpSettings.mqtt3QoSExceedAction = Mqtt3QoSExceedAction::Drop; else { std::ostringstream oss; oss << "Value '" << valueTrimmed << "' for " << key << " is invalid."; throw ConfigFileException(oss.str()); } } if (testKeyValidity(key, "log_debug", validKeys)) { Logger::getInstance()->log(LOG_WARNING) << "The config option '" << key << "' is deprecated. Use log_level instead."; bool tmp = stringTruthiness(value); tmpSettings.logDebug = tmp; } if (testKeyValidity(key, "log_level", validKeys)) { const std::string v = str_tolower(value); LogLevel level = LogLevel::None; if (v == "debug") level = LogLevel::Debug; else if (v == "info") level = LogLevel::Info; else if (v == "notice") level = LogLevel::Notice; else if (v == "warning") level = LogLevel::Warning; else if (v == "error") level = LogLevel::Warning; else if (v == "none") level = LogLevel::None; else throw ConfigFileException("Invalid log level: " + value); tmpSettings.logLevel = level; } if (testKeyValidity(key, "log_subscriptions", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.logSubscriptions = tmp; } if (testKeyValidity(key, "log_publishes", validKeys)) { tmpSettings.logPublishes = stringTruthiness(value); } if (testKeyValidity(key, "mosquitto_password_file", validKeys)) { checkFileExistsAndReadable("mosquitto_password_file", value, 1024*1024*1024); tmpSettings.mosquittoPasswordFile = value; } if (testKeyValidity(key, "mosquitto_acl_file", validKeys)) { checkFileExistsAndReadable("mosquitto_acl_file", value, 1024*1024*1024); tmpSettings.mosquittoAclFile = value; } if (testKeyValidity(key, "allow_anonymous", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.allowAnonymous = tmp; } if (testKeyValidity(key, "rlimit_nofile", validKeys)) { int newVal = full_stoi(key, value); if (newVal <= 0) { throw ConfigFileException(formatString("Value '%d' is negative.", newVal)); } tmpSettings.rlimitNoFile = newVal; } if (testKeyValidity(key, "expire_sessions_after_seconds", validKeys)) { const uint32_t newVal{value_to_int_ranged(key, value)}; if (newVal > 0 && newVal < 60) // 0 means disable { throw ConfigFileException(formatString("expire_sessions_after_seconds value '%d' is invalid. Valid values are 0, or 60 or higher.", newVal)); } tmpSettings.expireSessionsAfterSeconds = newVal; } if (testKeyValidity(key, "plugin_timer_period", validKeys)) { int newVal = full_stoi(key, value); if (newVal < 0) { throw ConfigFileException(formatString("plugin_timer_period value '%d' is invalid. Valid values are 0 or higher. 0 means disabled.", newVal)); } tmpSettings.pluginTimerPeriod = newVal; } if (testKeyValidity(key, "storage_dir", validKeys)) { std::string newPath = value; rtrim(newPath, '/'); checkWritableDir(newPath); tmpSettings.storageDir = newPath; } if (testKeyValidity(key, "thread_count", validKeys)) { int newVal = full_stoi(key, value); if (newVal < 0) { throw ConfigFileException(formatString("thread_count value '%d' is invalid. Valid values are 0 or higher. 0 means auto.", newVal)); } tmpSettings.threadCount = newVal; } if (testKeyValidity(key, "max_qos_msg_pending_per_client", validKeys)) { tmpSettings.maxQosMsgPendingPerClient = value_to_int_ranged(key, value, 32); } if (testKeyValidity(key, "max_qos_bytes_pending_per_client", validKeys)) { tmpSettings.maxQosBytesPendingPerClient = value_to_int_ranged(key, value, 4096); } if (testKeyValidity(key, "max_incoming_topic_alias_value", validKeys)) { tmpSettings.maxIncomingTopicAliasValue = value_to_int_ranged(key, value); } if (testKeyValidity(key, "max_outgoing_topic_alias_value", validKeys)) { tmpSettings.maxOutgoingTopicAliasValue = value_to_int_ranged(key, value); } if (testKeyValidity(key, "wills_enabled", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.willsEnabled = tmp; } if (testKeyValidity(key, "retained_messages_mode", validKeys)) { const std::string _val = str_tolower(value); if (_val == "enabled") tmpSettings.retainedMessagesMode = RetainedMessagesMode::Enabled; else if (_val == "enabled_without_persistence") tmpSettings.retainedMessagesMode = RetainedMessagesMode::EnabledWithoutPersistence; else if (_val == "enabled_without_retaining") tmpSettings.retainedMessagesMode = RetainedMessagesMode::EnabledWithoutRetaining; else if (_val == "downgrade") tmpSettings.retainedMessagesMode = RetainedMessagesMode::Downgrade; else if (_val == "drop") tmpSettings.retainedMessagesMode = RetainedMessagesMode::Drop; else if (_val == "disconnect_with_error") tmpSettings.retainedMessagesMode = RetainedMessagesMode::DisconnectWithError; else throw ConfigFileException(formatString("Value '%s' for '%s' is invalid.", value.c_str(), key.c_str())); } if (testKeyValidity(key, "expire_retained_messages_after_seconds", validKeys)) { uint32_t newVal = full_stoi(key, value); if (newVal < 1) { throw ConfigFileException(formatString("expire_retained_messages_after_seconds value '%d' is invalid. Valid values are between 1 and 4294967295.", newVal)); } tmpSettings.expireRetainedMessagesAfterSeconds = std::chrono::seconds(newVal); } if (testKeyValidity(key, "retained_message_node_lifetime", validKeys)) { const int val = full_stoi(key, value); if (val < 0) throw ConfigFileException("Option '" + key + "' must 0 or higher."); tmpSettings.retainedMessageNodeLifetime = std::chrono::seconds(val); } if (testKeyValidity(key, "expire_retained_messages_time_budget_ms", validKeys)) { Logger::getInstance()->log(LOG_WARNING) << "The config option '" << key << "' is deprecated."; } if (testKeyValidity(key, "websocket_set_real_ip_from", validKeys)) { Network net(value); tmpSettings.setRealIpFrom.push_back(std::move(net)); } if (testKeyValidity(key, "shared_subscription_targeting", validKeys)) { const std::string _val = str_tolower(value); if (_val == "round_robin") tmpSettings.sharedSubscriptionTargeting = SharedSubscriptionTargeting::RoundRobin; else if (_val == "sender_hash") tmpSettings.sharedSubscriptionTargeting = SharedSubscriptionTargeting::SenderHash; else if (_val == "first") tmpSettings.sharedSubscriptionTargeting = SharedSubscriptionTargeting::First; else throw ConfigFileException(formatString("Value '%s' for '%s' is invalid.", value.c_str(), key.c_str())); } if (testKeyValidity(key, "client_max_write_buffer_size", validKeys)) { tmpSettings.clientMaxWriteBufferSize = value_to_int_ranged(key, value, 4096); } if (testKeyValidity(key, "retained_messages_delivery_limit", validKeys)) { Logger::getInstance()->log(LOG_WARNING) << "The config option '" << key << "' is deprecated. Use 'retained_messages_node_limit' instead."; } if (testKeyValidity(key, "retained_messages_node_limit", validKeys)) { const uint32_t newVal{value_to_int_ranged(key, value)}; if (newVal == 0) throw ConfigFileException("Set '" + key + "' higher than 0, or use 'retained_messages_mode'."); tmpSettings.retainedMessagesNodeLimit = newVal; } if (testKeyValidity(key, "minimum_wildcard_subscription_depth", validKeys)) { tmpSettings.minimumWildcardSubscriptionDepth = value_to_int_ranged(key, value); } if (testKeyValidity(key, "max_topic_split_depth", validKeys)) { tmpSettings.maxTopicSplitDepth = value_to_int_ranged(key, value); } if (testKeyValidity(key, "wildcard_subscription_deny_mode", validKeys)) { const std::string _val = str_tolower(value); if (_val == "deny_all") tmpSettings.wildcardSubscriptionDenyMode = WildcardSubscriptionDenyMode::DenyAll; else if (_val == "deny_retained_only") tmpSettings.wildcardSubscriptionDenyMode = WildcardSubscriptionDenyMode::DenyRetainedOnly; else throw ConfigFileException(formatString("Value '%s' for '%s' is invalid.", value.c_str(), key.c_str())); } if (testKeyValidity(key, "zero_byte_username_is_anonymous", validKeys)) { tmpSettings.zeroByteUsernameIsAnonymous = stringTruthiness(value); } if (testKeyValidity(key, "overload_mode", validKeys)) { const std::string _val = str_tolower(value); if (_val == "log") tmpSettings.overloadMode = OverloadMode::Log; else if (_val == "close_new_clients") tmpSettings.overloadMode = OverloadMode::CloseNewClients; else throw ConfigFileException(formatString("Value '%s' for '%s' is invalid.", value.c_str(), key.c_str())); } if (testKeyValidity(key, "max_event_loop_drift", validKeys)) { const int val = full_stoi(key, value); if (val < 500) { throw ConfigFileException("Option '" + key + "' must be higher than 500 ms."); } tmpSettings.maxEventLoopDrift = std::chrono::milliseconds(val); } if (testKeyValidity(key, "set_retained_message_defer_timeout", validKeys)) { const int val = full_stoi(key, value); if (val < 0) throw ConfigFileException("Option '" + key + "' must 0 or higher."); tmpSettings.setRetainedMessageDeferTimeout = std::chrono::milliseconds(val); } if (testKeyValidity(key, "set_retained_message_defer_timeout_spread", validKeys)) { const int val = full_stoi(key, value); if (val < 0) throw ConfigFileException("Option '" + key + "' must 0 or higher."); tmpSettings.setRetainedMessageDeferTimeoutSpread = std::chrono::milliseconds(val); } if (testKeyValidity(key, "save_state_interval", validKeys)) { const int val = full_stoi(key, value); if (val < 300) throw ConfigFileException("Option '" + key + "' must 300 or higher."); tmpSettings.saveStateInterval = std::chrono::seconds(val); } if (testKeyValidity(key, "subscription_node_lifetime", validKeys)) { const int val = full_stoi(key, value); if (val < 0) throw ConfigFileException("Option '" + key + "' must 0 or higher."); tmpSettings.subscriptionNodeLifetime = std::chrono::seconds(val); } if (testKeyValidity(key, "subscription_identifiers_enabled", validKeys)) { bool tmp = stringTruthiness(value); tmpSettings.subscriptionIdentifierEnabled = tmp; } if (testKeyValidity(key, "persistence_data_to_save", validKeys)) { tmpSettings.persistenceDataToSave.clearAll(); int new_correct_number_of_args = 0; for(std::string arg : values) { new_correct_number_of_args++; bool set = true; if (arg.length() > 0 && arg.at(0) == '!') { set = false; arg.erase(0, 1); } if (arg == "all") { if (set) tmpSettings.persistenceDataToSave.setAll(); else tmpSettings.persistenceDataToSave.clearAll(); } else if (arg == "sessions_and_subscriptions") { if (set) tmpSettings.persistenceDataToSave.setFlag(PersistenceDataToSave::SessionsAndSubscriptions); else tmpSettings.persistenceDataToSave.clearFlag(PersistenceDataToSave::SessionsAndSubscriptions); } else if (arg == "retained_messages") { if (set) tmpSettings.persistenceDataToSave.setFlag(PersistenceDataToSave::RetainedMessages); else tmpSettings.persistenceDataToSave.clearFlag(PersistenceDataToSave::RetainedMessages); } else if (arg == "bridge_info") { if (set) tmpSettings.persistenceDataToSave.setFlag(PersistenceDataToSave::BridgeInfo); else tmpSettings.persistenceDataToSave.clearFlag(PersistenceDataToSave::BridgeInfo); } else throw ConfigFileException("Value '" + arg + "' is not a valid mode for " + key); } number_of_expected_values = new_correct_number_of_args; } if (testKeyValidity(key, "max_string_length", validKeys)) { tmpSettings.maxStringLength = value_to_int_ranged(key, valueTrimmed); } } } catch (std::invalid_argument &ex) // catch for the stoi() { throw ConfigFileException(ex.what()); } testCorrectNumberOfValues(key, number_of_expected_values, values); } checkUniqueBridgeNames(preMultipliedBridges); const std::string share_name_path = tmpSettings.getGeneratedShareNamesFilePath(); globals->bridgeClientGroupIds.loadShareNames(share_name_path, !test); for (const BridgeConfig &bc : preMultipliedBridges) { std::vector many = bc.multiply(); for (BridgeConfig &b : many) { tmpSettings.bridges.push_back(std::move(b)); } } tmpSettings.authOptCompatWrap = AuthOptCompatWrap(pluginOpts); tmpSettings.flashmqpluginOpts = std::move(pluginOpts); if (!test) { globals->bridgeClientGroupIds.saveShareNames(share_name_path); this->settings = tmpSettings; } } const Settings &ConfigFileParser::getSettings() { return settings; } ================================================ FILE: configfileparser.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef CONFIGFILEPARSER_H #define CONFIGFILEPARSER_H #include #include #include #include #include #include #include #include #include "settings.h" enum class ConfigParseLevel { Root, Listen, Bridge }; template typename std::enable_if::value, long long>::type value_to_int(const std::string &key, const std::string &value) { try { size_t len = 0; long long newVal{std::stoll(value, &len)}; if (len != value.length()) { throw std::exception(); } return newVal; } catch (std::exception &ex) { throw ConfigFileException(formatString("%s's value of '%s' can't be parsed to a number", key.c_str(), value.c_str())); } } template typename std::enable_if::value, long long unsigned>::type value_to_int(const std::string &key, const std::string &value) { try { size_t len = 0; long long unsigned newVal{std::stoull(value, &len)}; if (len != value.length()) { throw std::exception(); } return newVal; } catch (std::exception &ex) { throw ConfigFileException(formatString("%s's value of '%s' can't be parsed to a number", key.c_str(), value.c_str())); } } template T value_to_int_ranged(const std::string &key, const std::string &value, const T min=std::numeric_limits::min(), const T max=std::numeric_limits::max()) { const auto newVal{value_to_int(key, value)}; if (newVal < min || newVal > max) { std::ostringstream oss; oss << "Value '" << value << "' out of range for '" << key << "', which must be between "; if (sizeof(T) == 1) oss << static_cast(min); else oss << min; oss << " and "; if (sizeof(T) == 1) oss << static_cast(max); else oss << max; throw ConfigFileException(oss.str()); } return static_cast(newVal); } class ConfigFileParser { const std::string path; std::set validKeys; std::set validListenKeys; std::set validBridgeKeys; const std::regex key_value_regex = std::regex("^([\\w\\-]+)\\s+(.+)$"); const std::regex block_regex_start = std::regex("^([a-zA-Z0-9_\\-]+) *\\{$"); const std::regex block_regex_end = std::regex("^\\}$"); Settings settings; void static testCorrectNumberOfValues(const std::string &key, size_t expected_values, const std::vector &values); bool testKeyValidity(const std::string &key, const std::string &matchKey, const std::set &validKeys) const; public: void static checkFileExistsAndReadable(const std::string &key, const std::string &pathToCheck, ssize_t max_size = std::numeric_limits::max()); void static checkFileOrItsDirWritable(const std::string &filepath); void static checkDirExists(const std::string &key, const std::string &dir); ConfigFileParser(const std::string &path); void loadFile(bool test); std::list readFileRecursively(const std::string &path) const; const Settings &getSettings(); }; #endif // CONFIGFILEPARSER_H ================================================ FILE: debian/conffiles ================================================ /etc/flashmq/flashmq.conf ================================================ FILE: debian/flashmq.service ================================================ [Unit] Description=FlashMQ MQTT server After=network.target [Service] Type=notify User=root Group=root LimitNOFILE=infinity ExecStart=/usr/bin/flashmq --config-file /etc/flashmq/flashmq.conf ExecReload=/bin/kill -HUP $MAINPID Restart=on-failure RestartSec=5s TimeoutSec=300 [Install] WantedBy=multi-user.target ================================================ FILE: debian/postinst ================================================ #!/bin/bash -e FRESH_INSTALL=true if [[ "$1" == "configure" && "$2" != "" ]]; then FRESH_INSTALL=false fi if "$FRESH_INSTALL"; then echo "Fresh installation: enabling FlashMQ systemd service." systemctl enable flashmq.service else echo "This is not a fresh installation: not (re)enabling FlashMQ systemd service." # In case the service file changes, and to prevent systemd warnings 'service file changed' on upgrade. systemctl daemon-reload fi if systemctl is-enabled --quiet flashmq.service; then echo "FlashMQ is marked as enabled, so starting it." systemctl start flashmq.service else echo "FlashMQ is marked as disabled, so not starting it." fi ================================================ FILE: debian/postrm ================================================ #!/bin/bash -e if [[ "$1" != "upgrade" ]]; then echo "Disabling FlashMQ systemd service" systemctl disable flashmq.service || echo "Ignoring..." fi ================================================ FILE: debian/preinst ================================================ #!/bin/bash -e if [[ ! -f /etc/lsb-release && ! -f /etc/debian_version ]]; then echo "This is not a Debian or Ubuntu based system? Hmm" exit 1 fi if ! command -v systemctl -v &> /dev/null; then echo "This is not a systemd-based system. File a bug-report at https://github.com/halfgaar/FlashMQ/issues" exit 1 fi ================================================ FILE: debian/prerm ================================================ #!/bin/bash -e echo "Stopping FlashMQ systemd service" if systemctl is-active --quiet flashmq.service; then systemctl stop flashmq.service fi if systemctl is-active --quiet flashmq.service; then echo "FlashMQ failed to stop, according to systemctl." exit 1 fi ================================================ FILE: derivablecounter.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "derivablecounter.h" void DerivableCounter::inc(uint64_t n) { val += n; } uint64_t DerivableCounter::get() const { return val; } /** * @brief DerivableCounter::getPerSecond returns the amount per second since last time this method was called. * @return * * Even though the class it not meant to be thread-safe, this method does use a mutex, because obtaining the value can be * scheduled in different threads. */ double DerivableCounter::getPerSecond() { std::lock_guard locker(timeMutex); std::chrono::time_point now = std::chrono::steady_clock::now(); std::chrono::milliseconds msSinceLastTime = std::chrono::duration_cast(now - timeOfPrevious); uint64_t messagesTimes1000 = (val - valPrevious) * 1000; double result = messagesTimes1000 / static_cast(msSinceLastTime.count() + 1); // branchless avoidance of div by 0; timeOfPrevious = now; valPrevious = val; return result; } ================================================ FILE: derivablecounter.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef DERIVABLECOUNTER_H #define DERIVABLECOUNTER_H #include #include /** * @brief The DerivableCounter is a counter which can derive val/dt. * * It's not thread-safe, to avoid unnecessary locking. You should have counters per thread. */ class DerivableCounter { uint64_t val = 0; uint64_t valPrevious = 0; std::chrono::time_point timeOfPrevious = std::chrono::steady_clock::now(); std::mutex timeMutex; public: void inc(uint64_t n = 1); uint64_t get() const; double getPerSecond(); }; #endif // DERIVABLECOUNTER_H ================================================ FILE: dnsresolver.cpp ================================================ #include "dnsresolver.h" #include #include #include #include #include #include "utils.h" void DnsResolver::freeStuff() { curName.clear(); lookup.ar_name = nullptr; freeaddrinfo(lookup.ar_result); lookup.ar_result = nullptr; } DnsResolver::DnsResolver() { request.ai_family = AF_UNSPEC; request.ai_socktype = SOCK_STREAM; request.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; lookup.ar_request = &request; } DnsResolver::~DnsResolver() { gai_cancel(&lookup); int n = 0; while (gai_error(&lookup) == EAI_INPROGRESS && n++ < 1000) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } freeStuff(); } void DnsResolver::query(const std::string &text, ListenerProtocol protocol, std::chrono::milliseconds timeout) { gai_cancel(&lookup); freeStuff(); curName = text; getResultsTimeout = std::chrono::steady_clock::now() + timeout; lookup.ar_name = curName.c_str(); lookup.ar_service = nullptr; request.ai_family = AF_UNSPEC; if (protocol == ListenerProtocol::IPv4) request.ai_family = AF_INET; else if (protocol == ListenerProtocol::IPv6) request.ai_family = AF_INET6; struct gaicb *lookups[1]; lookups[0] = &lookup; int err = 0; if ((err = getaddrinfo_a(GAI_NOWAIT, lookups, 1, nullptr)) != 0) { const std::string errorstr(gai_strerror(err)); throw std::runtime_error(formatString("Dns lookup failure: %s", errorstr.c_str())); } } std::list DnsResolver::getResult() { if (curName.empty() || lookup.ar_name == nullptr) { throw std::runtime_error("No DNS query in progress"); } if (std::chrono::steady_clock::now() > getResultsTimeout) { const std::string name = this->curName; freeStuff(); throw std::runtime_error(formatString("DNS query for '%s' timed out.", name.c_str())); } std::list results; int err = gai_error(&lookup); if (err == EAI_INPROGRESS) return results; if (err != 0) { const std::string errorstr(gai_strerror(err)); freeStuff(); throw std::runtime_error(formatString("Dns lookup failure: %s", errorstr.c_str())); } struct addrinfo *cur_result = lookup.ar_result; while (cur_result) { FMQSockaddr result_sockaddr(cur_result->ai_addr); results.push_back(result_sockaddr); cur_result = cur_result->ai_next; } freeStuff(); if (results.empty()) { throw std::runtime_error("No error received but also no DNS result?"); } return results; } bool DnsResolver::idle() const { return (curName.empty() || lookup.ar_name == nullptr); } ================================================ FILE: dnsresolver.h ================================================ #ifndef DNSRESOLVER_H #define DNSRESOLVER_H #include #include #include #include #include #include "listener.h" #include "fmqsockaddr.h" /** * @brief The DnsResolver class does async DNS with getaddrinfo_a. * * Note that getaddrinfo_a uses threads. Had we known that, we would have probably rolled out our own. That * would also have prevented the fork problem: * * It turns out that getaddrinfo_a is not compatible with fork(). If it's been used once, the forked process can't * do DNS queries anymore. This is an issue for the test binary only at this point. */ class DnsResolver { struct gaicb lookup {}; struct addrinfo request {}; std::string curName; std::chrono::time_point getResultsTimeout; void freeStuff(); public: DnsResolver(); DnsResolver(const DnsResolver &other) = delete; DnsResolver(DnsResolver &&other) = delete; DnsResolver &operator=(const DnsResolver &other) = delete; ~DnsResolver(); void query(const std::string &text, ListenerProtocol protocol, std::chrono::milliseconds timeout); std::list getResult(); bool idle() const; }; #endif // DNSRESOLVER_H ================================================ FILE: driftcounter.cpp ================================================ #include "driftcounter.h" #include #include #include "settings.h" DriftCounter::DriftCounter() { } void DriftCounter::update(std::chrono::time_point call_time) { auto now = std::chrono::steady_clock::now(); std::chrono::milliseconds diff = std::chrono::duration_cast(now - call_time); diff = std::max(diff, std::chrono::milliseconds(0)); many_drifts.at(many_index++ & (many_drifts.size() - 1)) = diff; last_update = now; last_drift = diff; } std::chrono::milliseconds DriftCounter::getDrift() const { auto now = std::chrono::steady_clock::now(); const auto last_update_copy = last_update.value_or(std::chrono::steady_clock::now()); if (last_update_copy + (std::chrono::milliseconds(HEARTBEAT_INTERVAL) * 2 ) < now) return std::chrono::duration_cast(now - last_update_copy); return last_drift; } /** * @brief DriftCounter::getAvgDrift returns the averate of the last x updates. It's not thread safe, so you may * average updates from different calls to update(). * @return */ std::chrono::milliseconds DriftCounter::getAvgDrift() const { const double total = std::accumulate(many_drifts.begin(), many_drifts.end(), std::chrono::milliseconds(0)).count(); std::chrono::milliseconds avg = std::chrono::milliseconds(static_cast(total / static_cast(many_drifts.size()))); return avg; } ================================================ FILE: driftcounter.h ================================================ #ifndef DRIFTCOUNTER_H #define DRIFTCOUNTER_H #include #include #include /** * @brief The DriftCounter class allows measuring drift in threads. * * There is no thread syncing / mutexes. Values can be retrieved from other threads, but some values may * not be super accurate, which is fine. */ class DriftCounter { std::optional> last_update; std::chrono::milliseconds last_drift = std::chrono::milliseconds(0); std::array many_drifts{}; unsigned int many_index = 0; public: DriftCounter(); void update(std::chrono::time_point call_time); std::chrono::milliseconds getDrift() const; std::chrono::milliseconds getAvgDrift() const; }; #endif // DRIFTCOUNTER_H ================================================ FILE: enums.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef ENUMS_H #define ENUMS_H enum class X509ClientVerification { None, X509IsEnough, X509AndUsernamePassword }; enum class AllowListenerAnonymous { None, Yes, No }; enum class TLSVersion { TLSv1_0, TLSv1_1, TLSv1_2, TLSv1_3 }; enum class OverloadMode { Log, CloseNewClients }; enum class ConnectionProtocol { AcmeOnly, Mqtt, WebsocketMqtt }; enum class Mqtt3QoSExceedAction { Disconnect, Drop }; enum class HaProxyMode { Off, On, HaProxyClientVerification, HaProxyClientVerficiationWithAuthn }; #endif // ENUMS_H ================================================ FILE: evpencodectxmanager.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "evpencodectxmanager.h" EvpEncodeCtxManager::EvpEncodeCtxManager() : ctx(EVP_ENCODE_CTX_new(), EVP_ENCODE_CTX_free) { if (!ctx) throw std::runtime_error("Error allocating with EVP_ENCODE_CTX_new()"); EVP_DecodeInit(ctx.get()); } ================================================ FILE: evpencodectxmanager.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef EVPENCODECTXMANAGER_H #define EVPENCODECTXMANAGER_H #include #include struct EvpEncodeCtxManager { std::unique_ptr ctx; EvpEncodeCtxManager(); }; #endif // EVPENCODECTXMANAGER_H ================================================ FILE: examples/plugin_libcurl/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.5) cmake_policy(SET CMP0048 NEW) project(plugin_libcurl VERSION 1.0.0 LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) add_compile_options(-Wall) add_library(plugin_libcurl SHARED vendor/flashmq_plugin.h vendor/flashmq_public.h src/pluginstate.h src/curl_functions.h src/authenticatingclient.h src/authenticatingclient.cpp src/curl_functions.cpp src/pluginstate.cpp src/plugin_libcurl.cpp ) target_include_directories(plugin_libcurl PUBLIC . src) target_link_libraries(plugin_libcurl curl) set_target_properties(plugin_libcurl PROPERTIES VERSION ${PROJECT_VERSION}) set_target_properties(plugin_libcurl PROPERTIES SOVERSION 1) ================================================ FILE: examples/plugin_libcurl/LICENSE ================================================ The FlashMQ example plugin 'plugin_libcurl' is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to ================================================ FILE: examples/plugin_libcurl/README.md ================================================ # Async auth using libcurl This example plugin demonstrates the use of the async authentication interface in the `flashmq_plugin.h` interface, using libcurl. An HTTP request can take a while and in the mean time, we need to keep FlashMQ moving. The async interface allows us to do that. A breakdown: * The file `plugin_libcurl.cpp` contains the implementations of the plugin functions. * Libcurl is told using `CURLMOPT_SOCKETFUNCTION` what function we'll use to call `flashmq_poll_add_fd()`. * Libcurl is told using `CURLMOPT_TIMERFUNCTION` what function we'll use to call `flashmq_add_task()` and `flashmq_remove_task()`. * On login, it initiates an HTTP request and returns `AuthResult::async`. FlashMQ will keep the client waiting while handling other requests. * Socket activity is reported by `flashmq_plugin_poll_event_received()`, at which time we tell curl to continue with that socket. * In `check_all_active_curls()`, when we detect the transfer has finished and we make a decision about the authentication, by submitting `AuthResult::success` or `AuthResult::login_denied` to `flashmq_continue_async_authentication()`. ## Building Calling `build.sh` should be enough. It requires libcurl to be installed, of course. ## Configuring Use the config option `plugin` to point to the so file, like: ``` allow_anonymous false plugin /home/me/stuff/libplugin_libcurl.so.1.0.0 ``` ## Running One you have FlashMQ running with the plugin, you can log in with username `deny` to test denying login works. If you log in with username `curl`, it will do a request to Google. If it sees HTML in the response, it will pass the authentication. ================================================ FILE: examples/plugin_libcurl/build.sh ================================================ #!/bin/bash thisfile=$(readlink --canonicalize "$0") thisdir=$(dirname "$thisfile") BUILD_TYPE="Release" if [[ "$1" == "Debug" ]]; then BUILD_TYPE="Debug" fi BUILD_DIR="build-plugin-libcurl-$BUILD_TYPE" set -eu if [[ -e "$BUILD_DIR" ]]; then >&2 echo "$BUILD_DIR already exists. Considering fatal error because you should run 'make' in it if you want to keep using it." exit 1 else mkdir "$BUILD_DIR" fi cd "$BUILD_DIR" cmake -DCMAKE_BUILD_TYPE="$BUILD_TYPE" "$thisdir" make -j 4 ================================================ FILE: examples/plugin_libcurl/src/authenticatingclient.cpp ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #include #include "authenticatingclient.h" ExampleCurlPlugin::AuthenticatingClient::AuthenticatingClient() : easy_handle(curl_easy_init(), curl_easy_cleanup) { } ExampleCurlPlugin::AuthenticatingClient::~AuthenticatingClient() { auto x = registeredAtMultiHandle.lock(); if (x) { curl_multi_remove_handle(x.get(), easy_handle.get()); } } void ExampleCurlPlugin::AuthenticatingClient::addToMulti(std::shared_ptr &curlMulti) { if (curl_multi_add_handle(curlMulti.get(), easy_handle.get()) != CURLM_OK) throw std::runtime_error("curl_multi_add_handle failed"); registeredAtMultiHandle = curlMulti; } ================================================ FILE: examples/plugin_libcurl/src/authenticatingclient.h ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #ifndef AUTHENTICATINGCLIENT_H #define AUTHENTICATINGCLIENT_H #include #include #include "vendor/flashmq_plugin.h" #include // Use a name space to avoid obscure bugs when using type names FlashMQ uses too. namespace ExampleCurlPlugin { struct AuthenticatingClient { std::weak_ptr client; std::vector response; std::unique_ptr easy_handle; std::weak_ptr registeredAtMultiHandle; AuthenticatingClient(); ~AuthenticatingClient(); void addToMulti(std::shared_ptr &curlMulti); }; } #endif // AUTHENTICATINGCLIENT_H ================================================ FILE: examples/plugin_libcurl/src/curl_functions.cpp ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #include "curl_functions.h" #include #include "vendor/flashmq_plugin.h" #include "pluginstate.h" #include "authenticatingclient.h" using namespace ExampleCurlPlugin; /** * @brief This is curl telling us what events to watch for. * @param easy Is an 'easy handle'. You make one per request. * @param s The socket of the HTTP request. * @param what What events to listen for * @param clientp Pointer set with CURLMOPT_SOCKETDATA. It can be whatever you need. * @param socketp Is set with curl_multi_assign or will be NULL. * @return */ int socket_event_watch_notification(CURL *easy, curl_socket_t s, int what, void *clientp, void *socketp) { (void)easy; (void)clientp; (void)socketp; if (what == CURL_POLL_REMOVE) flashmq_poll_remove_fd(s); else { int events = 0; if (what == CURL_POLL_IN) events |= EPOLLIN; else if (what == CURL_POLL_OUT) events |= EPOLLOUT; else if (what == CURL_POLL_INOUT) events = EPOLLIN | EPOLLOUT; else return 1; // We know we get back a socket for curl, but if there are multiple libs we use, we could have used the weak void pointer to associate // a data structure with the socket, which you get back on socket events. flashmq_poll_add_fd(s, events, std::weak_ptr()); } return 0; } int timer_callback(CURLM *multi, long timeout_ms, void *clientp) { PluginState *s = static_cast(clientp); // We also remove the last known task before it executes if curl tells us to install a new one. This // is suggested by the unclear and incomplete example at https://curl.se/libcurl/c/CURLMOPT_TIMERFUNCTION.html. if (timeout_ms == -1 || s->current_timer > 0) { flashmq_remove_task(s->current_timer); s->current_timer = 0; } if (timeout_ms >= 0) { auto f = std::bind(&call_timed_curl_multi_socket_action, multi, s); s->current_timer = flashmq_add_task(f, timeout_ms); } return CURLM_OK; } void call_timed_curl_multi_socket_action(CURLM *multi, PluginState *s) { s->current_timer = 0; int a = 0; int rc = curl_multi_socket_action(multi, CURL_SOCKET_TIMEOUT, 0, &a); /* Curl says: "When this function returns error, the state of all transfers are uncertain and they cannot be * continued. curl_multi_socket_action should not be called again on the same multi handle after an error has * been returned, unless first removing all the handles and adding new ones." */ if (rc != CURLM_OK) { s->clearAllNetworkRequests(); return; } check_all_active_curls(s, multi); } void check_all_active_curls(PluginState *p, CURLM *curlMulti) { CURLMsg *msg; int msgs_left; while((msg = curl_multi_info_read(curlMulti, &msgs_left))) { if (msg->msg == CURLMSG_DONE) { CURL *easy = msg->easy_handle; AuthenticatingClient *c = nullptr; curl_easy_getinfo(easy, CURLINFO_PRIVATE, &c); flashmq_logf(LOG_INFO, "Libcurl said: %s", curl_easy_strerror(msg->data.result)); std::string answer(c->response.data(), std::min(9, c->response.size())); p->processNetworkAuthResult(c->client, answer); } } } /** * @brief curl_write_cb Would be more accurately 'read callback', because it's used to read the response from the curl easy handle. * @param data * @param n * @param l * @param userp is whatever you set with CURLOPT_WRITEDATA. * @return */ size_t curl_write_cb(char *data, size_t n, size_t l, void *userp) { AuthenticatingClient *ac = static_cast(userp); int pos = ac->response.size(); ac->response.resize(ac->response.size() + n*l); std::memcpy(&ac->response[pos], data, n*l); return n*l; } ================================================ FILE: examples/plugin_libcurl/src/curl_functions.h ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #ifndef CURL_FUNCTIONS_H #define CURL_FUNCTIONS_H #include #include #include "pluginstate.h" int socket_event_watch_notification(CURL *easy, curl_socket_t s, int what, void *clientp, void *socketp); int timer_callback(CURLM *multi, long timeout_ms, void *clientp); size_t curl_write_cb(char *data, size_t n, size_t l, void *userp); void call_timed_curl_multi_socket_action(CURLM *multi, ExampleCurlPlugin::PluginState *s); void check_all_active_curls(ExampleCurlPlugin::PluginState *p, CURLM *curlMulti); #endif // CURL_FUNCTIONS_H ================================================ FILE: examples/plugin_libcurl/src/plugin_libcurl.cpp ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #include "vendor/flashmq_plugin.h" #include #include #include #include "pluginstate.h" #include "curl_functions.h" #include "authenticatingclient.h" using namespace ExampleCurlPlugin; int flashmq_plugin_version() { return FLASHMQ_PLUGIN_VERSION; } void flashmq_plugin_main_init(std::unordered_map &plugin_opts) { (void)plugin_opts; if (curl_global_init(CURL_GLOBAL_ALL) != 0) throw std::runtime_error("Global curl init failed to init"); } void flashmq_plugin_main_deinit(std::unordered_map &plugin_opts) { (void)plugin_opts; curl_global_cleanup(); } void flashmq_plugin_allocate_thread_memory(void **thread_data, std::unordered_map &plugin_opts) { (void)plugin_opts; PluginState *state = new PluginState(); *thread_data = state; } void flashmq_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map &plugin_opts) { (void)plugin_opts; PluginState *state = static_cast(thread_data); delete state; } /** * @brief flashmq_plugin_init We have nothing to do here, really. * @param thread_data * @param plugin_opts * @param reloading */ void flashmq_plugin_init(void *thread_data, std::unordered_map &plugin_opts, bool reloading) { (void)thread_data; (void)plugin_opts; (void)reloading; } void flashmq_plugin_deinit(void *thread_data, std::unordered_map &plugin_opts, bool reloading) { (void)thread_data; (void)plugin_opts; (void)reloading; } /** * @brief flashmq_plugin_poll_event_received * @param thread_data * @param fd * @param events * @param p A pointer to a data structure we assigned when watching the fd. We only use libcurl so we know we have to give it to * libcurl. Had we also used something else, we would have needed this to figure out what the fd is. */ void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t events, const std::weak_ptr &p) { (void)p; PluginState *s = static_cast(thread_data); int new_events = CURL_CSELECT_ERR; if (events & EPOLLIN) { new_events &= ~CURL_CSELECT_ERR; new_events |= CURL_CSELECT_IN; } if (events & EPOLLOUT) { new_events &= ~CURL_CSELECT_ERR; new_events |= CURL_CSELECT_OUT; } int n = -1; curl_multi_socket_action(s->curlMulti.get(), fd, new_events, &n); check_all_active_curls(s, s->curlMulti.get()); } AuthResult flashmq_plugin_login_check( void *thread_data, const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client) { (void)clientid; (void)userProperties; (void)client; (void)username; (void)password; if (username == "deny") { return AuthResult::login_denied; } if (username == "curl") { PluginState *state = static_cast(thread_data); std::unique_ptr &c = state->networkAuthRequests[client]; if (c) throw std::runtime_error("Client already doing an authentication"); c = std::make_unique(); c->client = client; curl_easy_setopt(c->easy_handle.get(), CURLOPT_WRITEFUNCTION, curl_write_cb); // The function that is called when curl has data from the response for us. curl_easy_setopt(c->easy_handle.get(), CURLOPT_WRITEDATA, c.get()); // The pointer set we get in the above function. curl_easy_setopt(c->easy_handle.get(), CURLOPT_PRIVATE, c.get()); // The pointer set we get back with 'curl_easy_getinfo', for when the request is finished. flashmq_logf(LOG_INFO, "Asking an HTTP server"); // Keep in mind that DNS resovling may be blocking too. You could perhaps resolve the DNS once and use the result. But, // libcurl actually has some DNS caching as well. curl_easy_setopt(c->easy_handle.get(), CURLOPT_URL, "http://www.google.com/"); c->addToMulti(state->curlMulti); return AuthResult::async; } return AuthResult::success; } AuthResult flashmq_plugin_acl_check( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties) { (void)thread_data; (void)access; (void)clientid; (void)username; (void)topic; (void)subtopics; (void)shareName; (void)payload; (void)qos; (void)retain; (void)correlationData; (void)responseTopic; (void)contentType; (void) expiresAt; (void)userProperties; return AuthResult::success; } ================================================ FILE: examples/plugin_libcurl/src/pluginstate.cpp ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #include "pluginstate.h" #include #include "curl_functions.h" ExampleCurlPlugin::PluginState::PluginState() : curlMulti(curl_multi_init(), curl_multi_cleanup) { if (!curlMulti) throw std::runtime_error("Curl failed to init"); curl_multi_setopt(curlMulti.get(), CURLMOPT_SOCKETFUNCTION, socket_event_watch_notification); curl_multi_setopt(curlMulti.get(), CURLMOPT_TIMERFUNCTION, timer_callback); curl_multi_setopt(curlMulti.get(), CURLMOPT_TIMERDATA, this); // We need our plugin state in the timer_callback function. } ExampleCurlPlugin::PluginState::~PluginState() { } void ExampleCurlPlugin::PluginState::processNetworkAuthResult(std::weak_ptr &client, const std::string &answer) { auto pos = this->networkAuthRequests.find(client); if (pos == this->networkAuthRequests.end()) return; // This just checks we get an HTML page back, but you will of course need to do something more useful, like parse JSON, // look at the HTTP status code, etc. if (answer == "networkAuthRequests.erase(pos); } void ExampleCurlPlugin::PluginState::clearAllNetworkRequests() { this->networkAuthRequests.clear(); } ================================================ FILE: examples/plugin_libcurl/src/pluginstate.h ================================================ /* This file is part of FlashMQ example plugin 'plugin_libcurl' and is free and unencumbered software released into the public domain. Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For more information, please refer to */ #ifndef PLUGINSTATE_H #define PLUGINSTATE_H #include #include #include #include #include "vendor/flashmq_plugin.h" #include "authenticatingclient.h" // Use a name space to avoid obscure bugs when using type names FlashMQ uses too. namespace ExampleCurlPlugin { struct PluginState { std::shared_ptr curlMulti; uint32_t current_timer = 0; std::map, std::unique_ptr, std::owner_less>> networkAuthRequests; PluginState(); ~PluginState(); void processNetworkAuthResult(std::weak_ptr &client, const std::string &answer); void clearAllNetworkRequests(); }; } #endif // PLUGINSTATE_H ================================================ FILE: examples/plugin_libcurl/vendor/flashmq_plugin.h ================================================ /* * This file is part of FlashMQ (https://www.flashmq.org). It defines the * plugin interface. * * - flashmq_plugin.h defines functions to be implemented by plugins. * - flashmq_public.h describes FlashMQ's functions available to plugins. * * Those two files are public domain and you are encouraged * to copy them to your plugin project for portability. Including those files * in your project does not make it a 'derivative work'. * * Compile like: gcc -fPIC -shared plugin.cpp -o plugin.so * * It's best practice to build your plugin with the same library versions of the * build of FlashMQ you're using. In practice, this means building on the OS * version you're running on. This also means using the AppImage build of FlashMQ * is not really compatible with plugins, because that includes older, and fixed, * versions of various libraries. * * For instance, if you use OpenSSL: by the time your plugin is loaded, FlashMQ will * have already dynamically linked OpenSSL. If you then try to call OpenSSL * functions, you'll run into ABI incompatibilities. */ #ifndef FLASHMQ_PLUGIN_H #define FLASHMQ_PLUGIN_H #include #include #include #include #include "flashmq_public.h" #define FLASHMQ_PLUGIN_VERSION 5 extern "C" { /** * @brief flashmq_plugin_version must return FLASHMQ_PLUGIN_VERSION. * @return FLASHMQ_PLUGIN_VERSION. * * [Must be implemented by plugin] */ int flashmq_plugin_version(); /** * @brief flashmq_plugin_main_init is called once before the event loops start. * @param plugin_opts * * [Can optionally be implemented by plugin] */ void flashmq_plugin_main_init(std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_main_deinit is the complementary pair of flashmq_plugin_main_init(). It's called after the threads have stopped. * @param plugin_opts * * [Can optionally be implemented by plugin] */ void flashmq_plugin_main_deinit(std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_allocate_thread_memory is called once by each thread. Never again. * @param thread_data. Create a memory structure and assign it to *thread_data. * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * * Only allocate the plugin's memory here, or other things that you really only have to do once. Don't open connections, etc. That's * because the reload mechanism doesn't call this function. * * Because of the multi-core design of FlashMQ, you should treat each thread as its own domain with its own data. You can use static * variables for global scope if you must, or even create threads, but do provide proper locking where necessary. * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_allocate_thread_memory(void **thread_data, std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_deallocate_thread_memory is called once by each thread. Never again. * @param thread_data. Delete this memory. * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_init is called on thread start and config reload. It is the main place to initialize the plugin. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * @param reloading. * * The best approach to state keeping is doing everything per thread. You can initialize connections to database servers, load encryption keys, * create maps, etc. However, remember that for instance with libcurl, initing a 'multi handle' is best not done here, or at least not done * AGAIN on reload, because it will distrupt your ongoing transfers. Memory structures that you really only need to init once are best * done in 'flashmq_plugin_allocate_thread_memory()' or even 'flashmq_plugin_main_init()'. * * Keep in mind that libraries you use may not be thread safe (by default). Sometimes they use global scope in treacherous ways. As a random * example: Qt's QSqlDatabase needs a unique name for each connection, otherwise it is not thread safe and will crash. It will also hide away * libmysqlclient's requirement to do a global one-time init, that would be best done in 'flashmq_plugin_main_init()'. * * There is the option to set 'plugin_serialize_init true' in the config file, which allows some mitigation in * case you run into problems. * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_init(void *thread_data, std::unordered_map &plugin_opts, bool reloading); /** * @brief flashmq_plugin_deinit is called on thread stop and config reload. It is the precursor to initializing. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * @param reloading * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_deinit(void *thread_data, std::unordered_map &plugin_opts, bool reloading); /** * @brief flashmq_plugin_periodic is called every x seconds as defined in the config file. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * * You may need to periodically refresh data from a database, post stats, etc. You can do that from here. It's queued * in each thread at the same time, so you can perform somewhat synchronized events in all threads. * * Note that it's executed in the event loop, so it blocks the thread if you block here. If you need asynchronous operation, * you can make threads yourself. Be sure to synchronize data access properly in that case. * * The setting plugin_timer_period sets this interval in seconds. * * You can throw exceptions on errors. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_periodic_event(void *thread_data); /** * @brief flashmq_plugin_alter_subscription can optionally be implemented if you want to be able to change incoming subscriptions. * @param thread_data * @param clientid * @param topic non-const reference which can be changed. * @param subtopics * @param qos non-const reference which can be changed. * @param userProperties * @return boolean indicating whether the subscription was changed. Not returning the truth here results in unpredictable behavior. * * In case of shared subscriptions, you will see the original subscription path, like '$share/myshare/battery/voltage'. You have the * chance to change every aspect of it, like make it non-shared. * * [Can optionally be implemented by plugin] */ bool flashmq_plugin_alter_subscription( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, uint8_t &qos, const std::vector> *userProperties); /** * @brief flashmq_plugin_alter_publish allows changing of the non-const arguments. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @return boolean indicating whether the packet was changed. It saves FlashMQ from having to do a full compare. Not returning the truth here * results in unpredictable behavior. Note: if only changing retain, you can get away with returning false. * * Be aware that changing publishes may incur a (slight) reduction in performance. * * [Can optionally be implemented by plugin] */ bool flashmq_plugin_alter_publish( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::optional &correlationData, std::optional &responseTopic, std::optional &contentType, std::vector> *userProperties); /** * @brief flashmq_plugin_login_check is called on login of a client. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param username * @param password * @param client Example use is for storing in a async operation and passing to flashmq_continue_async_authentication. * @return * * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error, * because there's nothing else to do: the state of FlashMQ won't change. * * Note that there is a setting 'plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not * thread-safe. It will negate much of FlashMQ's multi-core model. * * The AuthResult::async can be used if your auth check causes blocking IO (like network). You can save the weak pointer to the client * and do the auth in a thread or any kind of async way. FlashMQ's event loop will then continue. You can call flashmq_continue_async_authentication * later with the result. * * [Must be implemented by plugin] */ AuthResult flashmq_plugin_login_check( void *thread_data, const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client); /** * @brief flashmq_plugin_client_disconnected Called when clients disconnect or their keep-alive expire. * @param thread_data * @param clientid * * Is only called for authenticated clients, to avoid spoofing. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_client_disconnected(void *thread_data, const std::string &clientid); /** * @brief flashmq_plugin_on_unsubscribe is called after unsubscribe. Unsubscribe actions can't be manipulated or blocked. * @param topic Does not contain the share name. * @param subtopics Does not contain the share name. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_on_unsubscribe( void *thread_data, const std::weak_ptr &session, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, const std::vector> *userProperties); /** * @brief flashmq_plugin_acl_check is called on publish, deliver and subscribe. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param shareName The shared subscription name in a filter like '$share/my_share_name/one/two'. Is only present on AclAccess::subscribe. * @return * * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error, * because there's nothing else to do: the state of FlashMQ won't change. * * Controlling subscribe access can have several benefits. For instance, you may want to avoid subscriptions that cause * a lot of server load. If clients pester you with many subscriptions like '+/+/+/+/+/+/+/+/+/', that causes a lot * of tree walking. Similarly, if all clients subscribe to '#' because it's easy, every single message passing through * the server will have to be ACL checked for every subscriber. * * Note that only MQTT 3.1.1 or higher has a 'failed' return code for subscribing, so older clients will see a normal * ack and won't know it failed. * * Note that there is a setting 'plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not * thread-safe. It will negate much of FlashMQ's multi-core model. * * When the 'access' is 'subscribe' and it's a shared subscription (like '$share/myshare/one/two/three'), you only get * the effective topic filter (like 'one/two/three'). However, since plugin version 4, there is the argument 'shareName' for that. * * [Must be implemented by plugin] */ AuthResult flashmq_plugin_acl_check( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties); /** * @brief flashmq_plugin_extended_auth can be used to implement MQTT 5 extended auth. This is optional. * @param thread_data is the memory you allocated in flashmq_plugin_allocate_thread_memory. * @param clientid * @param stage * @param authMethod * @param authData * @param userProperties are optional (and are nullptr in that case) * @param returnData is a non-const string, that you can set to include data back to the client in an AUTH packet. * @param username is a non-const string. You can set it, which will then apply to ACL checking and show in the logs. * @param client Use this for AuthResult::async. See flashmq_plugin_login_check(). * @return an AuthResult enum class value * * [Can optionally be implemented by plugin] */ AuthResult flashmq_plugin_extended_auth( void *thread_data, const std::string &clientid, ExtendedAuthStage stage, const std::string &authMethod, const std::string &authData, const std::vector> *userProperties, std::string &returnData, std::string &username, const std::weak_ptr &client); /** * @brief Is called when the socket watched by 'flashmq_poll_add_fd()' has an event. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param fd * @param events contains the events as a bit flags. See 'man epoll'. * @param p can be made back into your type with 'std::shared_ptr sp = std::static_pointer_cast(b.lock())'. * This allows you to properly lend pointers to the event system that you can actually check for expiration. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t events, const std::weak_ptr &p); } #endif // FLASHMQ_PLUGIN_H ================================================ FILE: examples/plugin_libcurl/vendor/flashmq_public.h ================================================ /* * This file is part of FlashMQ (https://www.flashmq.org). It describes * public FlashMQ functions available to plugins. * * This interface definition is public domain and you are encouraged * to copy it to your plugin project together with flashmq_plugin.h, * for portability. Including this file in your project does not make * it a 'derivative work'. */ #ifndef FLASHMQ_PUBLIC_H #define FLASHMQ_PUBLIC_H #include #include #include #include #include // Compatible with Mosquitto, for (auth) plugin compatability. #define LOG_NONE 0x00 #define LOG_INFO 0x01 #define LOG_NOTICE 0x02 #define LOG_WARNING 0x04 #define LOG_ERR 0x08 #define LOG_ERROR 0x08 #define LOG_DEBUG 0x10 #define LOG_SUBSCRIBE 0x20 #define LOG_UNSUBSCRIBE 0x40 #define LOG_PUBLISH 0x80 // Not compatible with Mosquitto #define API __attribute__((visibility("default"))) extern "C" { class Client; class Session; /** * @brief The AclAccess enum's numbers are compatible with Mosquitto's 'int access'. * * read = reading a publish published by someone else. * write = doing a publish. * subscribe = subscribing. */ enum class AclAccess { none = 0, read = 1, write = 2, subscribe = 4, register_will = 100 }; /** * @brief The AuthResult enum's numbers are compatible with Mosquitto's auth result. * * async = defer the decision until you have the result from an async call, which can be submitted with flashmq_continue_async_authentication(). * * auth_continue = part of MQTT5 extended authentication, which can be a back-and-forth between server and client. * * success_without_retained_delivery = allow the subscription action, but don't try to give client the matching retained messages. This * can be used prevent load on the server. For instance, if there are many retained messages and clients subscribe to '#'. This value * is only valid for AclAccess::subscribe, and requires FlashMQ version 1.9.0 or newer. * * success_without_setting_retained = allow the write action, but don't set the retained message (if it was set to retain to begin * with). MQTT5 subscribers with 'retain as published' will still see the retain flag set. This value is only valid for * AclAccess::write and requires a FlashMQ version 1.16.0 or higher. * * success_but_drop_publish / success_but_drop = send success SUBACK or PUBACK back to client (if QoS) but don't process the packet. This * may be useful in combination with flashmq_publish_message() or flashmq_plugin_add_subscription() if you need to publish or subscribe * something new entirely (different topic(s) or payload for instance). This only works with AclAccess::write (FlashMQ 1.20.0 or higher) * and AclAccess::subscribe (FlashMQ 1.21.0 or higher). * * server_not_available = to be used as log-in result, for when you don't have auth data yet, for instance. MQTT3 and MQTT5 both support * sending 'ServerUnavailble' in their CONNACK, when this result is used. Requires FlashMQ 1.17.0 or newer. */ enum class AuthResult { success = 0, auth_method_not_supported = 10, acl_denied = 12, login_denied = 11, error = 13, server_not_available = 14, async = 50, success_without_retained_delivery = 51, success_without_setting_retained = 52, success_but_drop_publish = 53, success_but_drop = 53, auth_continue = -4 }; enum class ExtendedAuthStage { None = 0, Auth = 10, Reauth = 20, Continue = 30 }; /** * @brief The ServerDisconnectReasons enum lists the possible values to initiate client disconnect with. * * This is a subset of all MQTT5 reason codes that are allowed in a disconnect packet, and that make sense for plugin use. */ enum class ServerDisconnectReasons { NormalDisconnect = 0, UnspecifiedError = 128, ProtocolError = 130, ImplementationSpecificError = 131, NotAuthorized = 135, ServerBusy = 137, MessageRateTooHigh = 150 }; /** * @brief flashmq_logf calls the internal logger of FlashMQ. The logger mutexes all access, so is thread-safe, and writes to disk * asynchronously, so it won't hold you up. * @param level is any of the levels defined above, starting with LOG_. * @param str * * FlashMQ makes no distinction between INFO and NOTICE. */ void API flashmq_logf(int level, const char *str, ...); /** * @brief flashmq_plugin_remove_client queues a removal of a client in the proper thread, including session if required. It can be called by * plugin code (meaning this function does not need to be implemented). * @param session Can be obtained with flashmq_get_session_pointer(). * @param alsoSession also remove the session if it would otherwise remain. * @param reasonCode is only for MQTT5, because MQTT3 doesn't have server-initiated disconnect packets. * * Many clients will automatically reconnect, so you'll have to also remove permissions of the client in question, probably. * * Can be called from any thread: the action will be queued properly. * * [New version since plugin version 4] */ void API flashmq_plugin_remove_client_v4(const std::weak_ptr &session, bool alsoSession, ServerDisconnectReasons reasonCode); /** * @brief flashmq_plugin_remove_subscription removes a client's subscription from the central store. It can be called by plugin code (meaning * this function does not need to be implemented). * @param session Can be obtained with flashmq_get_session_pointer(). * @param topicFilter Like 'one/two/three' or '$share/myshare/one/two/three'. * * It matches only literal filters. So removing '#' would only remove an active subscription on '#', not 'everything'. * * Will throw exceptions on certain errors. * * Can be called from any thread, because the global subscription store is mutexed. * * [New version since plugin version 4] */ void API flashmq_plugin_remove_subscription_v4(const std::weak_ptr &session, const std::string &topicFilter); /** * @brief flashmq_plugin_add_subscription * @param session Can be obtained with flashmq_get_session_pointer(). * @param topicFilter Like 'one/two/three' or '$share/myshare/one/two/three'. * @return boolean True when session found and subscription actually added. * * Will throw exceptions on certain errors. * * Can be called from any thread, because the global subscription store is mutexed. */ bool API flashmq_plugin_add_subscription( const std::weak_ptr &session, const std::string &topicFilter, uint8_t qos, bool noLocal, bool retainAsPublished, const uint32_t subscriptionIdentifier); /** * @brief flashmq_continue_async_authentication is to continue/finish async authentication. * @param client * @param result * @param delay Introducing a delay on failure can be a benificial security feature. * * When you've previously returned AuthResult::async in the authentication check, because you need to perform a network call for instance, * you can submit the final result back to FlashMQ with this function. The action will be queued in the proper thread. * * It uses a weak pointer to Client instead of client id, because clients are in limbo at this point, and a client id isn't necessarily * correct (anymore). The login functions also give this weak pointer so you can store it with the async operation, to be used again later for * a call to this function. * * Can be called from any thread. * * [New version since plugin version 4, FlashMQ version 1.25.0] */ void API flashmq_continue_async_authentication_v4( const std::weak_ptr &client, AuthResult result, const std::string &authMethod, const std::string &returnData, const uint32_t delay_in_ms); /** * @brief flashmq_publish_message Publish a message from the plugin. * * Can be called from any thread. */ void API flashmq_publish_message( const std::string &topic, const uint8_t qos, const bool retain, const std::string &payload, uint32_t expiryInterval=0, const std::vector> *userProperties = nullptr, const std::string *responseTopic=nullptr, const std::string *correlationData=nullptr, const std::string *contentType=nullptr); /** * @brief flashmq_get_client_address_v4 * @param client A client pointer as provided by 'flashmq_plugin_login_check'. * @param text If not nullptr, will be assigned the address in text form, like 192.168.1.1 or "2001:0db8:85a3:0000:1319:8a2e:0370:7344". * @param addr If not nullptr, will fill a sockaddr struct, for low level operations. * @param addrlen Size of addr. Supply the length. Afterwards, the actual size will be reported. * * The text, addr and addrlen must be pointers to local variables in the calling context. * * Note that the sockaddr API is hard to use safely in C++. Use of addr can very easily lead to undetected undefined behavior in * C++ because of type aliasing violations. The only safe way is to avoid using a casted object, instead use memcpy into structs of the * correct type: first 'struct sockaddr', to read the family, then like 'struct sockaddr_in' or 'struct sockaddr_in6'. In other * words, avoid accessing the members of the struct, even those of 'struct sockaddr'. * * Example of initializing the addr variable: * * struct sockaddr_storage addr_mem; * struct sockaddr *addr = reinterpret_cast(&addr_mem); * socklen_t addrlen = sizeof(addr_mem); * * Afterwards, do the memcpy stuff to read it, or pass it verbatim to library functions. * * [New version since plugin version 4] */ void API flashmq_get_client_address_v4(const std::weak_ptr &client, std::string *text, sockaddr *addr, socklen_t *addrlen); /** * @brief flashmq_get_session_pointer Get reference counted weak pointer of a session. * @param clientid The client ID of the session you're retrieving. * @param username The username is used for verification, as a security measure. * @param sessionOut The result (has to be an output parameter because we can't return it). * * The weak pointer will acurately reflect the original session. If it has been replaced with a new one with * the same client ID, this weak pointer will be 'expired'. */ void API flashmq_get_session_pointer(const std::string &clientid, const std::string &username, std::weak_ptr &sessionOut); /** * @brief flashmq_get_client_pointer Get reference counted client pointer of a session. * @param clientOut The result (has to be an output parameter because we can't return it). * * Can we used to feed to other functions, or to check if the client is still online. */ void API flashmq_get_client_pointer(const std::weak_ptr &session, std::weak_ptr &clientOut); /** * @brief Allows async operation of outgoing connections you may need to make. It adds the file descriptor to * the epoll listener. * @param fd * @param events epoll events, typically EPOLLIN (ready read) and EPOLLOUT (ready write). Should be or'ed together, * like 'EPOLLOUT | EPOLLIN'. See 'man epoll'. * @param p weak pointer. Can be a weak copy of a shared pointer with proper type. Like p = std::make_share(). * You'll get it back in 'flashmq_plugin_poll_event_received()'. Use is optional. For libcurl multi socket, * you don't need it. * * You can do this once you have a connection with something external. * * You can also call it again with different events, in which case it will modify the existing entry. If you specify * a non-expired p, it will overwrite the original data associated with the fd. * * Is meant for the local worker thread. It's a no-op in custom threads. * * Will throw exceptions on error, so be sure to handle them. */ void API flashmq_poll_add_fd(int fd, uint32_t events, const std::weak_ptr &p); /** * @brief Remove the fd from the event polling system. * @param fd * * Closing a socket will also remove it from the epoll system, but if you don't call this function on close, you may get stray * events once the fd number is reused. There is protection against it, but you may end up with unpredictable behavior. * * Is meant for the local worker thread. It's a no-op in custom threads. * * Will throw exceptions on error, so be sure to handle them. */ void API flashmq_poll_remove_fd(uint32_t fd); /** * @brief call a task later, once. * @param f Function, that can be created with std::bind, for instance. * @param delay_in_ms * @return id of the timer, which can be used to remove it. * * The task queue is local to the current worker thread, including the id returned. It's a no-op in custom threads. * * This can be necessary for asynchronous interfaces, like libcurl. * * Can throw an exceptions. */ uint32_t API flashmq_add_task(std::function f, uint32_t delay_in_ms); /** * @brief Remove a task with id as given by 'flashmq_add_task()'. * @param id * * The task queue is local to the current worker thread, including the id returned. It's a no-op in custom threads. */ void API flashmq_remove_task(uint32_t id); /** * @brief Use this if you don't want your thread to be signalled ready after flashmq_plugin_init(). * * Normally, when threads return from flashmq_plugin_init(), the worker thread signals to the main thread it's done. Once * all threads have done this, the listeners are created and systemd is notified. You may want to defer this, until all * your auth data is loaded, for instance. * * One simple solution to queue back this event is to use flashmq_add_task() to poll on a timer and check if your data is * loaded, and run flashmq_signal_thread_ready() when it's ready. * * [Introduced in FlashMQ 1.26.0] */ void API flashmq_defer_thread_ready(); /** * @brief Counterpart to flashmq_defer_thread_ready(). * * There is no timeout action. If you don't do this, FlashMQ will never be ready. When starting from systemd unit of * type 'notify', it will be restarted after a set time. * * [Introduced in FlashMQ 1.26.0] */ void API flashmq_signal_thread_ready(); } #endif // FLASHMQ_PUBLIC_H ================================================ FILE: exceptions.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "exceptions.h" ================================================ FILE: exceptions.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef EXCEPTIONS_H #define EXCEPTIONS_H #include #include #include #include "types.h" #define API __attribute__((visibility("default"))) // Because exceptions are (potentially) visible to plugins, we want to make sure to avoid symbol collisions, so using their own namespace. namespace FlashMQ { /** * @brief The ProtocolError class is handled by the error handler in the worker threads and is used to make decisions about if and how * to inform a client and log the message. * * It's mainly meant for errors that can be communicated with MQTT packets. */ class API ProtocolError : public std::runtime_error { public: const ReasonCodes reasonCode; ProtocolError(const std::string &msg, ReasonCodes reasonCode = ReasonCodes::UnspecifiedError) : std::runtime_error(msg), reasonCode(reasonCode) { } }; class API BadClientException : public std::runtime_error { std::optional mLogLevel; public: BadClientException(const std::string &msg, int logLevel=-1) : std::runtime_error(msg) { if (logLevel >= 0) this->mLogLevel = logLevel; } std::optional getLogLevel() const { return this->mLogLevel; } }; class API NotImplementedException : public std::runtime_error { public: NotImplementedException(const std::string &msg) : std::runtime_error(msg) {} }; class API FatalError : public std::runtime_error { public: FatalError(const std::string &msg) : std::runtime_error(msg) {} }; class API ConfigFileException : public std::runtime_error { public: ConfigFileException(const std::string &msg) : std::runtime_error(msg) {} ConfigFileException(std::ostringstream oss) : std::runtime_error(oss.str()) {} }; class API pluginException : public std::runtime_error { public: pluginException(const std::string &msg) : std::runtime_error(msg) {} }; class API BadWebsocketVersionException : public std::runtime_error { public: BadWebsocketVersionException(const std::string &msg) : std::runtime_error(msg) {} }; class API BadHttpRequest : public std::runtime_error { public: BadHttpRequest(const std::string &msg) : std::runtime_error(msg) {} }; } using namespace FlashMQ; #endif // EXCEPTIONS_H ================================================ FILE: fdmanaged.cpp ================================================ #include "fdmanaged.h" #include FdManaged::FdManaged(int fd) : fd(fd) { } FdManaged::~FdManaged() { if (fd > 0) // For some tech debt reasons we do > 0 instead of >= 0, which it should be. { close(fd); fd = -1; } } ================================================ FILE: fdmanaged.h ================================================ #ifndef FDMANAGED_H #define FDMANAGED_H class FdManaged { int fd = -1; public: FdManaged() = default; FdManaged(int fd); FdManaged(const FdManaged &other) = delete; FdManaged(FdManaged &&other) = delete; ~FdManaged(); int get() const { return fd; } }; #endif // FDMANAGED_H ================================================ FILE: flags.h ================================================ #ifndef FLAGS_H #define FLAGS_H #include #include template class Flags { uint32_t flags = 0; public: bool hasFlagSet(FlagType val) const { return static_cast(this->flags & (1 << static_cast(val))); } bool hasNone() const { return flags == 0; } bool hasAll() const { return flags == std::numeric_limits::max(); } void setFlag(FlagType val) { flags |= (1 << static_cast(val)); } void clearFlag(FlagType val) { flags &= ~(1 << static_cast(val)); } void clearAll() { flags = 0; } void setAll() { flags = std::numeric_limits::max(); } }; #endif // FLAGS_H ================================================ FILE: flashmq.conf ================================================ # https://www.flashmq.org/documentation/config-file/ log_file /var/log/flashmq/flashmq.log storage_dir /var/lib/flashmq ================================================ FILE: flashmq_plugin.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "flashmq_plugin.h" #include "flashmq_plugin_deprecated.h" #include "logger.h" #include "threaddata.h" #include "threadglobals.h" #include "subscriptionstore.h" #include "globals.h" #include "utils.h" #include "mainapp.h" void flashmq_logf(int level, const char *str, ...) { Logger *logger = Logger::getInstance(); va_list valist; va_start(valist, str); logger->logf(level, str, valist); va_end(valist); } void mosquitto_log_printf(int level, const char *fmt, ...) { Logger *logger = Logger::getInstance(); va_list valist; va_start(valist, fmt); logger->logf(level, fmt, valist); va_end(valist); } /** * @brief flashmq_plugin_remove_client for previous plugin versions. */ void flashmq_plugin_remove_client(const std::string &clientid, bool alsoSession, ServerDisconnectReasons reasonCode) { std::shared_ptr store = globals->subscriptionStore; std::shared_ptr session = store->lockSession(clientid); if (session) { std::shared_ptr client = session->makeSharedClient(); if (client) { ReasonCodes _code = static_cast(reasonCode); std::shared_ptr td = client->lockThreadData(); if (td) { td->serverInitiatedDisconnect(client, _code, "Removed from plugin"); } } if (alsoSession) store->removeSession(session); } } void flashmq_plugin_remove_client_v4(const std::weak_ptr &session, bool alsoSession, ServerDisconnectReasons reasonCode) { std::shared_ptr store = globals->subscriptionStore; if (!store) return; std::shared_ptr session_locked = session.lock(); if (!session_locked) return; std::shared_ptr client = session_locked->makeSharedClient(); if (!client) return; std::shared_ptr td = client->lockThreadData(); if (!td) return; ReasonCodes _code = static_cast(reasonCode); td->serverInitiatedDisconnect(client, _code, "Removed from plugin"); if (alsoSession) store->removeSession(session_locked); } /** * @brief flashmq_plugin_remove_subscription for previous plugin versions. */ void flashmq_plugin_remove_subscription(const std::string &clientid, const std::string &topicFilter) { if (!(isValidUtf8(topicFilter) && isValidSubscribePath(topicFilter))) throw std::runtime_error("Unsubscribing from plugin failed: invalid topic filter: " + topicFilter); std::shared_ptr store = globals->subscriptionStore; std::shared_ptr session = store->lockSession(clientid); if (!session) return; std::vector subtopics = splitTopic(topicFilter); std::string shareName; std::string _; parseSubscriptionShare(subtopics, shareName, _); store->removeSubscription(session, subtopics, shareName); } void flashmq_plugin_remove_subscription_v4(const std::weak_ptr &session, const std::string &topicFilter) { if (!(isValidUtf8(topicFilter) && isValidSubscribePath(topicFilter))) throw std::runtime_error("Unsubscribing from plugin failed: invalid topic filter: " + topicFilter); std::shared_ptr store = globals->subscriptionStore; std::shared_ptr session_locked = session.lock(); if (!session_locked) return; std::vector subtopics = splitTopic(topicFilter); std::string shareName; std::string _; parseSubscriptionShare(subtopics, shareName, _); store->removeSubscription(session_locked, subtopics, shareName); } bool flashmq_plugin_add_subscription( const std::weak_ptr &session, const std::string &topicFilter, uint8_t qos, bool noLocal, bool retainAsPublished, const uint32_t subscriptionIdentifier) { if (!(isValidUtf8(topicFilter) && isValidSubscribePath(topicFilter))) throw std::runtime_error("Subscribing from plugin failed: invalid topic filter: " + topicFilter); std::shared_ptr store = globals->subscriptionStore; if (!store) return false; std::shared_ptr session_locked = session.lock(); if (!session_locked) return false; std::vector subtopics = splitTopic(topicFilter); std::string shareName; std::string topicDummy; parseSubscriptionShare(subtopics, shareName, topicDummy); const AddSubscriptionType result = store->addSubscription(session_locked, subtopics, qos, noLocal, retainAsPublished, shareName, subscriptionIdentifier); return result == AddSubscriptionType::Invalid ? false : true; } void flashmq_continue_async_authentication(const std::weak_ptr &client, AuthResult result, const std::string &authMethod, const std::string &returnData) { flashmq_continue_async_authentication_v4(client, result, authMethod, returnData, 0); } void flashmq_continue_async_authentication_v4( const std::weak_ptr &client, AuthResult result, const std::string &authMethod, const std::string &returnData, const uint32_t delay_in_ms) { std::shared_ptr c = client.lock(); if (!c) return; std::shared_ptr td = c->lockThreadData(); if (!td) return; td->queueContinuationOfAuthentication(c, result, authMethod, returnData, delay_in_ms); } void flashmq_publish_message(const std::string &topic, const uint8_t qos, const bool retain, const std::string &payload, uint32_t expiryInterval, const std::vector> *userProperties, const std::string *responseTopic, const std::string *correlationData, const std::string *contentType) { auto do_publish = [](Publish &pub){ std::shared_ptr store = globals->subscriptionStore; if (pub.retain) { store->setRetainedMessage(pub, pub.getSubtopics()); } PublishCopyFactory factory(&pub); store->queuePacketAtSubscribers(factory, "", {}); }; auto f2 = [do_publish](std::shared_ptr &pub){ do_publish(*pub); }; auto &td_local = ThreadGlobals::getThreadData(); if (td_local) { Publish pub(topic, payload, qos, retain, expiryInterval, userProperties, responseTopic, correlationData, contentType); do_publish(pub); return; } std::shared_ptr pub = std::make_shared(topic, payload, qos, retain, expiryInterval, userProperties, responseTopic, correlationData, contentType); auto td = globals->getDeterministicThreadData(); auto f_bound = std::bind(f2, std::move(pub)); td->addImmediateTask(f_bound); } void flashmq_get_client_address(const std::weak_ptr &client, std::string *text, FlashMQSockAddr *addr) { if (addr) { std::memset(addr, 0, sizeof(FlashMQSockAddr)); } std::shared_ptr c = client.lock(); if (!c) return; const FMQSockaddr &client_addr = c->getAddr(); if (text) *text = client_addr.getText(); if (addr) { const socklen_t addrlen_min = std::min(client_addr.getSize(), sizeof(FlashMQSockAddr)); memcpy(addr->getAddr(), client_addr.getData(), addrlen_min); } } void flashmq_get_client_address_v4(const std::weak_ptr &client, std::string *text, sockaddr *addr, socklen_t *addrlen) { if (addr && addrlen) std::memset(addr, 0, *addrlen); std::shared_ptr c = client.lock(); if (!c) return; const FMQSockaddr &client_addr = c->getAddr(); if (text) *text = client_addr.getText(); if (addr && addrlen) { const socklen_t addrlen_min = std::min(client_addr.getSize(), *addrlen); memcpy(addr, client_addr.getData(), addrlen_min); *addrlen = addrlen_min; } } void flashmq_poll_add_fd(int fd, uint32_t events, const std::weak_ptr &p) { auto &d = ThreadGlobals::getThreadData(); if (!d) return; d->pollExternalFd(fd, events, p); } void flashmq_poll_remove_fd(uint32_t fd) { auto &d = ThreadGlobals::getThreadData(); if (!d) return; d->pollExternalRemove(fd); } sockaddr *FlashMQSockAddr::getAddr() { return reinterpret_cast(&this->addr_in6); } constexpr int FlashMQSockAddr::getLen() { return sizeof(struct sockaddr_in6); } uint32_t flashmq_add_task(std::function f, uint32_t delay_in_ms) { auto d = ThreadGlobals::getThreadData(); if (!d) throw std::runtime_error("No thread data?"); return d->addDelayedTask(f, delay_in_ms); } void flashmq_remove_task(uint32_t id) { auto d = ThreadGlobals::getThreadData(); if (!d) return; d->removeDelayedTask(id); } void flashmq_get_session_pointer(const std::string &clientid, const std::string &username, std::weak_ptr &sessionOut) { std::shared_ptr store = globals->subscriptionStore; if (!store) return; std::shared_ptr session = store->lockSession(clientid); if (!session) return; if (session->getUsername() != username) return; sessionOut = session; } void flashmq_get_client_pointer(const std::weak_ptr &session, std::weak_ptr &clientOut) { std::shared_ptr sessionLocked = session.lock(); if (!sessionLocked) return; clientOut = sessionLocked->makeSharedClient(); } void flashmq_defer_thread_ready() { auto d = ThreadGlobals::getThreadData(); if (!d) { Logger::getInstance()->log(LOG_ERROR) << "Calling flashmq_defer_thread_ready from custom thread."; return; } d->deferThreadReady = true; } void flashmq_signal_thread_ready() { const auto d = ThreadGlobals::getThreadData(); if (!d) { Logger::getInstance()->log(LOG_ERROR) << "Calling flashmq_signal_thread_ready from custom thread."; return; } if (!d->deferThreadReady) return; std::shared_ptr lockedMainApp = d->mMainApp.lock(); if (!lockedMainApp) return; lockedMainApp->queueThreadInitDecrement(); } ================================================ FILE: flashmq_plugin.h ================================================ /* * This file is part of FlashMQ (https://www.flashmq.org). It defines the * plugin interface. * * - flashmq_plugin.h defines functions to be implemented by plugins. * - flashmq_public.h describes FlashMQ's functions available to plugins. * * Those two files are public domain and you are encouraged * to copy them to your plugin project for portability. Including those files * in your project does not make it a 'derivative work'. * * Compile like: gcc -fPIC -shared plugin.cpp -o plugin.so * * It's best practice to build your plugin with the same library versions of the * build of FlashMQ you're using. In practice, this means building on the OS * version you're running on. This also means using the AppImage build of FlashMQ * is not really compatible with plugins, because that includes older, and fixed, * versions of various libraries. * * For instance, if you use OpenSSL: by the time your plugin is loaded, FlashMQ will * have already dynamically linked OpenSSL. If you then try to call OpenSSL * functions, you'll run into ABI incompatibilities. */ #ifndef FLASHMQ_PLUGIN_H #define FLASHMQ_PLUGIN_H #include #include #include #include #include "flashmq_public.h" #define FLASHMQ_PLUGIN_VERSION 5 extern "C" { /** * @brief flashmq_plugin_version must return FLASHMQ_PLUGIN_VERSION. * @return FLASHMQ_PLUGIN_VERSION. * * [Must be implemented by plugin] */ int flashmq_plugin_version(); /** * @brief flashmq_plugin_main_init is called once before the event loops start. * @param plugin_opts * * [Can optionally be implemented by plugin] */ void flashmq_plugin_main_init(std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_main_deinit is the complementary pair of flashmq_plugin_main_init(). It's called after the threads have stopped. * @param plugin_opts * * [Can optionally be implemented by plugin] */ void flashmq_plugin_main_deinit(std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_allocate_thread_memory is called once by each thread. Never again. * @param thread_data. Create a memory structure and assign it to *thread_data. * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * * Only allocate the plugin's memory here, or other things that you really only have to do once. Don't open connections, etc. That's * because the reload mechanism doesn't call this function. * * Because of the multi-core design of FlashMQ, you should treat each thread as its own domain with its own data. You can use static * variables for global scope if you must, or even create threads, but do provide proper locking where necessary. * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_allocate_thread_memory(void **thread_data, std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_deallocate_thread_memory is called once by each thread. Never again. * @param thread_data. Delete this memory. * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_deallocate_thread_memory(void *thread_data, std::unordered_map &plugin_opts); /** * @brief flashmq_plugin_init is called on thread start and config reload. It is the main place to initialize the plugin. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * @param reloading. * * The best approach to state keeping is doing everything per thread. You can initialize connections to database servers, load encryption keys, * create maps, etc. However, remember that for instance with libcurl, initing a 'multi handle' is best not done here, or at least not done * AGAIN on reload, because it will distrupt your ongoing transfers. Memory structures that you really only need to init once are best * done in 'flashmq_plugin_allocate_thread_memory()' or even 'flashmq_plugin_main_init()'. * * Keep in mind that libraries you use may not be thread safe (by default). Sometimes they use global scope in treacherous ways. As a random * example: Qt's QSqlDatabase needs a unique name for each connection, otherwise it is not thread safe and will crash. It will also hide away * libmysqlclient's requirement to do a global one-time init, that would be best done in 'flashmq_plugin_main_init()'. * * There is the option to set 'plugin_serialize_init true' in the config file, which allows some mitigation in * case you run into problems. * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_init(void *thread_data, std::unordered_map &plugin_opts, bool reloading); /** * @brief flashmq_plugin_deinit is called on thread stop and config reload. It is the precursor to initializing. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param plugin_opts. Map of flashmq_plugin_opt_* from the config file. * @param reloading * * You can throw exceptions on errors. * * [Must be implemented by plugin] */ void flashmq_plugin_deinit(void *thread_data, std::unordered_map &plugin_opts, bool reloading); /** * @brief flashmq_plugin_periodic is called every x seconds as defined in the config file. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * * You may need to periodically refresh data from a database, post stats, etc. You can do that from here. It's queued * in each thread at the same time, so you can perform somewhat synchronized events in all threads. * * Note that it's executed in the event loop, so it blocks the thread if you block here. If you need asynchronous operation, * you can make threads yourself. Be sure to synchronize data access properly in that case. * * The setting plugin_timer_period sets this interval in seconds. * * You can throw exceptions on errors. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_periodic_event(void *thread_data); /** * @brief flashmq_plugin_alter_subscription can optionally be implemented if you want to be able to change incoming subscriptions. * @param thread_data * @param clientid * @param topic non-const reference which can be changed. * @param subtopics * @param qos non-const reference which can be changed. * @param userProperties * @return boolean indicating whether the subscription was changed. Not returning the truth here results in unpredictable behavior. * * In case of shared subscriptions, you will see the original subscription path, like '$share/myshare/battery/voltage'. You have the * chance to change every aspect of it, like make it non-shared. * * [Can optionally be implemented by plugin] */ bool flashmq_plugin_alter_subscription( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, uint8_t &qos, const std::vector> *userProperties); /** * @brief flashmq_plugin_alter_publish allows changing of the non-const arguments. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @return boolean indicating whether the packet was changed. It saves FlashMQ from having to do a full compare. Not returning the truth here * results in unpredictable behavior. Note: if only changing retain, you can get away with returning false. * * Be aware that changing publishes may incur a (slight) reduction in performance. * * [Can optionally be implemented by plugin] */ bool flashmq_plugin_alter_publish( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::optional &correlationData, std::optional &responseTopic, std::optional &contentType, std::vector> *userProperties); /** * @brief flashmq_plugin_login_check is called on login of a client. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param username * @param password * @param client Example use is for storing in a async operation and passing to flashmq_continue_async_authentication. * @return * * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error, * because there's nothing else to do: the state of FlashMQ won't change. * * Note that there is a setting 'plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not * thread-safe. It will negate much of FlashMQ's multi-core model. * * The AuthResult::async can be used if your auth check causes blocking IO (like network). You can save the weak pointer to the client * and do the auth in a thread or any kind of async way. FlashMQ's event loop will then continue. You can call flashmq_continue_async_authentication * later with the result. * * [Must be implemented by plugin] */ AuthResult flashmq_plugin_login_check( void *thread_data, const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client); /** * @brief flashmq_plugin_client_disconnected Called when clients disconnect or their keep-alive expire. * @param thread_data * @param clientid * * Is only called for authenticated clients, to avoid spoofing. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_client_disconnected(void *thread_data, const std::string &clientid); /** * @brief flashmq_plugin_on_unsubscribe is called after unsubscribe. Unsubscribe actions can't be manipulated or blocked. * @param topic Does not contain the share name. * @param subtopics Does not contain the share name. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_on_unsubscribe( void *thread_data, const std::weak_ptr &session, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, const std::vector> *userProperties); /** * @brief flashmq_plugin_acl_check is called on publish, deliver and subscribe. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param shareName The shared subscription name in a filter like '$share/my_share_name/one/two'. Is only present on AclAccess::subscribe. * @return * * You could throw exceptions here, but that will be slow and pointless. It will just get converted into AuthResult::error, * because there's nothing else to do: the state of FlashMQ won't change. * * Controlling subscribe access can have several benefits. For instance, you may want to avoid subscriptions that cause * a lot of server load. If clients pester you with many subscriptions like '+/+/+/+/+/+/+/+/+/', that causes a lot * of tree walking. Similarly, if all clients subscribe to '#' because it's easy, every single message passing through * the server will have to be ACL checked for every subscriber. * * Note that only MQTT 3.1.1 or higher has a 'failed' return code for subscribing, so older clients will see a normal * ack and won't know it failed. * * Note that there is a setting 'plugin_serialize_auth_checks'. Use only as a last resort if your plugin is not * thread-safe. It will negate much of FlashMQ's multi-core model. * * When the 'access' is 'subscribe' and it's a shared subscription (like '$share/myshare/one/two/three'), you only get * the effective topic filter (like 'one/two/three'). However, since plugin version 4, there is the argument 'shareName' for that. * * [Must be implemented by plugin] */ AuthResult flashmq_plugin_acl_check( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties); /** * @brief flashmq_plugin_extended_auth can be used to implement MQTT 5 extended auth. This is optional. * @param thread_data is the memory you allocated in flashmq_plugin_allocate_thread_memory. * @param clientid * @param stage * @param authMethod * @param authData * @param userProperties are optional (and are nullptr in that case) * @param returnData is a non-const string, that you can set to include data back to the client in an AUTH packet. * @param username is a non-const string. You can set it, which will then apply to ACL checking and show in the logs. * @param client Use this for AuthResult::async. See flashmq_plugin_login_check(). * @return an AuthResult enum class value * * [Can optionally be implemented by plugin] */ AuthResult flashmq_plugin_extended_auth( void *thread_data, const std::string &clientid, ExtendedAuthStage stage, const std::string &authMethod, const std::string &authData, const std::vector> *userProperties, std::string &returnData, std::string &username, const std::weak_ptr &client); /** * @brief Is called when the socket watched by 'flashmq_poll_add_fd()' has an event. * @param thread_data is memory allocated in flashmq_plugin_allocate_thread_memory(). * @param fd * @param events contains the events as a bit flags. See 'man epoll'. * @param p can be made back into your type with 'std::shared_ptr sp = std::static_pointer_cast(b.lock())'. * This allows you to properly lend pointers to the event system that you can actually check for expiration. * * [Can optionally be implemented by plugin] */ void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t events, const std::weak_ptr &p); } #endif // FLASHMQ_PLUGIN_H ================================================ FILE: flashmq_plugin_deprecated.h ================================================ #ifndef FLASHMQ_PLUGIN_DEPRECATED_H #define FLASHMQ_PLUGIN_DEPRECATED_H #include #include "flashmq_public.h" extern "C" { class API FlashMQSockAddr { struct sockaddr_in6 addr_in6; public: struct sockaddr *getAddr(); static constexpr int getLen(); }; void API mosquitto_log_printf(int level, const char *fmt, ...); void API flashmq_get_client_address(const std::weak_ptr &client, std::string *text, FlashMQSockAddr *addr); void API flashmq_plugin_remove_client(const std::string &clientid, bool alsoSession, ServerDisconnectReasons reasonCode); void API flashmq_plugin_remove_subscription(const std::string &clientid, const std::string &topicFilter); void API flashmq_continue_async_authentication(const std::weak_ptr &client, AuthResult result, const std::string &authMethod, const std::string &returnData); } #endif // FLASHMQ_PLUGIN_DEPRECATED_H ================================================ FILE: flashmq_public.h ================================================ /* * This file is part of FlashMQ (https://www.flashmq.org). It describes * public FlashMQ functions available to plugins. * * This interface definition is public domain and you are encouraged * to copy it to your plugin project together with flashmq_plugin.h, * for portability. Including this file in your project does not make * it a 'derivative work'. */ #ifndef FLASHMQ_PUBLIC_H #define FLASHMQ_PUBLIC_H #include #include #include #include #include // Compatible with Mosquitto, for (auth) plugin compatability. #define LOG_NONE 0x00 #define LOG_INFO 0x01 #define LOG_NOTICE 0x02 #define LOG_WARNING 0x04 #define LOG_ERR 0x08 #define LOG_ERROR 0x08 #define LOG_DEBUG 0x10 #define LOG_SUBSCRIBE 0x20 #define LOG_UNSUBSCRIBE 0x40 #define LOG_PUBLISH 0x80 // Not compatible with Mosquitto #define API __attribute__((visibility("default"))) extern "C" { class Client; class Session; /** * @brief The AclAccess enum's numbers are compatible with Mosquitto's 'int access'. * * read = reading a publish published by someone else. * write = doing a publish. * subscribe = subscribing. */ enum class AclAccess { none = 0, read = 1, write = 2, subscribe = 4, register_will = 100 }; /** * @brief The AuthResult enum's numbers are compatible with Mosquitto's auth result. * * async = defer the decision until you have the result from an async call, which can be submitted with flashmq_continue_async_authentication(). * * auth_continue = part of MQTT5 extended authentication, which can be a back-and-forth between server and client. * * success_without_retained_delivery = allow the subscription action, but don't try to give client the matching retained messages. This * can be used prevent load on the server. For instance, if there are many retained messages and clients subscribe to '#'. This value * is only valid for AclAccess::subscribe, and requires FlashMQ version 1.9.0 or newer. * * success_without_setting_retained = allow the write action, but don't set the retained message (if it was set to retain to begin * with). MQTT5 subscribers with 'retain as published' will still see the retain flag set. This value is only valid for * AclAccess::write and requires a FlashMQ version 1.16.0 or higher. * * success_but_drop_publish / success_but_drop = send success SUBACK or PUBACK back to client (if QoS) but don't process the packet. This * may be useful in combination with flashmq_publish_message() or flashmq_plugin_add_subscription() if you need to publish or subscribe * something new entirely (different topic(s) or payload for instance). This only works with AclAccess::write (FlashMQ 1.20.0 or higher) * and AclAccess::subscribe (FlashMQ 1.21.0 or higher). * * server_not_available = to be used as log-in result, for when you don't have auth data yet, for instance. MQTT3 and MQTT5 both support * sending 'ServerUnavailble' in their CONNACK, when this result is used. Requires FlashMQ 1.17.0 or newer. */ enum class AuthResult { success = 0, auth_method_not_supported = 10, acl_denied = 12, login_denied = 11, error = 13, server_not_available = 14, async = 50, success_without_retained_delivery = 51, success_without_setting_retained = 52, success_but_drop_publish = 53, success_but_drop = 53, auth_continue = -4 }; enum class ExtendedAuthStage { None = 0, Auth = 10, Reauth = 20, Continue = 30 }; /** * @brief The ServerDisconnectReasons enum lists the possible values to initiate client disconnect with. * * This is a subset of all MQTT5 reason codes that are allowed in a disconnect packet, and that make sense for plugin use. */ enum class ServerDisconnectReasons { NormalDisconnect = 0, UnspecifiedError = 128, ProtocolError = 130, ImplementationSpecificError = 131, NotAuthorized = 135, ServerBusy = 137, MessageRateTooHigh = 150 }; /** * @brief flashmq_logf calls the internal logger of FlashMQ. The logger mutexes all access, so is thread-safe, and writes to disk * asynchronously, so it won't hold you up. * @param level is any of the levels defined above, starting with LOG_. * @param str * * FlashMQ makes no distinction between INFO and NOTICE. */ void API flashmq_logf(int level, const char *str, ...); /** * @brief flashmq_plugin_remove_client queues a removal of a client in the proper thread, including session if required. It can be called by * plugin code (meaning this function does not need to be implemented). * @param session Can be obtained with flashmq_get_session_pointer(). * @param alsoSession also remove the session if it would otherwise remain. * @param reasonCode is only for MQTT5, because MQTT3 doesn't have server-initiated disconnect packets. * * Many clients will automatically reconnect, so you'll have to also remove permissions of the client in question, probably. * * Can be called from any thread: the action will be queued properly. * * [New version since plugin version 4] */ void API flashmq_plugin_remove_client_v4(const std::weak_ptr &session, bool alsoSession, ServerDisconnectReasons reasonCode); /** * @brief flashmq_plugin_remove_subscription removes a client's subscription from the central store. It can be called by plugin code (meaning * this function does not need to be implemented). * @param session Can be obtained with flashmq_get_session_pointer(). * @param topicFilter Like 'one/two/three' or '$share/myshare/one/two/three'. * * It matches only literal filters. So removing '#' would only remove an active subscription on '#', not 'everything'. * * Will throw exceptions on certain errors. * * Can be called from any thread, because the global subscription store is mutexed. * * [New version since plugin version 4] */ void API flashmq_plugin_remove_subscription_v4(const std::weak_ptr &session, const std::string &topicFilter); /** * @brief flashmq_plugin_add_subscription * @param session Can be obtained with flashmq_get_session_pointer(). * @param topicFilter Like 'one/two/three' or '$share/myshare/one/two/three'. * @return boolean True when session found and subscription actually added. * * Will throw exceptions on certain errors. * * Can be called from any thread, because the global subscription store is mutexed. */ bool API flashmq_plugin_add_subscription( const std::weak_ptr &session, const std::string &topicFilter, uint8_t qos, bool noLocal, bool retainAsPublished, const uint32_t subscriptionIdentifier); /** * @brief flashmq_continue_async_authentication is to continue/finish async authentication. * @param client * @param result * @param delay Introducing a delay on failure can be a benificial security feature. * * When you've previously returned AuthResult::async in the authentication check, because you need to perform a network call for instance, * you can submit the final result back to FlashMQ with this function. The action will be queued in the proper thread. * * It uses a weak pointer to Client instead of client id, because clients are in limbo at this point, and a client id isn't necessarily * correct (anymore). The login functions also give this weak pointer so you can store it with the async operation, to be used again later for * a call to this function. * * Can be called from any thread. * * [New version since plugin version 4, FlashMQ version 1.25.0] */ void API flashmq_continue_async_authentication_v4( const std::weak_ptr &client, AuthResult result, const std::string &authMethod, const std::string &returnData, const uint32_t delay_in_ms); /** * @brief flashmq_publish_message Publish a message from the plugin. * * Can be called from any thread. */ void API flashmq_publish_message( const std::string &topic, const uint8_t qos, const bool retain, const std::string &payload, uint32_t expiryInterval=0, const std::vector> *userProperties = nullptr, const std::string *responseTopic=nullptr, const std::string *correlationData=nullptr, const std::string *contentType=nullptr); /** * @brief flashmq_get_client_address_v4 * @param client A client pointer as provided by 'flashmq_plugin_login_check'. * @param text If not nullptr, will be assigned the address in text form, like 192.168.1.1 or "2001:0db8:85a3:0000:1319:8a2e:0370:7344". * @param addr If not nullptr, will fill a sockaddr struct, for low level operations. * @param addrlen Size of addr. Supply the length. Afterwards, the actual size will be reported. * * The text, addr and addrlen must be pointers to local variables in the calling context. * * Note that the sockaddr API is hard to use safely in C++. Use of addr can very easily lead to undetected undefined behavior in * C++ because of type aliasing violations. The only safe way is to avoid using a casted object, instead use memcpy into structs of the * correct type: first 'struct sockaddr', to read the family, then like 'struct sockaddr_in' or 'struct sockaddr_in6'. In other * words, avoid accessing the members of the struct, even those of 'struct sockaddr'. * * Example of initializing the addr variable: * * struct sockaddr_storage addr_mem; * struct sockaddr *addr = reinterpret_cast(&addr_mem); * socklen_t addrlen = sizeof(addr_mem); * * Afterwards, do the memcpy stuff to read it, or pass it verbatim to library functions. * * [New version since plugin version 4] */ void API flashmq_get_client_address_v4(const std::weak_ptr &client, std::string *text, sockaddr *addr, socklen_t *addrlen); /** * @brief flashmq_get_session_pointer Get reference counted weak pointer of a session. * @param clientid The client ID of the session you're retrieving. * @param username The username is used for verification, as a security measure. * @param sessionOut The result (has to be an output parameter because we can't return it). * * The weak pointer will acurately reflect the original session. If it has been replaced with a new one with * the same client ID, this weak pointer will be 'expired'. */ void API flashmq_get_session_pointer(const std::string &clientid, const std::string &username, std::weak_ptr &sessionOut); /** * @brief flashmq_get_client_pointer Get reference counted client pointer of a session. * @param clientOut The result (has to be an output parameter because we can't return it). * * Can we used to feed to other functions, or to check if the client is still online. */ void API flashmq_get_client_pointer(const std::weak_ptr &session, std::weak_ptr &clientOut); /** * @brief Allows async operation of outgoing connections you may need to make. It adds the file descriptor to * the epoll listener. * @param fd * @param events epoll events, typically EPOLLIN (ready read) and EPOLLOUT (ready write). Should be or'ed together, * like 'EPOLLOUT | EPOLLIN'. See 'man epoll'. * @param p weak pointer. Can be a weak copy of a shared pointer with proper type. Like p = std::make_share(). * You'll get it back in 'flashmq_plugin_poll_event_received()'. Use is optional. For libcurl multi socket, * you don't need it. * * You can do this once you have a connection with something external. * * You can also call it again with different events, in which case it will modify the existing entry. If you specify * a non-expired p, it will overwrite the original data associated with the fd. * * Is meant for the local worker thread. It's a no-op in custom threads. * * Will throw exceptions on error, so be sure to handle them. */ void API flashmq_poll_add_fd(int fd, uint32_t events, const std::weak_ptr &p); /** * @brief Remove the fd from the event polling system. * @param fd * * Closing a socket will also remove it from the epoll system, but if you don't call this function on close, you may get stray * events once the fd number is reused. There is protection against it, but you may end up with unpredictable behavior. * * Is meant for the local worker thread. It's a no-op in custom threads. * * Will throw exceptions on error, so be sure to handle them. */ void API flashmq_poll_remove_fd(uint32_t fd); /** * @brief call a task later, once. * @param f Function, that can be created with std::bind, for instance. * @param delay_in_ms * @return id of the timer, which can be used to remove it. * * The task queue is local to the current worker thread, including the id returned. It's a no-op in custom threads. * * This can be necessary for asynchronous interfaces, like libcurl. * * Can throw an exceptions. */ uint32_t API flashmq_add_task(std::function f, uint32_t delay_in_ms); /** * @brief Remove a task with id as given by 'flashmq_add_task()'. * @param id * * The task queue is local to the current worker thread, including the id returned. It's a no-op in custom threads. */ void API flashmq_remove_task(uint32_t id); /** * @brief Use this if you don't want your thread to be signalled ready after flashmq_plugin_init(). * * Normally, when threads return from flashmq_plugin_init(), the worker thread signals to the main thread it's done. Once * all threads have done this, the listeners are created and systemd is notified. You may want to defer this, until all * your auth data is loaded, for instance. * * One simple solution to queue back this event is to use flashmq_add_task() to poll on a timer and check if your data is * loaded, and run flashmq_signal_thread_ready() when it's ready. * * [Introduced in FlashMQ 1.26.0] */ void API flashmq_defer_thread_ready(); /** * @brief Counterpart to flashmq_defer_thread_ready(). * * There is no timeout action. If you don't do this, FlashMQ will never be ready. When starting from systemd unit of * type 'notify', it will be restarted after a set time. * * [Introduced in FlashMQ 1.26.0] */ void API flashmq_signal_thread_ready(); } #endif // FLASHMQ_PUBLIC_H ================================================ FILE: flashmqtestclient.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "flashmqtestclient.h" #include #include #include #include #include "threadloop.h" #include "utils.h" #include "client.h" #include "threaddata.h" #include "threadglobals.h" #define TEST_CLIENT_MAX_EVENTS 25 int FlashMQTestClient::clientCount = 0; void FlashMQTestClient::ReceivedObjects::clear() { receivedPackets.clear(); receivedPublishes.clear(); } FlashMQTestClient::FlashMQTestClient() : testServerWorkerThreadData(0, settings, pluginLoader, {}) { } /** * @brief FlashMQTestClient::~FlashMQTestClient properly quits the threads when exiting. * * This prevents accidental crashes on calling terminate(), and Qt Macro's prematurely end the method, skipping explicit waits after the tests. */ FlashMQTestClient::~FlashMQTestClient() { waitForQuit(); } void FlashMQTestClient::waitForCondition(std::function f, int timeout) { const int loopCount = (timeout * 1000) / 10; int n = 0; while(n++ < loopCount) { usleep(10000); if (f()) break; } if (!f()) { throw std::runtime_error("Wait condition failed."); } } void FlashMQTestClient::clearReceivedLists() { auto received_objects_locked = receivedObjects.lock(); received_objects_locked->clear(); } void FlashMQTestClient::setWill(std::shared_ptr &will) { this->will = will; } void FlashMQTestClient::disconnect(ReasonCodes reason) { std::shared_ptr client = client_weak.lock(); client->setDisconnectStage(DisconnectStage::SendPendingAppData); Disconnect d(client->getProtocolVersion(), reason); client->writeMqttPacket(d); } void FlashMQTestClient::start() { testServerWorkerThreadData.start(); } void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion, int port, bool _waitForConnack) { connectClient(protocolVersion, true, 0, [](Connect&){}, port, _waitForConnack); } void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval, int port, bool _waitForConnack) { connectClient(protocolVersion, clean_start, session_expiry_interval, [](Connect&){}, port, _waitForConnack); } void FlashMQTestClient::connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval, std::function manipulateConnect, int port, bool _waitForConnack) { int sockfd = check(socket(AF_INET, SOCK_STREAM, 0)); struct sockaddr_in servaddr{}; const std::string hostname = "127.0.0.1"; servaddr.sin_family = AF_INET; servaddr.sin_addr.s_addr = inet_addr(hostname.c_str()); servaddr.sin_port = htons(port); int flags = fcntl(sockfd, F_GETFL); fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); int rc = connect(sockfd, reinterpret_cast(&servaddr), sizeof (servaddr)); if (rc < 0 && errno != EINPROGRESS) { throw std::runtime_error(strerror(errno)); } const std::string clientid = formatString("testclient_%d", clientCount++); std::shared_ptr client = std::make_shared( ClientType::Normal, sockfd, testServerWorkerThreadData.getThreadData(), FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, reinterpret_cast(&servaddr), settings); this->client_weak = client; client->setClientProperties(protocolVersion, clientid, {}, "user", false, 60); { // Hack to make it work with the rvalue argument whilest not voiding our own client. std::shared_ptr dummyToMoveFrom = client; client->addToEpoll(EPOLLIN); // Normally the worker thread does this, but we must avoid races, for the code below, that already tries to mod the epoll. testServerWorkerThreadData->giveClient(std::move(dummyToMoveFrom)); } // This gets called in the test client's worker thread, but the STL container's minimal thread safety should be enough: only list manipulation is // mutexed, elements within are not. client->onPacketReceived = [this](MqttPacket &pack) { std::shared_ptr client = this->client_weak.lock(); if (pack.packetType == PacketType::PUBLISH) { pack.parsePublishData(client); auto received_objects_locked = receivedObjects.lock(); MqttPacket copyPacket = pack; received_objects_locked->receivedPublishes.push_back(copyPacket); if (pack.getPublishData().qos == 1) { PubResponse pubAck(client->getProtocolVersion(), PacketType::PUBACK, ReasonCodes::Success, pack.getPacketId()); client->writeMqttPacketAndBlameThisClient(pubAck); } else if (pack.getPublishData().qos == 2) { PubResponse pubAck(client->getProtocolVersion(), PacketType::PUBREC, ReasonCodes::Success, pack.getPacketId()); client->writeMqttPacketAndBlameThisClient(pubAck); } } else if (pack.packetType == PacketType::PUBREL) { pack.parsePubRelData(); PubResponse pubComp(client->getProtocolVersion(), PacketType::PUBCOMP, ReasonCodes::Success, pack.getPacketId()); client->writeMqttPacketAndBlameThisClient(pubComp); } else if (pack.packetType == PacketType::PUBREC) { pack.parsePubRecData(); PubResponse pubRel(client->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, pack.getPacketId()); client->writeMqttPacketAndBlameThisClient(pubRel); } auto received_objects_locked = receivedObjects.lock(); received_objects_locked->receivedPackets.push_back(std::move(pack)); }; Connect connect(protocolVersion, client->getClientId()); connect.will = this->will; connect.clean_start = clean_start; connect.sessionExpiryInterval = session_expiry_interval; manipulateConnect(connect); MqttPacket connectPack(connect); client->writeMqttPacketAndBlameThisClient(connectPack); if (_waitForConnack) { waitForConnack(); client->setAuthenticated(true); } } void FlashMQTestClient::subscribe(const std::string topic, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionIdentifier, RetainHandling retainHandling) { std::shared_ptr client = client_weak.lock(); clearReceivedLists(); const uint16_t packet_id = 66; std::vector subs; subs.emplace_back(topic, qos); subs.back().noLocal = noLocal; subs.back().retainAsPublished = retainAsPublished; subs.back().retainHandling = retainHandling; MqttPacket subPack(client->getProtocolVersion(), packet_id, subscriptionIdentifier, subs); client->writeMqttPacketAndBlameThisClient(subPack); waitForCondition([&]() { auto ro = receivedObjects.lock(); return !ro->receivedPackets.empty() && ro->receivedPackets.front().packetType == PacketType::SUBACK; }); auto ro = receivedObjects.lock(); MqttPacket &subAck = ro->receivedPackets.front(); SubAckData data = subAck.parseSubAckData(); if (data.packet_id != packet_id) throw std::runtime_error("Incorrect packet id in suback"); if (!std::all_of(data.subAckCodes.begin(), data.subAckCodes.end(), [&](uint8_t x) { return x <= qos ;})) { throw SubAckIsError("Suback indicates error."); } } void FlashMQTestClient::unsubscribe(const std::string &topic) { std::shared_ptr client = client_weak.lock(); clearReceivedLists(); const uint16_t packet_id = 66; std::vector unsubs; unsubs.emplace_back(topic); MqttPacket unsubPack(client->getProtocolVersion(), packet_id, unsubs); client->writeMqttPacketAndBlameThisClient(unsubPack); waitForCondition([&]() { auto ro = receivedObjects.lock(); return !ro->receivedPackets.empty() && ro->receivedPackets.front().packetType == PacketType::UNSUBACK; }); // TODO: parse the UNSUBACK and check reason codes. } void FlashMQTestClient::publish(Publish &pub) { std::shared_ptr client = client_weak.lock(); clearReceivedLists(); const uint16_t packet_id = 77; MqttPacket pubPack(client->getProtocolVersion(), pub); if (pub.qos > 0) pubPack.setPacketId(packet_id); client->writeMqttPacketAndBlameThisClient(pubPack); if (pub.qos == 1) { waitForCondition([&]() { auto ro = receivedObjects.lock(); return !ro->receivedPackets.empty(); }); auto ro = receivedObjects.lock(); MqttPacket &pubAckPack = ro->receivedPackets.front(); pubAckPack.parsePubAckData(); if (pubAckPack.packetType != PacketType::PUBACK) throw std::runtime_error("First packet received from server is not a PUBACK, but " + packetTypeToString(pubAckPack.packetType)); if (pubAckPack.getPacketId() != packet_id) throw std::runtime_error("Packet ID mismatch between publish and ack on QoS 1 publish."); // We may have received publishes along with our acks, if we publish and subscribe to the same topic with // one client, so we have to filter out the publishes. int metaPacketCount = std::count_if(ro->receivedPackets.begin(), ro->receivedPackets.end(), [](MqttPacket &pack){return pack.packetType != PacketType::PUBLISH;}); if (metaPacketCount != 1) throw std::runtime_error("Packet ID mismatch on QoS 1 publish or packet count wrong."); } else if (pub.qos == 2) { waitForCondition([&]() { auto ro = receivedObjects.lock(); return ro->receivedPackets.size() >= 2; }); auto ro = receivedObjects.lock(); MqttPacket &pubRecPack = ro->receivedPackets.front(); pubRecPack.parsePubRecData(); MqttPacket &pubCompPack = ro->receivedPackets.back(); pubCompPack.parsePubComp(); if (pubRecPack.packetType != PacketType::PUBREC) throw std::runtime_error("First packet received from server is not a PUBREC, but " + packetTypeToString(pubRecPack.packetType)); if (pubCompPack.packetType != PacketType::PUBCOMP) throw std::runtime_error("Last packet received from server is not a PUBCOMP."); if (pubRecPack.getPacketId() != packet_id || pubCompPack.getPacketId() != packet_id) throw std::runtime_error("Packet ID mismatch on QoS 2 publish."); } } void FlashMQTestClient::writeAuth(const Auth &auth) { std::shared_ptr client = client_weak.lock(); MqttPacket pack(auth); client->writeMqttPacketAndBlameThisClient(pack); } void FlashMQTestClient::publish(const std::string &topic, const std::string &payload, uint8_t qos) { Publish pub(topic, payload, qos); publish(pub); } void FlashMQTestClient::waitForQuit() { testServerWorkerThreadData->queueQuit(); testServerWorkerThreadData.waitForQuit(); } void FlashMQTestClient::waitForConnack() { waitForCondition([&]() { auto ro = receivedObjects.lock(); return std::any_of(ro->receivedPackets.begin(), ro->receivedPackets.end(), [](const MqttPacket &p) { return p.packetType == PacketType::CONNACK || p.packetType == PacketType::AUTH; }); }); } void FlashMQTestClient::waitForDisconnectPacket() { waitForCondition([&]() { auto ro = receivedObjects.lock(); return std::any_of(ro->receivedPackets.begin(), ro->receivedPackets.end(), [](const MqttPacket &p) { return p.packetType == PacketType::DISCONNECT; }); }); } void FlashMQTestClient::waitForMessageCount(const size_t count, int timeout) { waitForCondition([&]() { auto ro = receivedObjects.lock(); return ro->receivedPublishes.size() >= count; }, timeout); } void FlashMQTestClient::waitForPacketCount(const size_t count, int timeout) { waitForCondition([&]() { auto ro = receivedObjects.lock(); return ro->receivedPackets.size() >= count; }, timeout); } std::shared_ptr FlashMQTestClient::getClient() { std::shared_ptr client = client_weak.lock(); return client; } std::string FlashMQTestClient::getClientId() { std::shared_ptr client = client_weak.lock(); return client->getClientId(); } ProtocolVersion FlashMQTestClient::getProtocolVersion() { std::shared_ptr client = client_weak.lock(); return client->getProtocolVersion(); } bool FlashMQTestClient::clientExpired() const { return client_weak.expired(); } ================================================ FILE: flashmqtestclient.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef FLASHMQTESTCLIENT_H #define FLASHMQTESTCLIENT_H #include #include #include "pluginloader.h" #include "settings.h" #include "threaddata.h" #include "checkedweakptr.h" #include "mutexowned.h" class SubAckIsError : public std::runtime_error { public: SubAckIsError(const std::string &msg) : std::runtime_error(msg) {} }; /** * @brief The FlashMQTestClient class uses the existing server code as a client, for testing purposes. */ class FlashMQTestClient { struct ReceivedObjects { std::vector receivedPackets; std::vector receivedPublishes; void clear(); }; std::shared_ptr pluginLoader = std::make_shared(); Settings settings; ThreadDataOwner testServerWorkerThreadData; CheckedWeakPtr client_weak; std::shared_ptr will; static int clientCount; void waitForCondition(std::function f, int timeout = 1); public: MutexOwned receivedObjects; FlashMQTestClient(); ~FlashMQTestClient(); void start(); void connectClient(ProtocolVersion protocolVersion, int port=21883, bool _waitForConnack=true); void connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval, int port=21883, bool _waitForConnack=true); void connectClient(ProtocolVersion protocolVersion, bool clean_start, uint32_t session_expiry_interval, std::function manipulateConnect, int port=21883, bool _waitForConnack=true); void subscribe(const std::string topic, uint8_t qos, bool noLocal=false, bool retainAsPublished=false, uint32_t subscriptionIdentifier=0, RetainHandling retainHandling=RetainHandling::SendRetainedMessagesAtSubscribe); void unsubscribe(const std::string &topic); void publish(const std::string &topic, const std::string &payload, uint8_t qos); void publish(Publish &pub); void writeAuth(const Auth &auth); void clearReceivedLists(); void setWill(std::shared_ptr &will); void disconnect(ReasonCodes reason); void waitForQuit(); void waitForConnack(); void waitForDisconnectPacket(); void waitForMessageCount(const size_t count, int timeout = 1); void waitForPacketCount(const size_t count, int timeout = 1); std::shared_ptr getClient(); std::string getClientId(); ProtocolVersion getProtocolVersion(); bool clientExpired() const; }; #endif // FLASHMQTESTCLIENT_H ================================================ FILE: fmqmain.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include #include #include #include #include #include "mainapp.h" #include "utils.h" #include "exceptions.h" std::weak_ptr globalMainApp; void signal_handler(int signal) { std::shared_ptr locked = globalMainApp.lock(); if (!locked) return; if (signal == SIGPIPE) { return; } if (signal == SIGHUP) { locked->queueConfigReload(); } else if (signal == SIGUSR1) { locked->queueReopenLogFile(); } else if (signal == SIGUSR2) { locked->queueMemoryTrim(); } else if (signal == SIGTERM || signal == SIGINT) { locked->queueQuit(); } } int register_signal_handers() { struct sigaction sa {}; sa.sa_handler = &signal_handler; sigemptyset(&sa.sa_mask); sa.sa_flags = SA_RESTART; for (int signal : {SIGHUP, SIGTERM, SIGINT, SIGUSR1, SIGUSR2}) { if (sigaction(signal, &sa, nullptr) != 0) { Logger *logger = Logger::getInstance(); logger->logf(LOG_ERR, "Error registering signal handlers"); return -1; } } sigset_t set; sigemptyset(&set); sigaddset(&set,SIGPIPE); int r; if ((r = sigprocmask(SIG_BLOCK, &set, NULL) != 0)) { return r; } return 0; } int fmqmain(int argc, char *argv[]) { #ifndef OPENSSL_THREADS std::cerr << "Error: FlashMQ was compiled with an OpenSSL without thread support." << std::endl; exit(66); #endif Logger *logger = nullptr; try { logger = Logger::getInstance(); std::shared_ptr mainapp = MainApp::initMainApp(argc, argv); globalMainApp = mainapp; check(register_signal_handers()); std::string sse = "without SSE support"; #ifdef __SSE4_2__ sse = "with SSE4.2 support"; #endif #ifdef NDEBUG logger->logf(LOG_NOTICE, "Starting FlashMQ version %s, release build %s.", FLASHMQ_VERSION, sse.c_str()); #else logger->logf(LOG_NOTICE, "Starting FlashMQ version %s, debug build %s.", FLASHMQ_VERSION, sse.c_str()); #endif mainapp->start(); logger->quit(); } catch (ConfigFileException &ex) { if (logger) logger->quit(); // Not using the logger here, because we may have had all sorts of init errors while setting it up. std::cerr << ex.what() << std::endl; return 99; } catch (std::exception &ex) { if (logger) logger->quit(); // Not using the logger here, because we may have had all sorts of init errors while setting it up. std::cerr << ex.what() << std::endl; return 1; } return 0; } ================================================ FILE: fmqmain.h ================================================ void signal_handler(int signal); int register_signal_handers(); int fmqmain(int argc, char *argv[]); ================================================ FILE: fmqsockaddr.cpp ================================================ #include "fmqsockaddr.h" #include #include #include "utils.h" FMQSockaddr::FMQSockaddr(const sockaddr *addr) : family(getFamilyFromSockAddr(addr)) { if (addr == nullptr) return; if (this->family == AF_INET) { std::memcpy(dat.data(), addr, sizeof(struct sockaddr_in)); } else if (this->family == AF_INET6) { std::memcpy(dat.data(), addr, sizeof(struct sockaddr_in6)); } else if (this->family == AF_UNIX) { std::memcpy(dat.data(), addr, sizeof(struct sockaddr_un)); } else { throw std::runtime_error("Trying to make IPv4 or IPv6 address from address structure that is neither of those"); } this->text = sockaddrToString(addr); } /** * This should (mostly) only be necessary when it's needed to pass to a system API. Don't be tempted to cast this back * into a specific sockaddr, because that will result in undefined behavior because of type aliasing * violations, unless your compiler bends the rules. */ const sockaddr *FMQSockaddr::getSockaddr() const { return reinterpret_cast(dat.data()); } const char *FMQSockaddr::getData() const { return dat.data(); } socklen_t FMQSockaddr::getSize() const { if (this->family == AF_INET) return sizeof(struct sockaddr_in); if (this->family == AF_INET6) return sizeof(struct sockaddr_in6); if (this->family == AF_UNIX) return sizeof(struct sockaddr_un); return 0; } void FMQSockaddr::setPort(uint16_t port) { if (this->family == AF_INET) { struct sockaddr_in tmp; std::memcpy(&tmp, dat.data(), sizeof(tmp)); tmp.sin_port = htons(port); std::memcpy(dat.data(), &tmp, sizeof(tmp)); } else if (this->family == AF_INET6) { struct sockaddr_in6 tmp; std::memcpy(&tmp, dat.data(), sizeof(tmp)); tmp.sin6_port = htons(port); std::memcpy(dat.data(), &tmp, sizeof(tmp)); } } void FMQSockaddr::setAddress(const std::string &address) { in_port_t port = 0; if (this->family == AF_INET) { struct sockaddr_in tmp; std::memcpy(&tmp, dat.data(), sizeof(tmp)); port = tmp.sin_port; } else if (this->family == AF_INET6) { struct sockaddr_in6 tmp; std::memcpy(&tmp, dat.data(), sizeof(tmp)); port = tmp.sin6_port; } { struct in_addr a {}; if (inet_pton(AF_INET, address.c_str(), &a) > 0) { sockaddr_in addr {}; addr.sin_addr = a; addr.sin_port = port; addr.sin_family = AF_INET; memcpy(dat.data(), &addr, sizeof(addr)); text = sockaddrToString(getSockaddr()); return; } } { struct in6_addr a {}; if (inet_pton(AF_INET6, address.c_str(), &a) > 0) { sockaddr_in6 addr {}; addr.sin6_addr = a; addr.sin6_port = port; addr.sin6_family = AF_INET6; memcpy(dat.data(), &addr, sizeof(addr)); text = sockaddrToString(getSockaddr()); } } } void FMQSockaddr::setAddressName(const std::string &addressName) { this->name = addressName; } const std::string &FMQSockaddr::getText() const { if (name) return name.value(); return text; } int FMQSockaddr::getFamily() const { return this->family; } ================================================ FILE: fmqsockaddr.h ================================================ #ifndef FMQSOCKADDR_H #define FMQSOCKADDR_H #include #include #include #include #include /** * @brief A class for storing socket addresses, and accesses them in a way that avoids type aliasing violations. */ class FMQSockaddr { std::vector dat = std::vector(sizeof(sockaddr_storage)); sa_family_t family = AF_UNSPEC; std::string text; std::optional name; public: FMQSockaddr(const struct sockaddr *addr); const struct sockaddr *getSockaddr() const; const char *getData() const; socklen_t getSize() const; void setPort(uint16_t port); void setAddress(const std::string &address); void setAddressName(const std::string &addressName); const std::string &getText() const; int getFamily() const; }; #endif // FMQSOCKADDR_H ================================================ FILE: fmqssl.cpp ================================================ #include #include "fmqssl.h" FmqSsl::~FmqSsl() { if (!d) return; /* * We write the shutdown when we can, but don't take error conditions into account. If socket buffers are full, because * clients disappear for instance, the socket is just closed. We don't care. * * Truncation attacks seem irrelevant. MQTT is frame based, so either end knows if the transmission is done or not. The * close_notify is not used in determining whether to use or discard the received data. */ SSL_shutdown(d.get()); } FmqSsl::FmqSsl() : d(nullptr, SSL_free) { } FmqSsl::FmqSsl(const SslCtxManager &ssl_ctx) : d(SSL_new(ssl_ctx.get()), SSL_free) { } void FmqSsl::set_fd(int fd) { if (!d) return; SSL_set_fd(d.get(), fd); } ================================================ FILE: fmqssl.h ================================================ #ifndef FMQSSL_H #define FMQSSL_H #include #include #include "sslctxmanager.h" class FmqSsl { std::unique_ptr d; public: FmqSsl(); FmqSsl(const SslCtxManager &ssl_ctx); FmqSsl(FmqSsl &&) = default; ~FmqSsl(); FmqSsl& operator=(FmqSsl&&) = default; operator bool() const { return d != nullptr; } SSL* get() const { return d.get(); } void set_fd(int fd); }; #endif // FMQSSL_H ================================================ FILE: forward_declarations.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef FORWARD_DECLARATIONS_H #define FORWARD_DECLARATIONS_H #include class Client; class ThreadData; class MqttPacket; class SubscriptionStore; class Session; class Settings; class Mqtt5PropertyBuilder; class SessionsAndSubscriptionsDB; #endif // FORWARD_DECLARATIONS_H ================================================ FILE: fuzz-helper.sh ================================================ #!/bin/bash # # Quick 'n dirty Script to build and run FlashMQ with American Fuzzy Lop. thisfile=$(readlink --canonicalize "$0") thisdir=$(dirname "$thisfile") if [[ -z "$AFL_ROOT" ]]; then echo "ERROR: set AFL_ROOT environment variable" exit 1 fi if [[ -z "$FLASHMQ_SRC" ]]; then echo "ERROR: set FLASHMQ_SRC environment variable" exit 1 fi set -u if [[ ! -d "$FLASHMQ_SRC/fuzztests" ]]; then echo "Folder 'fuzztests' not found in '$FLASHMQ_SRC'" exit 1 fi if [[ "$1" == "build" ]]; then export CC="$AFL_ROOT/afl-gcc" export CXX="$AFL_ROOT/afl-g++" mkdir "fuzzbuild" cd "fuzzbuild" || exit 1 "$thisdir/build.sh" Debug if [[ -f "./FlashMQBuildDebug/flashmq" ]]; then cp -v "./FlashMQBuildDebug/flashmq" .. fi fi if [[ "$1" == "run" ]]; then INPUTDIR="$FLASHMQ_SRC/fuzztests" OUTPUTDIR="fuzzoutput" BINARY="./flashmq" if [[ ! -d "$OUTPUTDIR" ]]; then mkdir "$OUTPUTDIR" fi tmux new-session -s flashmqfuzz -d "'$AFL_ROOT/afl-fuzz' -m 200 -M primary -i '$INPUTDIR' -o '$OUTPUTDIR' '$BINARY' --fuzz-file '@@'; sleep 5" tmux split-window -t flashmqfuzz -v "'$AFL_ROOT/afl-fuzz' -m 200 -S secondary01 -i '$INPUTDIR' -o '$OUTPUTDIR' '$BINARY' --fuzz-file '@@'; sleep 5" tmux split-window -t flashmqfuzz -h "'$AFL_ROOT/afl-fuzz' -m 200 -S secondary02 -i '$INPUTDIR' -o '$OUTPUTDIR' '$BINARY' --fuzz-file '@@'; sleep 5" tmux select-pane -t flashmqfuzz -U tmux split-window -t flashmqfuzz -h "'$AFL_ROOT/afl-fuzz' -m 200 -S secondary03 -i '$INPUTDIR' -o '$OUTPUTDIR' '$BINARY' --fuzz-file '@@'; sleep 5" tmux attach-session -d -t flashmqfuzz fi ================================================ FILE: globals.cpp ================================================ #include "globals.h" #include "threadglobals.h" Globals globals; /** * Normally there's 'ThreadGlobals::getThreadData()', but there are places where we can be called * from non-worker-threads, like plugin custom threads. This gives us the worker's thread one, or * a deterministic other one. Determinism is important to ensure order, when publishing for instance. * * Note that there is overhead here, so use only when required. */ CheckedSharedPtr Globals::GlobalsData::getDeterministicThreadData() { CheckedSharedPtr result = ThreadGlobals::getThreadData(); if (result) return result; static thread_local std::weak_ptr thread_data_copy_weak; std::shared_ptr result2 = thread_data_copy_weak.lock(); if (result2) return result2; static std::atomic index {0}; auto locked = threadDatas.lock(); result2 = locked->at(index++ % locked->size()); thread_data_copy_weak = result2; return result2; } ================================================ FILE: globals.h ================================================ #ifndef GLOBALS_H #define GLOBALS_H #include #include #include "subscriptionstore.h" #include "globalstats.h" #include "bridgeconfig.h" #include "checkedsharedptr.h" #include "mutexowned.h" /** * The idea about it being a shared pointer is having globals that are still tied to a MainApp instance (which * should assign a new global object upon creation and destruction). This is mainly for keeping the memory model * between normal FlashMQ and the re-instantiated MainApps in the test program the same, which wouldn't be the * case by when having static variables for globals. */ class Globals { struct GlobalsData { bool quitting = false; pthread_t createdByThread = pthread_self(); std::shared_ptr subscriptionStore = std::make_shared(); GlobalStats stats; BridgeClientGroupIds bridgeClientGroupIds; MutexOwned>> threadDatas; CheckedSharedPtr getDeterministicThreadData(); }; std::shared_ptr data = std::make_shared(); public: GlobalsData *operator->() const { return data.get(); } }; extern Globals globals; #endif // GLOBALS_H ================================================ FILE: globalstats.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "globalstats.h" GlobalStats::GlobalStats() { } void GlobalStats::setExtra(const std::string &topic, const std::string &payload) { auto locked_data = extras.lock(); locked_data->operator[](topic) = payload; } std::unordered_map GlobalStats::getExtras() { auto locked_data = extras.lock(); std::unordered_map r = *locked_data; return r; } ================================================ FILE: globalstats.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef GLOBALSTATS_H #define GLOBALSTATS_H #include #include #include #include "derivablecounter.h" #include "mutexowned.h" class GlobalStats { MutexOwned> extras; public: GlobalStats(); DerivableCounter socketConnects; void setExtra(const std::string &topic, const std::string &payload); std::unordered_map getExtras(); }; #endif // GLOBALSTATS_H ================================================ FILE: globber.cpp ================================================ #include "globber.h" #include #include GlobT::GlobT() { } GlobT::~GlobT() { globfree(&result); } std::vector Globber::getGlob(const std::string &pattern) { GlobT result; const int rc = glob(pattern.c_str(), 0, nullptr, &result.result); if (!(rc == 0 || rc == GLOB_NOMATCH)) { throw std::runtime_error("Glob failed"); } std::vector filenames; for(size_t i = 0; i < result.result.gl_pathc; ++i) { filenames.push_back(std::string(result.result.gl_pathv[i])); } return filenames; } ================================================ FILE: globber.h ================================================ #ifndef GLOBBER_H #define GLOBBER_H #include #include #include class GlobT { friend class Globber; glob_t result {}; public: GlobT(); ~GlobT(); }; class Globber { public: std::vector getGlob(const std::string &pattern); }; #endif // GLOBBER_H ================================================ FILE: haproxy.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "haproxy.h" #include #include #include "utils.h" #include "exceptions.h" /** * @brief read_ha_proxy_helper This function simplyfies reading for HAProxy in that it doesn't need to deal with incomplete data (because HAProxy frames * are always complete), so those cases can just be an error. * @param fd * @param buf * @param nbytes * @return */ size_t read_ha_proxy_helper(int fd, void *buf, size_t nbytes) { if (nbytes == 0) return 0; ssize_t n; while ((n = read(fd, buf, nbytes)) != 0) { if (n == 0) throw BadClientException("Client disconnected before all HA proxy data could be read"); if (n < 0) { if (errno == EINTR) continue; else if (errno == EAGAIN || errno == EWOULDBLOCK) throw BadClientException("Incomplete HAProxy data"); else check(n); } break; } if (static_cast(nbytes) != n) throw BadClientException("Not an HAProxy frame."); return n; } std::optional get_ha_proxy_pp2_field(const std::unordered_map> &fields, int key) { auto pos = fields.find(key); if (pos == fields.end()) return {}; return std::get(pos->second); } HaProxySslData read_ha_proxy_pp2_ssl(const std::vector &data, int &recurse_counter) { HaProxySslData result; size_t pos = 0; result.client = data.at(pos++); for (int shift = 24; shift >= 0; shift -= 8) { result.verify = (data.at(pos++) << shift); } auto sub_vec = make_vector(data, pos, data.size() - pos); const std::unordered_map> fields = read_ha_proxy_pp2_tlv(sub_vec, recurse_counter); result.ssl_version = get_ha_proxy_pp2_field(fields, PP2_SUBTYPE_SSL_VERSION); result.ssl_cn = get_ha_proxy_pp2_field(fields, PP2_SUBTYPE_SSL_CN); if (result.ssl_version) { if (!isValidUtf8Generic(result.ssl_version.value())) throw BadClientException("HAProxy pp2 ssl text fields are not valid UTF8."); } if (result.ssl_cn) { if (!isValidUtf8Generic(result.ssl_cn.value())) throw BadClientException("HAProxy pp2 ssl text fields are not valid UTF8."); } // We don't use this fields (yet), so ignorning for now. Here for future reference. //result.ssl_cipher = get_ha_proxy_pp2_field(fields, PP2_SUBTYPE_SSL_CIPHER); //result.ssl_sig_alg = get_ha_proxy_pp2_field(fields, PP2_SUBTYPE_SSL_SIG_ALG); //result.ssl_key_alg = get_ha_proxy_pp2_field(fields, PP2_SUBTYPE_SSL_KEY_ALG); return result; } /** * See info about Type-Length-Value (TLV vectors in https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt * * Many of the fields aren't actually strings, but since we're only using SSL data right now, we'll just read those other * fields as raw data and not use them. */ std::unordered_map> read_ha_proxy_pp2_tlv(const std::vector &data, int &recurse_counter) { recurse_counter++; if (recurse_counter > 2) throw BadClientException("Invalid HAProxy TLV structure."); size_t pos = 0; std::unordered_map> result; while (pos < data.size()) { const uint8_t type{data.at(pos++)}; const uint8_t length_hi{data.at(pos++)}; const uint8_t length_lo{data.at(pos++)}; const size_t len{static_cast((length_hi << 8) + length_lo)}; if (pos + len > data.size()) throw BadClientException("Client specifies invalid length in haproxy TLV"); if (type == PP2_TYPE_SSL) { auto sub_vec = make_vector(data, pos, len); const auto emplace_result = result.try_emplace(type, read_ha_proxy_pp2_ssl(sub_vec, recurse_counter)); pos += len; if (!emplace_result.second) { std::ostringstream oss; oss << "Client specifies haproxy PP2 field 0x" << std::hex << type << " multiple times."; throw BadClientException(oss.str()); } continue; } const auto emplace_result = result.try_emplace(type, make_string(data, pos, len)); pos += len; if (!emplace_result.second) { std::ostringstream oss; oss << "Client specifies haproxy PP2 field 0x" << std::hex << type << " multiple times."; throw BadClientException(oss.str()); } } if (pos != data.size()) throw BadClientException("Client specifies invalid length in haproxy TLV"); return result; } ================================================ FILE: haproxy.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef HAPROXY_H #define HAPROXY_H #include #include #include #include #include #include #include #include #define PP2_TYPE_ALPN 0x01 #define PP2_TYPE_AUTHORITY 0x02 #define PP2_TYPE_CRC32C 0x03 #define PP2_TYPE_NOOP 0x04 #define PP2_TYPE_UNIQUE_ID 0x05 #define PP2_TYPE_SSL 0x20 #define PP2_SUBTYPE_SSL_VERSION 0x21 #define PP2_SUBTYPE_SSL_CN 0x22 #define PP2_SUBTYPE_SSL_CIPHER 0x23 #define PP2_SUBTYPE_SSL_SIG_ALG 0x24 #define PP2_SUBTYPE_SSL_KEY_ALG 0x25 #define PP2_TYPE_NETNS 0x30 struct proxy_hdr_v2 { uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ uint8_t ver_cmd; /* protocol version and command */ uint8_t fam; /* protocol family and address */ uint16_t len; /* number of following bytes part of the header */ }; /* for TCP/UDP over IPv4, len = 12 */ struct proxy_ipv4_addr { uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port; }; /* for TCP/UDP over IPv6, len = 36 */ struct proxy_ipv6_addr { uint8_t src_addr[16]; uint8_t dst_addr[16]; uint16_t src_port; uint16_t dst_port; }; /* for AF_UNIX sockets, len = 216 */ struct proxy_unix_addr { uint8_t src_addr[108]; uint8_t dst_addr[108]; }; union proxy_addr { struct proxy_ipv4_addr ipv4_addr; struct proxy_ipv6_addr ipv6_addr; struct proxy_unix_addr proxy_unix_addr; }; struct HaProxySslData { uint8_t client = 0; uint32_t verify = 0; std::optional ssl_version; std::optional ssl_cn; std::optional ssl_cipher; std::optional ssl_sig_alg; std::optional ssl_key_alg; }; size_t read_ha_proxy_helper(int fd, void *buf, size_t nbytes); std::unordered_map> read_ha_proxy_pp2_tlv(const std::vector &data, int &recurse_counter); enum class HaProxyConnectionType { None, Remote, Local }; #endif // HAPROXY_H ================================================ FILE: http.cpp ================================================ #include "http.h" #include #include #include #include #include "utils.h" #include "exceptions.h" std::string generateWebsocketAcceptString(const std::string &websocketKey) { unsigned char md_value[EVP_MAX_MD_SIZE]; unsigned int md_len; std::unique_ptr mdctx(EVP_MD_CTX_new(), EVP_MD_CTX_free); const EVP_MD *md = EVP_sha1(); EVP_DigestInit_ex(mdctx.get(), md, NULL); const std::string keyPlusMagic = websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; EVP_DigestUpdate(mdctx.get(), keyPlusMagic.c_str(), keyPlusMagic.length()); EVP_DigestFinal_ex(mdctx.get(), md_value, &md_len); std::string base64 = base64Encode(md_value, md_len); return base64; } std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion) { std::ostringstream oss; oss << "HTTP/1.1 400 Bad Request\r\n"; oss << "Sec-WebSocket-Version: " << wantedVersion; oss << "\r\n"; oss.flush(); return oss.str(); } std::string generateBadHttpRequestReponse(const std::string &msg) { std::ostringstream oss; oss << "HTTP/1.1 400 Bad Request\r\n"; oss << "\r\n"; oss << msg; oss.flush(); return oss.str(); } std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol) { std::ostringstream oss; oss << "HTTP/1.1 101 Switching Protocols\r\n"; oss << "Upgrade: websocket\r\n"; oss << "Connection: Upgrade\r\n"; oss << "Sec-WebSocket-Accept: " << acceptString << "\r\n"; oss << "Sec-WebSocket-Protocol: " << subprotocol << "\r\n"; oss << "\r\n"; oss.flush(); return oss.str(); } std::string generateRedirect(const std::string &location) { const std::string text("Redirecting ACME request."); std::ostringstream oss; oss << "HTTP/1.1 301 Redirect\r\n"; oss << "Content-Type: text/plain\r\n"; oss << "Content-Length: " << text.size() << "\r\n"; oss << "Location: " << location << "\r\n"; oss << "\r\n"; oss << text; oss.flush(); return oss.str(); } std::optional parseHttpHeader(CirBuf &buf) { if (buf.usedBytes() >= 16384) { return BadHttpRequest("Too much data for HTTP request"); } HttpRequest::Data result; std::vector lines; { bool doubleEmptyLine = false; // meaning, the HTTP header is complete const std::vector buf_data = buf.peekAllToVector(); const std::string beginning(buf_data.data(), std::min(4, buf_data.size())); if (buf_data.size() >= 4 && beginning != "GET ") { return BadHttpRequest("HTTP request should start with GET."); } const std::string s(buf_data.data(), buf_data.size()); std::istringstream is(s); for (std::string line; std::getline(is, line);) { trim(line); if (line.empty()) { doubleEmptyLine = true; break; } lines.push_back(line); } if (!doubleEmptyLine) return {}; } bool firstLine = true; for (const std::string &line : lines) { if (firstLine) { firstLine = false; if (!startsWith(line, "GET")) return BadHttpRequest("HTTP request should start with GET."); const std::vector fields = splitToVector(line, ' ', std::numeric_limits::max(), false); if (fields.size() != 3) return BadHttpRequest("HTTP request should include three fields."); result.request = fields.at(1); continue; } std::list fields = split(line, ':', 1); if (fields.size() != 2) { return BadHttpRequest("This does not look like a HTTP request."); } const std::vector fields2(fields.begin(), fields.end()); std::string name = str_tolower(fields2[0]); trim(name); std::string value = fields2[1]; trim(value); std::string value_lower = str_tolower(value); if (name == "upgrade") { std::vector protocols = splitToVector(value_lower, ','); for (std::string &prot : protocols) { trim(prot); if (prot == "websocket") { result.upgrade = true; } } } else if (name == "connection" && strContains(value_lower, "upgrade")) result.connectionUpgrade = true; else if (name == "sec-websocket-key") result.websocketKey = value; else if (name == "sec-websocket-version") result.websocketVersion = stoi(value); else if (name == "sec-websocket-protocol" && strContains(value_lower, "mqtt")) { std::vector protocols = splitToVector(value, ','); for(std::string &prot : protocols) { trim(prot); // Return what is requested, which can be 'mqttv3.1' or 'mqtt', or whatever variant. if (strContains(str_tolower(prot), "mqtt")) { result.subprotocol = prot; } } } else if (name == "x-real-ip" && value.length() < 64) { result.xRealIp = value; } } return result; } std::string websocketCloseCodeToString(uint16_t code) { switch (code) { case 1000: return "Normal websocket close"; case 1001: return "Browser navigating away from page"; default: return formatString("Websocket status code %d", code); } } HttpRequest::HttpRequest(const Data &data) : d(data) { } HttpRequest::HttpRequest(BadHttpRequest &&other) : e(std::move(other)) { } const HttpRequest::Data &HttpRequest::value() { if (e) throw *e; return d; } HttpRequest::operator bool() const { return (!e); } ================================================ FILE: http.h ================================================ #ifndef HTTP_H #define HTTP_H #include #include "cirbuf.h" #include "types.h" #include "exceptions.h" class HttpRequest { public: struct Data { std::string request; bool upgrade = false; bool connectionUpgrade = false; std::optional websocketKey; std::optional websocketVersion; std::optional subprotocol; std::optional xRealIp; }; HttpRequest(const Data &data); HttpRequest(BadHttpRequest &&other); const Data &value(); operator bool() const; private: std::optional e; Data d; }; std::string generateWebsocketAcceptString(const std::string &websocketKey); std::string generateInvalidWebsocketVersionHttpHeaders(const int wantedVersion); std::string generateBadHttpRequestReponse(const std::string &msg); std::string generateWebsocketAnswer(const std::string &acceptString, const std::string &subprotocol); std::string generateRedirect(const std::string &location); std::optional parseHttpHeader(CirBuf &buf); std::string websocketCloseCodeToString(uint16_t code); std::string protocolVersionString(ProtocolVersion p); #endif // HTTP_H ================================================ FILE: iowrapper.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "iowrapper.h" #include #include #include #include #include #include #include "logger.h" #include "client.h" #include "utils.h" #include "exceptions.h" #include "threadglobals.h" #include "settings.h" #include "http.h" IncompleteSslWrite::IncompleteSslWrite(size_t nbytes) : valid(true), nbytes(nbytes) { } void IncompleteSslWrite::reset() { valid = false; nbytes = 0; } bool IncompleteSslWrite::hasPendingWrite() const { return valid; } void IncompleteWebsocketRead::reset() { maskingKeyI = 0; memset(maskingKey,0, 4); frame_bytes_left = 0; opcode = WebsocketOpcode::Unknown; } bool IncompleteWebsocketRead::sillWorkingOnFrame() const { return frame_bytes_left > 0; } char IncompleteWebsocketRead::getNextMaskingByte() { return maskingKey[maskingKeyI++ % 4]; } IncompleteWebsocketRead::IncompleteWebsocketRead() { reset(); } IoWrapper::IoWrapper(FmqSsl &&ssl, ConnectionProtocol connectionProtocol, const size_t initialBufferSize, Client *parent) : parentClient(parent), ssl(std::move(ssl)), connectionProtocol(connectionProtocol), websocketPendingBytes(connectionProtocol == ConnectionProtocol::WebsocketMqtt ? initialBufferSize : 0), websocketWriteRemainder(connectionProtocol == ConnectionProtocol::WebsocketMqtt ? initialBufferSize : 0) { } void IoWrapper::startOrContinueSslHandshake() { if (parentClient->isOutgoingConnection()) startOrContinueSslConnect(); else startOrContinueSslAccept(); } void IoWrapper::startOrContinueSslConnect() { assert(ssl); ERR_clear_error(); int connected = SSL_connect(ssl.get()); if (connected <= 0) { int err = SSL_get_error(ssl.get(), connected); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE); return; } unsigned long error_code = ERR_get_error(); std::array error_buf; ERR_error_string_n(error_code, error_buf.data(), error_buf.size()); std::string errorMsg(error_buf.data()); if (error_code == 0) errorMsg = "Error code was 0. Are you really connecting to a TLS socket?"; // We upgrade the error to WARNING because outgoing connections are under our control and failures // should not be 'normal' in the log, as opposed to incoming clients. throw BadClientException("Problem connecting to SSL socket: " + errorMsg, LOG_WARNING); } parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake sslAccepted = true; } void IoWrapper::startOrContinueSslAccept() { ERR_clear_error(); int accepted = SSL_accept(ssl.get()); if (accepted <= 0) { int err = SSL_get_error(ssl.get(), accepted); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { parentClient->setReadyForWriting(err == SSL_ERROR_WANT_WRITE); return; } unsigned long error_code = ERR_get_error(); std::array error_buf; ERR_error_string_n(error_code, error_buf.data(), error_buf.size()); std::string errorMsg(error_buf.data()); if (error_code == OPENSSL_WRONG_VERSION_NUMBER) errorMsg = "Wrong protocol version number. Probably a non-SSL connection on SSL socket."; throw BadClientException("Problem accepting SSL socket: " + errorMsg); } parentClient->setReadyForWriting(false); // Undo write readiness that may have have happened during SSL handshake sslAccepted = true; } bool IoWrapper::getSslReadWantsWrite() const { return this->sslReadWantsWrite; } bool IoWrapper::getSslWriteWantsRead() const { return sslWriteWantsRead; } bool IoWrapper::isSslAccepted() const { return this->sslAccepted; } bool IoWrapper::isSsl() const { return this->ssl; } static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { SSL *ssl = static_cast(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); const int mode = SSL_get_verify_mode(ssl); if (mode == SSL_VERIFY_NONE) return 1; std::optional err; const int depth = X509_STORE_CTX_get_error_depth(ctx); /* * Explicity catch long chains, to avoid other random chain errors. */ if (depth > 50) { preverify_ok = 0; err = X509_V_ERR_CERT_CHAIN_TOO_LONG; X509_STORE_CTX_set_error(ctx, err.value()); } if (!preverify_ok) { if (!err) err = X509_STORE_CTX_get_error(ctx); std::array buf; const X509 *err_cert = X509_STORE_CTX_get_current_cert(ctx); X509_NAME_oneline(X509_get_subject_name(err_cert), buf.data(), buf.size()); const std::string subject_name(buf.data()); Logger *logger = Logger::getInstance(); X509_NAME_oneline(X509_get_issuer_name(err_cert), buf.data(), buf.size()); const std::string issuer(buf.data()); logger->log(LOG_ERROR) << "X509 verify error. Num=" << err.value() << ". Error: " << X509_verify_cert_error_string(err.value()) << ". Depth=" << depth << ". Subject = " << subject_name << ". Issuer = " << issuer; } return preverify_ok; } void IoWrapper::setSslVerify(int mode, const std::string &hostname) { assert(mode == SSL_VERIFY_PEER || mode == SSL_VERIFY_NONE); if (!ssl) return; SSL_set_hostflags(ssl.get(), X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); if (!hostname.empty()) { if (!SSL_set1_host(ssl.get(), hostname.c_str())) throw std::runtime_error("Failed setting hostname of SSL context."); if (SSL_set_tlsext_host_name(ssl.get(), hostname.c_str()) != 1) throw std::runtime_error("Failed setting SNI hostname of SSL context."); } SSL_set_verify(ssl.get(), mode, verify_callback); } bool IoWrapper::hasPendingWrite() const { return incompleteSslWrite.hasPendingWrite() || websocketWriteRemainder.usedBytes() > 0; } /** * @brief IoWrapper::hasProcessedBufferedBytesToRead needs to be used for when event-based IO (epoll) won't inform you there are pending bytes. * @return * * When the system sockets is readable, epoll will say so. But when you buffer that with SSL and/or websockets, it can happen that you * don't fully read that, because the client buffers are full. So, use this function on 'buffer full' conditions to determine that you * need to grow the buffer and read again. */ bool IoWrapper::hasProcessedBufferedBytesToRead() const { bool result = false; if (ssl) result |= SSL_pending(ssl.get()) > 0; /* * Note that this is tecnhically not 100% correct. If the only bytes are part of a header, doing a read will actually * result in 0 bytes. But, for the intended purpose at time of writing (see git), we can get away with this. */ if (connectionProtocol == ConnectionProtocol::WebsocketMqtt) result |= websocketPendingBytes.usedBytes() > 0; return result; } WebsocketState IoWrapper::getWebsocketState() const { return websocketState; } X509Manager IoWrapper::getPeerCertificate() const { X509Manager result(this->ssl.get()); return result; } const char *IoWrapper::getSslVersion() const { return SSL_get_version(ssl.get()); } /** * @brief IoWrapper::readHaProxyData Reads the PROXY protocol header in one go. It must fit in one TCP segment. * @param fd * @param addr * * https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt */ std::optional IoWrapper::readHaProxyHeader(int fd) { assert(mHaProxyStage == HaProxyStage::HeaderPending); struct proxy_hdr_v2 hdr {}; const size_t read = read_ha_proxy_helper(fd, &hdr, sizeof(hdr)); if (read != sizeof(hdr)) throw std::runtime_error("Reading haproxy header resulted in wrong number of bytes."); if (std::memcmp(hdr.sig, "\r\n\r\n\0\r\nQUIT\n", 12) != 0) throw BadClientException("Invalid HAProxy signature."); const uint8_t family = (hdr.fam & 0xF0) >> 4; const uint8_t prot = hdr.fam & 0x0F; const uint8_t ver = (hdr.ver_cmd & 0xF0) >> 4; const uint8_t cmd = hdr.ver_cmd & 0x0F; const uint16_t total_len{ntohs(hdr.len)}; uint16_t addr_len = 0; if (family == 0) addr_len = 0; else if (family == 1) addr_len = sizeof(proxy_ipv4_addr); else if (family == 2) addr_len = sizeof(proxy_ipv6_addr); else if (family == 3) addr_len = sizeof(proxy_unix_addr); else throw BadClientException("Unsupported haproxy address family"); if (total_len < addr_len) throw BadClientException("Bad HAProxy length: length specified is less than the signaled address length"); else if (total_len == addr_len) mHaProxyStage = HaProxyStage::DoneOrNotNeeded; else { mHaProxyStage = HaProxyStage::AdditionalBytesPending; mHaProxyAdditionalBytesLeft = total_len - addr_len; } std::array addr_block{}; read_ha_proxy_helper(fd, addr_block.data(), addr_len); if (ver != 2) throw BadClientException("Only HAProxy protocol version 2 is supported"); if (cmd == 0) { mHaProxyConnectionType = HaProxyConnectionType::Local; return {}; } if (cmd != 1) throw BadClientException("HAProxy command must be 1"); if (prot == 0) { mHaProxyConnectionType = HaProxyConnectionType::Local; return {}; } mHaProxyConnectionType = HaProxyConnectionType::Remote; if (prot > 2) throw BadClientException("Invalid protocol in HAProxy."); if (family == 1) { struct proxy_ipv4_addr paddr; std::memcpy(&paddr, addr_block.data(), sizeof(paddr)); struct sockaddr_in addr {}; addr.sin_family = AF_INET; addr.sin_addr.s_addr = paddr.src_addr; addr.sin_port = paddr.src_port; return FMQSockaddr(reinterpret_cast(&addr)); } else if (family == 2) { struct proxy_ipv6_addr paddr; std::memcpy(&paddr, addr_block.data(), sizeof(paddr)); struct sockaddr_in6 addr {}; addr.sin6_family = AF_INET6; memcpy(&addr.sin6_addr, paddr.src_addr, sizeof(struct in6_addr)); addr.sin6_port = paddr.src_port; return FMQSockaddr(reinterpret_cast(&addr)); } throw BadClientException("Unsupported haproxy address family"); } void IoWrapper::readHaProxyAdditionalData(int fd) { assert(mHaProxyStage == HaProxyStage::AdditionalBytesPending); while (mHaProxyAdditionalBytesLeft > 0) { std::array buf{}; const size_t readlen{std::min(buf.size(), mHaProxyAdditionalBytesLeft)}; const ssize_t n{read(fd, buf.data(), readlen)}; if (n < 0) { if (errno == EINTR) continue; if (errno == EWOULDBLOCK || errno == EAGAIN) break; check(n); } else if (n == 0) { throw BadClientException("HaProxy client disconnected before additional data could be read"); } FMQ_ENSURE(n <= mHaProxyAdditionalBytesLeft && static_cast(n) <= buf.size()); const size_t oldSize{mHaProxyAdditionalData.size()}; mHaProxyAdditionalData.resize(oldSize + n); std::copy(buf.begin(), buf.begin() + n, mHaProxyAdditionalData.begin() + oldSize); mHaProxyAdditionalBytesLeft -= n; } if (mHaProxyAdditionalBytesLeft == 0) { mHaProxyStage = HaProxyStage::DoneOrNotNeeded; int recurse_counter = 0; const auto haproxy_tlvs = read_ha_proxy_pp2_tlv(this->mHaProxyAdditionalData, recurse_counter); std::vector().swap(this->mHaProxyAdditionalData); auto pos = haproxy_tlvs.find(PP2_TYPE_SSL); if (pos != haproxy_tlvs.end()) { HaProxySslData data = std::get(pos->second); // At this point, we don't use anything but the CN. mHaProxySslCnName = data.ssl_cn; } } } void IoWrapper::setHaProxy() { this->mHaProxyStage = HaProxyStage::HeaderPending; } #ifndef NDEBUG /** * @brief IoWrapper::setFakeUpgraded marks this wrapper as upgraded. This is for fuzzing, so to bypass the sha1 protected handshake * of websockets. */ void IoWrapper::setFakeUpgraded() { websocketState = WebsocketState::Upgraded; } #endif /** * @brief SSL and non-SSL sockets behave differently. For one, reading 0 doesn't mean 'disconnected' with an SSL * socket. This wrapper unifies behavor for the caller. * * @param fd * @param buf * @param nbytes * @param error is an out-argument with the result. * @return */ ssize_t IoWrapper::readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error) { *error = IoWrapResult::Success; if (!ssl) { ssize_t n = read(fd, buf, nbytes); if (n < 0) { if (errno == EINTR) *error = IoWrapResult::Interrupted; else if (errno == EAGAIN || errno == EWOULDBLOCK) *error = IoWrapResult::Wouldblock; else check(n); } else if (n == 0) { *error = IoWrapResult::Disconnected; } return n; } else { this->sslReadWantsWrite = false; ERR_clear_error(); ssize_t n = SSL_read(ssl.get(), buf, nbytes); if (n > 0) return n; int err = SSL_get_error(ssl.get(), n); unsigned long error_code = ERR_get_error(); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { *error = IoWrapResult::Wouldblock; if (err == SSL_ERROR_WANT_WRITE) { sslReadWantsWrite = true; parentClient->setReadyForWriting(true); } return -1; } if (err == SSL_ERROR_ZERO_RETURN) { parentClient->setDisconnectReason("SSL socket close with close_notify"); *error = IoWrapResult::Disconnected; return 0; } #if OPENSSL_VERSION_NUMBER >= 0x30000000L if (err == SSL_ERROR_SSL && ERR_GET_REASON(error_code) == SSL_R_UNEXPECTED_EOF_WHILE_READING) { parentClient->setDisconnectReason("SSL socket close without close_notify"); *error = IoWrapResult::Disconnected; return 0; } #endif if (err == SSL_ERROR_SYSCALL && errno == 0) { // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html "BUGS" why unexpected EOF is seen as SSL_ERROR_SYSCALL. *error = IoWrapResult::Disconnected; parentClient->setDisconnectReason("SSL<3 socket close without close_notify"); return 0; } if (err == SSL_ERROR_SYSCALL) { // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it // implies EINTR is not included? Also, we use non-blocking sockets, which don't // return EINTR. if (errno == EINTR) { *error = IoWrapResult::Interrupted; return -1; } char *err = strerror(errno); std::string msg(err); throw BadClientException("SSL read syscall error: " + msg); } std::array error_buf; ERR_error_string_n(error_code, error_buf.data(), error_buf.size()); const std::string errorString(error_buf.data()); ERR_print_errors_cb(logSslError, NULL); throw BadClientException("SSL socket error reading: " + errorString); } } // SSL and non-SSL sockets behave differently. This wrapper unifies behavor for the caller. ssize_t IoWrapper::writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error) { *error = IoWrapResult::Success; ssize_t n = 0; if (!ssl) { // A write on a socket with count=0 is unspecified. assert(nbytes > 0); n = write(fd, buf, nbytes); if (n < 0) { if (errno == EINTR) *error = IoWrapResult::Interrupted; else if (errno == EAGAIN || errno == EWOULDBLOCK) *error = IoWrapResult::Wouldblock; else check(n); } } else { size_t nbytes_ = nbytes; /* * OpenSSL doc: When a write function call has to be repeated because SSL_get_error(3) returned * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments */ if (this->incompleteSslWrite.hasPendingWrite()) { nbytes_ = this->incompleteSslWrite.nbytes; } // OpenSSL: "You should not call SSL_write() with num=0, it will return an error" assert(nbytes_ > 0); this->sslWriteWantsRead = false; this->incompleteSslWrite.reset(); ERR_clear_error(); n = SSL_write(ssl.get(), buf, nbytes_); if (n <= 0) { int err = SSL_get_error(ssl.get(), n); unsigned long error_code = ERR_get_error(); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { logger->logf(LOG_DEBUG, "SSL Write is incomplete: %d. Will be retried later.", err); *error = IoWrapResult::Wouldblock; IncompleteSslWrite sslAction(nbytes_); this->incompleteSslWrite = sslAction; if (err == SSL_ERROR_WANT_READ) this->sslWriteWantsRead = true; n = 0; } else { if (err == SSL_ERROR_SYSCALL) { // I don't actually know if OpenSSL hides this or passes EINTR on. The docs say // 'Some non-recoverable, fatal I/O error occurred' for SSL_ERROR_SYSCALL, so it // implies EINTR is not included? if (errno == EINTR) *error = IoWrapResult::Interrupted; else { char *err = strerror(errno); std::string msg(err); throw BadClientException(msg); } } std::array error_buf; ERR_error_string_n(error_code, error_buf.data(), error_buf.size()); const std::string errorString(error_buf.data()); ERR_print_errors_cb(logSslError, NULL); throw BadClientException("SSL socket error writing: " + errorString); } } } return n; } /** * @brief Read the fd into buf. For websockets, reads the fd into an intermediate buffer and decodes the result to buf. MQTT is already a frames * protocol, so we don't care about websocket frames being incomplete. * @param fd * @param buf * @param nbytes * @param error May still be set when bytes are written. * @return number of bytes read, despite also possibly having an error. This is possible when there was still buffered data when you called it. * * Because of its buffered nature, it can legitimately return a number of bytes read AND set an error. So, the * interface is somewhat different from normal 'read()' syscalls. */ ssize_t IoWrapper::readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error) { if (connectionProtocol < ConnectionProtocol::WebsocketMqtt) { return readOrSslRead(fd, buf, nbytes, error); } ssize_t n = 0; while (websocketPendingBytes.freeSpace() > 0 && (n = readOrSslRead(fd, websocketPendingBytes.headPtr(), websocketPendingBytes.maxWriteSize(), error)) != 0) { if (n > 0) websocketPendingBytes.advanceHead(n); else if (n < 0) { if (*error == IoWrapResult::Interrupted) return n; break; // On other errors, like 'would block' it depends on our 'pending bytes' what we will tell the caller. } // Make sure we either always have enough space for a next loop iteration, or stop reading the fd. if (websocketPendingBytes.freeSpace() == 0) { if (websocketState == WebsocketState::NotUpgraded) { if (websocketPendingBytes.getCapacity() * 2 <= 8192) websocketPendingBytes.doubleCapacity(); else throw BadClientException("Trying to exceed websocket buffer. Probably not valid websocket traffic."); } else { const Settings *settings = ThreadGlobals::getSettings(); if (websocketPendingBytes.getCapacity() * 2 <= settings->clientMaxWriteBufferSize) { websocketPendingBytes.doubleCapacity(); } else { break; } } } } const bool hasWebsocketPendingBytes = websocketPendingBytes.usedBytes() > 0; // When some or all the data has been read, we can continue. if (!(*error == IoWrapResult::Wouldblock || *error == IoWrapResult::Success) && !hasWebsocketPendingBytes) return n; // TODO: I guess because of the error condition check this is never > 0? It needs a bit of clarification. // This should never happen, because of the above check, but I'm leaving it in just in case. if (!hasWebsocketPendingBytes) return n; if (websocketState == WebsocketState::NotUpgraded) { try { std::optional req = parseHttpHeader(websocketPendingBytes); if (!req) return 0; const HttpRequest::Data &req_data = req.value().value(); if (parentClient->getAcmeRedirectUrl() && startsWith(req_data.request, "/.well-known/acme-challenge/")) { parentClient->respondWithRedirectURL(req_data.request); return 0; } if (!req_data.connectionUpgrade || !req_data.upgrade) throw BadHttpRequest("HTTP request is not a websocket upgrade request."); if (!req_data.subprotocol) throw BadHttpRequest("HTTP header Sec-WebSocket-Protocol with value 'mqtt' must be present."); if (!req_data.websocketKey) throw BadHttpRequest("No websocket key specified."); if (req_data.websocketVersion != 13) throw BadWebsocketVersionException("Websocket version 13 required."); const std::string acceptString = generateWebsocketAcceptString(req_data.websocketKey.value()); const Settings *settings = ThreadGlobals::getSettings(); const size_t initialBufferSize = settings->clientInitialBufferSize; std::string answer = generateWebsocketAnswer(acceptString, req_data.subprotocol.value()); parentClient->writeText(answer); websocketState = WebsocketState::Upgrading; websocketPendingBytes.reset(); websocketPendingBytes.resetCapacity(initialBufferSize); *error = IoWrapResult::Success; if (req_data.xRealIp) parentClient->setAddr(req_data.xRealIp.value()); return 0; } catch (BadWebsocketVersionException &ex) { std::string response = generateInvalidWebsocketVersionHttpHeaders(13); parentClient->writeText(response); parentClient->setDisconnectReason("Invalid websocket version"); parentClient->setDisconnectStage(DisconnectStage::SendPendingAppData); } catch (BadHttpRequest &ex) // Should should also properly deal with attempt at HTTP2 with PRI. { std::string response = generateBadHttpRequestReponse(ex.what()); parentClient->writeText(response); const std::string reason = formatString("Invalid websocket start: %s", ex.what()); parentClient->setDisconnectReason(reason); parentClient->setDisconnectStage(DisconnectStage::SendPendingAppData); } return 0; } n = websocketBytesToReadBuffer(buf, nbytes, error); if (*error != IoWrapResult::Disconnected) *error = websocketPendingBytes.usedBytes() == 0 ? IoWrapResult::Wouldblock : IoWrapResult::WantRead; return n; } /** * @brief IoWrapper::websocketBytesToReadBuffer takes the payload from a websocket packet and puts it in the 'normal' read buffer, the * buffer that contains the MQTT bytes. * @param buf * @param nbytes * @return the number of bytes read. Can be 0 when only websocket meta or empty frames are processed. * * When websocketPendingBytes still has bytes when we return, the following could have happened: * - We don't have enough bytes to know how long the frame is. * - On ping/close, we don't have enough bytes of the frame. */ ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes, IoWrapResult *error) { size_t nbytesRead = 0; int iter_count = 0; auto log_spin = [&](const std::string &id) { logger->log(LOG_ERR) << std::boolalpha << "Websocket spin loop in " << id << " detected. Please report at https://github.com/halfgaar/FlashMQ. Variables: " << "usedBytes=" << websocketPendingBytes.usedBytes() << ". nbytesRead=" << nbytesRead << ". nbytes=" << nbytes << ". sillWorkingOnFrame=" << incompleteWebsocketRead.sillWorkingOnFrame() << ". " << "frameBytesLeft=" << incompleteWebsocketRead.frame_bytes_left << ". opcode=" << std::hex << static_cast(incompleteWebsocketRead.opcode) << "."; throw std::runtime_error("Websocket spin loop detected. Please report at https://github.com/halfgaar/FlashMQ with log."); }; while (websocketPendingBytes.usedBytes() > 0 && nbytesRead < nbytes) { if (iter_count++ >= 1000000) log_spin("A"); // This block decodes the header. if (!incompleteWebsocketRead.sillWorkingOnFrame()) { if (websocketPendingBytes.usedBytes() < WEBSOCKET_MIN_HEADER_BYTES_NEEDED) break; const uint8_t byte1 = websocketPendingBytes.peakAhead(0); const uint8_t byte2 = websocketPendingBytes.peakAhead(1); bool masked = !!(byte2 & 0b10000000); uint8_t reserved = (byte1 & 0b01110000) >> 4; WebsocketOpcode opcode = (WebsocketOpcode)(byte1 & 0b00001111); const uint8_t payloadLength = byte2 & 0b01111111; size_t realPayloadLength = payloadLength; uint64_t extendedPayloadLengthLength = 0; uint8_t headerLength = masked ? 6 : 2; if (payloadLength == 126) extendedPayloadLengthLength = 2; else if (payloadLength == 127) extendedPayloadLengthLength = 8; headerLength += extendedPayloadLengthLength; //if (!masked) // throw BadClientException("Client must send masked websocket bytes."); if (reserved != 0) throw BadClientException("Reserved bytes in header must be 0."); if (headerLength > websocketPendingBytes.usedBytes()) return nbytesRead; uint64_t extendedPayloadLength = 0; int i = 2; int shift = extendedPayloadLengthLength * 8; while (shift > 0) { shift -= 8; uint64_t byte {static_cast(websocketPendingBytes.peakAhead(i++))}; extendedPayloadLength += (byte << shift); } if (extendedPayloadLength > 0) realPayloadLength = extendedPayloadLength; if (headerLength > websocketPendingBytes.usedBytes()) return nbytesRead; if (masked) { for (int j = 0; j < 4; j++) { incompleteWebsocketRead.maskingKey[j] = websocketPendingBytes.peakAhead(i++); } } assert(i == headerLength); assert(headerLength <= websocketPendingBytes.usedBytes()); websocketPendingBytes.advanceTail(headerLength); incompleteWebsocketRead.frame_bytes_left = realPayloadLength; incompleteWebsocketRead.opcode = opcode; } if (incompleteWebsocketRead.opcode == WebsocketOpcode::Binary) { // The following reads one websocket frame max: it will continue with the previous, or start a new one, which it may or may not finish. const size_t frame_bytes_left = std::min(websocketPendingBytes.usedBytes(), incompleteWebsocketRead.frame_bytes_left); const size_t max_read_size = std::min(frame_bytes_left, nbytes - nbytesRead); FMQ_ENSURE(max_read_size + nbytesRead <= nbytes); char *offset_in_buf = &static_cast(buf)[nbytesRead]; for (size_t x = 0; x < max_read_size; x++) { offset_in_buf[x] = websocketPendingBytes.peakAhead(x) ^ incompleteWebsocketRead.getNextMaskingByte(); } websocketPendingBytes.advanceTail(max_read_size); incompleteWebsocketRead.frame_bytes_left -= max_read_size; nbytesRead += max_read_size; } else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Ping) { // A ping MAY have user data, which needs to be ponged back. Pings contain no MQTT data, so nbytesRead is // not touched, nor are we writing to the client's MQTT buffer. const Settings *settings = ThreadGlobals::getSettings(); // Because these internal websocket frames don't contain bytes for the client, we need to allow them to fit // fully in websocketPendingBytes, otherwise you can get stuck. if (incompleteWebsocketRead.frame_bytes_left > (settings->clientMaxWriteBufferSize / 2)) throw BadClientException("The option 'client_max_write_buffer_size / 2' is lower than the ping frame we're are supposed to pong back. Abusing client?"); if (incompleteWebsocketRead.frame_bytes_left > websocketPendingBytes.usedBytes()) break; logger->logf(LOG_DEBUG, "Ponging websocket"); std::vector masked_payload = websocketPendingBytes.readToVector(incompleteWebsocketRead.frame_bytes_left); for (size_t i = 0; i < masked_payload.size(); i++) { masked_payload.at(i) = masked_payload.at(i) ^ incompleteWebsocketRead.getNextMaskingByte(); } websocketWriteRemainder.ensureFreeSpace(masked_payload.size() + WEBSOCKET_MAX_SENDING_HEADER_SIZE); writeAsMuchOfBufAsWebsocketFrame(masked_payload.data(), masked_payload.size(), WebsocketOpcode::Pong); parentClient->setReadyForWriting(true); incompleteWebsocketRead.frame_bytes_left -= masked_payload.size(); } else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Pong) { // See ping comments const Settings *settings = ThreadGlobals::getSettings(); if (incompleteWebsocketRead.frame_bytes_left > (settings->clientMaxWriteBufferSize / 2)) throw BadClientException("The option 'client_max_write_buffer_size / 2' is lower than the pong frame we're getting. Abusing client?"); if (incompleteWebsocketRead.frame_bytes_left > websocketPendingBytes.usedBytes()) break; std::vector payload = websocketPendingBytes.readToVector(incompleteWebsocketRead.frame_bytes_left); logger->log(LOG_DEBUG) << "Received websocket pong (with length " << payload.size() << ")? We never sent a ping..."; incompleteWebsocketRead.frame_bytes_left -= payload.size(); } else if (incompleteWebsocketRead.opcode == WebsocketOpcode::Close) { const Settings *settings = ThreadGlobals::getSettings(); // Because these internal websocket frames don't contain bytes for the client, we need to allow them to fit // fully in websocketPendingBytes, otherwise you can get stuck. if (incompleteWebsocketRead.frame_bytes_left > (settings->clientMaxWriteBufferSize / 2)) throw BadClientException("Websocket close frame is too big."); if (incompleteWebsocketRead.frame_bytes_left > websocketPendingBytes.usedBytes()) break; std::string websocketCloseString = "Websocket close without reason code"; if (incompleteWebsocketRead.frame_bytes_left >= 2) { // If there is payload, MUST be a 2-byte unsigned integer (in network byte order) representing a status code with value /code/ defined const uint8_t msb = websocketPendingBytes.peakAhead(0) ^ incompleteWebsocketRead.getNextMaskingByte(); const uint8_t lsb = websocketPendingBytes.peakAhead(1) ^ incompleteWebsocketRead.getNextMaskingByte(); websocketPendingBytes.advanceTail(2); const uint16_t code = msb << 8 | lsb; websocketCloseString = websocketCloseCodeToString(code); } // An actual MQTT disconnect doesn't send websocket close frames, or perhaps after the MQTT // disconnect when it doesn't matter anymore. So, when users close the tab or stuff like that, // we can consider it a closed transport i.e. failed connection. This means will messages // will be sent. parentClient->setDisconnectReason(websocketCloseString); *error = IoWrapResult::Disconnected; // There may be a UTF8 string with a reason in the packet still, but ignoring that for now. incompleteWebsocketRead.reset(); websocketPendingBytes.reset(); } else { // Specs: "MQTT Control Packets MUST be sent in WebSocket binary data frames. If any other type of data frame is // received the recipient MUST close the Network Connection [MQTT-6.0.0-1]". std::ostringstream opcode_oss; opcode_oss << "Unsupported websocket frame type: " << std::hex << static_cast(incompleteWebsocketRead.opcode); throw BadClientException(opcode_oss.str()); } if (!incompleteWebsocketRead.sillWorkingOnFrame()) incompleteWebsocketRead.reset(); } assert(nbytesRead <= nbytes); return nbytesRead; } /** * @brief IoWrapper::writeAsMuchOfBufAsWebsocketFrame writes buf or part of buf as websocket frame to websocketWriteRemainder * @param buf * @param nbytes. The amount of bytes. Can be 0, for just an empty websocket frame. * @return */ ssize_t IoWrapper::writeAsMuchOfBufAsWebsocketFrame(const void *buf, const size_t nbytes, WebsocketOpcode opcode) { // We do allow pong frames to generate a zero payload packet, but for binary, that's not necessary. if (nbytes == 0 && opcode == WebsocketOpcode::Binary) return 0; const Settings *settings = ThreadGlobals::getSettings(); websocketWriteRemainder.ensureFreeSpace(nbytes + WEBSOCKET_MAX_SENDING_HEADER_SIZE, settings->clientMaxWriteBufferSize); const uint32_t bytesFree = websocketWriteRemainder.freeSpace(); const size_t bodyBytesAvailable = bytesFree < WEBSOCKET_MAX_SENDING_HEADER_SIZE ? 0 : bytesFree - WEBSOCKET_MAX_SENDING_HEADER_SIZE; const ssize_t nBytesReal = std::min(nbytes, bodyBytesAvailable); // We normally wrap each write in a frame, but if a previous one didn't fit in the system's write buffers, we're still working on it. if (websocketWriteRemainder.freeSpace() > WEBSOCKET_MAX_SENDING_HEADER_SIZE) { uint8_t extended_payload_length_num_bytes = 0; uint8_t payload_length = 0; if (nBytesReal < 126) payload_length = nBytesReal; else if (nBytesReal >= 126 && nBytesReal <= 0xFFFF) { payload_length = 126; extended_payload_length_num_bytes = 2; } else if (nBytesReal > 0xFFFF) { payload_length = 127; extended_payload_length_num_bytes = 8; } int x = 0; std::array header; header[x++] = (0b10000000 | static_cast(opcode)); header[x++] = payload_length; const int header_length = x + extended_payload_length_num_bytes; // This block writes the extended payload length. const uint64_t nbytes64 = nBytesReal; for (int z = extended_payload_length_num_bytes - 1; z >= 0; z--) { header[x++] = (nbytes64 >> (z*8)) & 0xFF; } assert(x <= WEBSOCKET_MAX_SENDING_HEADER_SIZE); assert(x == header_length); websocketWriteRemainder.writerange(header.begin(), header.begin() + header_length); websocketWriteRemainder.write(buf, nBytesReal); } return nBytesReal; } /** * @brief write the buffer to the fd, potentially as websocket frame. * @param fd * @param buf * @param nbytes * @param error May still be set when bytes are written. * @return number of bytes written, despite also possibly having an error. * * Because of its buffered nature (for websockets), it can legitimately return a number of bytes written AND set an * error. So, the interface is somewhat different from normal 'write()' syscalls. * * This also means there is no need to do that repeating of the write thing that SSL_write() has when there is * still buffered data. Just obey the 'wouldblock' error. * * Mqtt docs: "A single WebSocket data frame can contain multiple or partial MQTT Control Packets. The receiver * MUST NOT assume that MQTT Control Packets are aligned on WebSocket frame boundaries [MQTT-6.0.0-2]." We * make use of that here, and wrap each write in a frame. */ ssize_t IoWrapper::writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error) { if (websocketState != WebsocketState::Upgraded) { if (connectionProtocol == ConnectionProtocol::WebsocketMqtt && websocketState == WebsocketState::Upgrading) websocketState = WebsocketState::Upgraded; return writeOrSslWrite(fd, buf, nbytes, error); } else { const ssize_t nBytesReal = writeAsMuchOfBufAsWebsocketFrame(buf, nbytes); while (websocketWriteRemainder.usedBytes() > 0) { const ssize_t n = writeOrSslWrite(fd, websocketWriteRemainder.tailPtr(), websocketWriteRemainder.maxReadSize(), error); if (n > 0) websocketWriteRemainder.advanceTail(n); else break; } return nBytesReal; } } void IoWrapper::resetBuffersIfEligible() { const Settings *settings = ThreadGlobals::getSettings(); const size_t initialBufferSize = settings->clientInitialBufferSize; const size_t sz = connectionProtocol == ConnectionProtocol::WebsocketMqtt ? initialBufferSize : 0; websocketPendingBytes.resetCapacityIfEligable(sz); websocketWriteRemainder.resetCapacityIfEligable(sz); } ================================================ FILE: iowrapper.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef IOWRAPPER_H #define IOWRAPPER_H #include #include #include #include #include "forward_declarations.h" #include "logger.h" #include "haproxy.h" #include "cirbuf.h" #include "x509manager.h" #include "fmqsockaddr.h" #include "enums.h" #include "fmqssl.h" #define WEBSOCKET_MIN_HEADER_BYTES_NEEDED 2 #define WEBSOCKET_MAX_SENDING_HEADER_SIZE 10 #define OPENSSL_ERROR_STRING_SIZE 256 // OpenSSL requires at least 256. #define OPENSSL_WRONG_VERSION_NUMBER 336130315 enum class IoWrapResult { Success = 0, Interrupted = 1, Wouldblock = 2, Disconnected = 3, Error = 4, WantRead = 5 }; enum class WebsocketOpcode { Continuation = 0x00, Text = 0x1, Binary = 0x2, Close = 0x8, Ping = 0x9, Pong = 0xA, Unknown = 0xF }; /** * @brief The IncompleteSslWrite struct facilities the SSL retry * * OpenSSL doc: "When a write function call has to be repeated because SSL_get_error(3) returned * SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, it must be repeated with the same arguments" * * Note that we use SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER. */ struct IncompleteSslWrite { bool valid = false; size_t nbytes = 0; IncompleteSslWrite() = default; IncompleteSslWrite(size_t nbytes); bool hasPendingWrite() const; void reset(); }; struct IncompleteWebsocketRead { size_t frame_bytes_left = 0; char maskingKey[4]; unsigned int maskingKeyI = 0; WebsocketOpcode opcode; void reset(); bool sillWorkingOnFrame() const; char getNextMaskingByte(); IncompleteWebsocketRead(); }; enum class WebsocketState { NotUpgraded, Upgrading, Upgraded }; enum class HaProxyStage { HeaderPending, AdditionalBytesPending, DoneOrNotNeeded }; /** * @brief provides a unified wrapper for SSL and websockets to read() and write(). * * */ class IoWrapper { Client *parentClient; FmqSsl ssl; bool sslAccepted = false; IncompleteSslWrite incompleteSslWrite; bool sslReadWantsWrite = false; bool sslWriteWantsRead = false; ConnectionProtocol connectionProtocol; WebsocketState websocketState = WebsocketState::NotUpgraded; CirBuf websocketPendingBytes; IncompleteWebsocketRead incompleteWebsocketRead; CirBuf websocketWriteRemainder; uint16_t mHaProxyAdditionalBytesLeft = 0; HaProxyStage mHaProxyStage = HaProxyStage::DoneOrNotNeeded; HaProxyConnectionType mHaProxyConnectionType = HaProxyConnectionType::None; std::vector mHaProxyAdditionalData; std::optional mHaProxySslCnName; Logger *logger = Logger::getInstance(); ssize_t websocketBytesToReadBuffer(void *buf, const size_t nbytes, IoWrapResult *error); ssize_t readOrSslRead(int fd, void *buf, size_t nbytes, IoWrapResult *error); ssize_t writeOrSslWrite(int fd, const void *buf, size_t nbytes, IoWrapResult *error); ssize_t writeAsMuchOfBufAsWebsocketFrame(const void *buf, const size_t nbytes, WebsocketOpcode opcode = WebsocketOpcode::Binary); void startOrContinueSslConnect(); void startOrContinueSslAccept(); public: IoWrapper(FmqSsl &&ssl, ConnectionProtocol connectionProtocol, const size_t initialBufferSize, Client *parent); ~IoWrapper() = default; void startOrContinueSslHandshake(); bool getSslReadWantsWrite() const; bool getSslWriteWantsRead() const; bool isSslAccepted() const; bool isSsl() const; void setSslVerify(int mode, const std::string &hostname); bool hasPendingWrite() const; bool hasProcessedBufferedBytesToRead() const; ConnectionProtocol getConnectionProtocol() const { return this->connectionProtocol; }; WebsocketState getWebsocketState() const; X509Manager getPeerCertificate() const; const char *getSslVersion() const; const std::optional &getHaProxySslCnName() const { return mHaProxySslCnName; } HaProxyStage getHaProxyStage() const { return mHaProxyStage; } HaProxyConnectionType getHaProxyConnectionType() const { return mHaProxyConnectionType; } std::optional readHaProxyHeader(int fd); void readHaProxyAdditionalData(int fd); void setHaProxy(); #ifndef NDEBUG void setFakeUpgraded(); #endif ssize_t readWebsocketAndOrSsl(int fd, void *buf, size_t nbytes, IoWrapResult *error); ssize_t writeWebsocketAndOrSsl(int fd, const void *buf, size_t nbytes, IoWrapResult *error); void resetBuffersIfEligible(); }; #endif // IOWRAPPER_H ================================================ FILE: listener.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "listener.h" #include "utils.h" #include "exceptions.h" #include "logger.h" #include "configfileparser.h" void Listener::isValid() { if (isSsl()) { if (port == 0) { if (connectionProtocol == ConnectionProtocol::WebsocketMqtt) port = 4443; else port = 8883; } if (acmeRedirectURL) { throw ConfigFileException("An SSL listener can't have an acme_redirect_url."); } if (!dropListener()) { ConfigFileParser::checkFileExistsAndReadable("SSL fullchain", sslFullchain, 1024*1024); ConfigFileParser::checkFileExistsAndReadable("SSL privkey", sslPrivkey, 1024*1024); testSsl(sslFullchain, sslPrivkey); } testSslVerifyLocations(clientVerificationCaFile, clientVerificationCaDir, "Loading client_verification_ca_dir/client_verification_ca_file failed."); } else { if (port == 0) { if (connectionProtocol == ConnectionProtocol::AcmeOnly) port = 80; else if (connectionProtocol == ConnectionProtocol::WebsocketMqtt) port = 8080; else port = 1883; } if (dropOnAbsentCertificates) { throw ConfigFileException("Using drop_on_absent_certificate is only valid on SSL listeners; define privkey and fullchain."); } } if (protocol < ListenerProtocol::Unix && !unixSocketPath.empty()) { throw ConfigFileException("Specifying 'unix_socket_path' for IP listeners is not allowed."); } if (protocol == ListenerProtocol::Unix) { if (unixSocketPath.empty()) throw ConfigFileException("Option 'unix_socket_path' must be set for unix socket listeners."); if (!inet4BindAddress.empty() || !inet6BindAddress.empty()) throw ConfigFileException("Specifying inet bind addresses is not allowed for unix socket listeners."); if (isSsl()) throw ConfigFileException("TLS on domain sockets is not supported."); if (acmeRedirectURL) throw ConfigFileException("ACME redirect is not support on unix sockets"); } if ((!clientVerificationCaDir.empty() || !clientVerificationCaFile.empty()) && !isSsl()) { throw ConfigFileException("X509 client verification can only be done on TLS listeners."); } if (port <= 0 || port > 65534) { throw ConfigFileException(formatString("Port nr %d is not valid", port)); } if (connectionProtocol == ConnectionProtocol::AcmeOnly && !acmeRedirectURL) { throw ConfigFileException("An ACME listener needs to have an acme_redirect_url"); } if (getX509ClientVerficationMode() > X509ClientVerification::None && haProxyMode >= HaProxyMode::HaProxyClientVerification) { throw ConfigFileException("Client verification can't be done by both FlashMQ and HAProxy."); } } bool Listener::isSsl() const { return (!sslFullchain.empty() || !sslPrivkey.empty()); } bool Listener::isTcpNoDelay() const { return this->tcpNoDelay; } std::string Listener::getProtocolName() const { if (protocol == ListenerProtocol::Unix) { return "unix socket"; } if (connectionProtocol == ConnectionProtocol::AcmeOnly) { return "ACME-only"; } else if (isSsl()) { if (connectionProtocol == ConnectionProtocol::WebsocketMqtt) return "SSL websocket"; else return "SSL TCP"; } else { std::string answer; if (connectionProtocol == ConnectionProtocol::WebsocketMqtt) answer = "non-SSL websocket"; else answer = "non-SSL TCP"; if (acmeRedirectURL) answer.append(" with ACME redirect to ").append(acmeRedirectURL.value()); return answer; } return "whoops"; } void Listener::loadCertAndKeyFromConfig() { if (!isSsl()) return; if (!sslctx) { sslctx.emplace(); sslctx->setMinimumTlsVersion(minimumTlsVersion); SSL_CTX_set_mode(sslctx->get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); /* * Session cache requires active shutdown of SSL connections, which we don't have right now. We * might as well just turn the session cache off, at least until we do have session shutdown. */ SSL_CTX_set_session_cache_mode(sslctx->get(), SSL_SESS_CACHE_OFF); } if (SSL_CTX_use_certificate_chain_file(sslctx->get(), sslFullchain.c_str()) != 1) throw std::runtime_error("Loading cert failed. This was after test loading the certificate, so is very unexpected."); if (SSL_CTX_use_PrivateKey_file(sslctx->get(), sslPrivkey.c_str(), SSL_FILETYPE_PEM) != 1) throw std::runtime_error("Loading key failed. This was after test loading the certificate, so is very unexpected."); { const char *ca_file = clientVerificationCaFile.empty() ? nullptr : clientVerificationCaFile.c_str(); const char *ca_dir = clientVerificationCaDir.empty() ? nullptr : clientVerificationCaDir.c_str(); if (ca_file || ca_dir) { if (SSL_CTX_load_verify_locations(sslctx->get(), ca_file, ca_dir) != 1) { ERR_print_errors_cb(logSslError, NULL); throw std::runtime_error("Loading client_verification_ca_dir/client_verification_ca_file failed. " "This was after test loading the certificate, so is very unexpected."); } } } } X509ClientVerification Listener::getX509ClientVerficationMode() const { X509ClientVerification result = X509ClientVerification::None; const bool clientCADefined = !clientVerificationCaDir.empty() || !clientVerificationCaFile.empty(); if (clientCADefined) result = X509ClientVerification::X509IsEnough; if (result >= X509ClientVerification::X509IsEnough && clientVerifictionStillDoAuthn) result = X509ClientVerification::X509AndUsernamePassword; return result; } bool Listener::dropListener() const { if (!dropOnAbsentCertificates || !isSsl()) return false; return access(sslPrivkey.c_str(), R_OK) != 0 && access(sslFullchain.c_str(), R_OK) != 0; } std::string Listener::getBindAddress(ListenerProtocol p) { if (p == ListenerProtocol::IPv4) { if (inet4BindAddress.empty()) return "0.0.0.0"; return inet4BindAddress; } if (p == ListenerProtocol::IPv6) { if (inet6BindAddress.empty()) return "::"; return inet6BindAddress; } if (p == ListenerProtocol::Unix) { if (unixSocketPath.empty()) throw std::runtime_error("Listener's unix socket path is empty."); return unixSocketPath; } return ""; } bool Listener::isAllowed(const sockaddr *addr) const { if (!denyList.empty() && std::any_of(denyList.begin(), denyList.end(), [=](const Network &n) { return n.match(addr);})) return false; if (exclusiveAllowList.empty()) return true; return std::any_of(exclusiveAllowList.begin(), exclusiveAllowList.end(), [=](const Network &n) { return n.match(addr);}); } ================================================ FILE: listener.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef LISTENER_H #define LISTENER_H #include #include #include #include "sslctxmanager.h" #include "enums.h" #include "network.h" enum class ListenerProtocol { IPv46, IPv4, IPv6, Unix }; struct Listener { /* * We track this per listener so that if you isolate clients to a specific listener, you have * control over on which thread it ends up. */ size_t next_thread_index = 0; ListenerProtocol protocol = ListenerProtocol::IPv46; std::string inet4BindAddress; std::string inet6BindAddress; std::string unixSocketPath; std::optional unixSocketUser; std::optional unixSocketGroup; std::optional unixSocketMode; int port = 0; ConnectionProtocol connectionProtocol = ConnectionProtocol::Mqtt; bool tcpNoDelay = false; HaProxyMode haProxyMode = HaProxyMode::Off; std::string sslFullchain; std::string sslPrivkey; std::string clientVerificationCaFile; std::string clientVerificationCaDir; bool clientVerifictionStillDoAuthn = false; std::optional sslctx; AllowListenerAnonymous allowAnonymous = AllowListenerAnonymous::None; std::optional acmeRedirectURL; TLSVersion minimumTlsVersion = TLSVersion::TLSv1_1; std::optional overloadMode; bool dropOnAbsentCertificates = false; std::optional maxBufferSize; std::vector exclusiveAllowList; std::vector denyList; std::optional maxQos; std::optional mqtt3QoSExceedAction; void isValid(); bool isSsl() const; bool isTcpNoDelay() const; std::string getProtocolName() const; void loadCertAndKeyFromConfig(); X509ClientVerification getX509ClientVerficationMode() const; bool dropListener() const; std::string getBindAddress(ListenerProtocol p); bool isAllowed(const sockaddr *addr) const; }; #endif // LISTENER_H ================================================ FILE: lockedsharedptr.h ================================================ #ifndef LOCKEDSHAREDPTR_H #define LOCKEDSHAREDPTR_H #include #include template class LockedSharedPtr { std::mutex m; std::shared_ptr p; public: LockedSharedPtr& operator=(const std::shared_ptr &r) noexcept { std::lock_guard l(m); p = r; return *this; } void reset() { std::lock_guard l(m); p.reset(); } std::shared_ptr getCopy() { std::lock_guard l(m); return p; } }; #endif // LOCKEDSHAREDPTR_H ================================================ FILE: lockedweakptr.cpp ================================================ #include "lockedweakptr.h" ================================================ FILE: lockedweakptr.h ================================================ #ifndef LOCKEDWEAKPTR_H #define LOCKEDWEAKPTR_H #include #include template class LockedWeakPtr { std::mutex m; std::weak_ptr p; public: LockedWeakPtr() = default; LockedWeakPtr(const std::shared_ptr &r) : p(r) { } std::shared_ptr lock() { std::lock_guard l(m); return p.lock(); } bool expired() { std::lock_guard l(m); return p.expired(); } LockedWeakPtr& operator=(const std::shared_ptr &r) noexcept { std::lock_guard l(m); p = r; return *this; } }; #endif // LOCKEDWEAKPTR_H ================================================ FILE: logger.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "logger.h" #include #include #include #include #include "threaddata.h" #include "globals.h" #include "threadglobals.h" #include "utils.h" LogLine::LogLine(std::string &&line, bool alsoToStdOut) : line(std::move(line)), alsoToStdOut(alsoToStdOut) { } LogLine::LogLine(const char *s, size_t len, bool alsoToStdOut) : line(s, len), alsoToStdOut(alsoToStdOut) { } LogLine::LogLine() : alsoToStdOut(true) { } bool LogLine::alsoLogToStdOut() const { return alsoToStdOut; } Logger::Logger() { memset(&linesPending, 1, sizeof(sem_t)); sem_init(&linesPending, 0, 0); start(); } Logger::~Logger() { if (running) quit(); if (ofile.is_open()) { ofile.close(); } sem_close(&linesPending); } std::string_view Logger::getLogLevelString(int level) { switch (level) { case LOG_NONE: return "NONE"; case LOG_INFO: return "INFO"; case LOG_NOTICE: return "NOTICE"; case LOG_WARNING: return "WARNING"; case LOG_ERR: return "ERROR"; case LOG_DEBUG: return "DEBUG"; case LOG_SUBSCRIBE: return "SUBSCRIBE"; case LOG_UNSUBSCRIBE: return "UNSUBSCRIBE"; case LOG_PUBLISH: return "PUBLISH"; default: return "UNKNOWN LOG LEVEL"; } } Logger *Logger::getInstance() { static Logger instance; if (!instance.running) instance.start(); return &instance; } void Logger::logf(int level, const char *str, ...) { va_list valist; va_start(valist, str); this->logf(level, str, valist); va_end(valist); } /** * @brief Logger::log * @param level * @return a StreamToLog, that you're not suppose to name. When you don't, its destructor will log the stream. * * Allows logging like: logger->log(LOG_NOTICE) << "blabla: " << 1 << ".". The advantage is safety (printf crashes), and not forgetting printf arguments. * * Beware though: C++ streams chars as characters. When you have an uint8_t or int8_t that's also a char, and those need to be cast to int first. A good * solution needs to be devised. */ StreamToLog Logger::log(int level) { return StreamToLog(level); } bool Logger::wouldLog(int level) const { return static_cast(level & curLogLevel); } void Logger::queueReOpen() { reload = true; sem_post(&linesPending); } void Logger::reOpen() { reload = false; if (ofile.is_open()) { ofile.close(); } if (logPath.empty()) return; ofile.open(logPath, std::ios::app | std::ios::out); if (!ofile.good()) { ofile.close(); logf(LOG_ERR, "(Re)opening log file '%s' error: %s. Logging to stdout.", logPath.c_str(), strerror(errno)); } } // I want all messages logged during app startup to also show on stdout/err, otherwise failure can look so silent. So, call this when the app started. void Logger::noLongerLogToStd() { if (!logPath.empty()) logf(LOG_INFO, "Switching logging from stdout to logfile '%s'", logPath.c_str()); alsoLogToStd = false; } void Logger::setLogPath(const std::string &path) { this->logPath = path; } /** * @brief Logger::setFlags sets the log level based on a maximum desired level. * @param level Level based on the defines LOG_*. * @param logSubscriptions * @param quiet * * The log levels are mosquitto-compatible, and while the terminology is similar to the syslog standard, they are unfortunately * not in order of verbosity/priority. So, we use our own enum for the config setting, so that DEBUG is indeed more verbose than INFO. * * The subscriptions flag is still set explicitly, because you may want that irrespective of the log level. */ void Logger::setFlags(LogLevel level, bool logSubscriptions, bool logPublishes) { curLogLevel = 0; if (level <= LogLevel::Debug) curLogLevel |= LOG_DEBUG; if (level <= LogLevel::Info) curLogLevel |= LOG_INFO; if (level <= LogLevel::Notice) curLogLevel |= LOG_NOTICE; if (level <= LogLevel::Warning) curLogLevel |= LOG_WARNING; if (level <= LogLevel::Error) curLogLevel |= LOG_ERR; if (logSubscriptions) curLogLevel |= (LOG_UNSUBSCRIBE | LOG_SUBSCRIBE); else curLogLevel &= ~(LOG_UNSUBSCRIBE | LOG_SUBSCRIBE); if (logPublishes) curLogLevel |= LOG_PUBLISH; else curLogLevel &= ~LOG_PUBLISH; } /** * @brief Logger::setFlags is provided for backwards compatability * @param logDebug * @param quiet */ void Logger::setFlags(std::optional logDebug, std::optional quiet) { if (logDebug) { if (logDebug.value()) curLogLevel |= LOG_DEBUG; else curLogLevel &= ~LOG_DEBUG; } // It only makes sense to allow quiet to mute things in backward compatability mode, not enable LOG_NOTICE and LOG_INFO again. if (quiet) { if (quiet.value()) curLogLevel &= ~(LOG_NOTICE | LOG_INFO); } } void Logger::quit() { std::lock_guard locker(startStopMutex); running = false; sem_post(&linesPending); if (writerThread.joinable()) writerThread.join(); } void Logger::start() { std::lock_guard locker(startStopMutex); if (writerThread.joinable()) return; running = true; auto f = std::bind(&Logger::writeLog, this); this->writerThread = std::thread(f, this); pthread_t native = this->writerThread.native_handle(); pthread_setname_np(native, "LogWriter"); } void Logger::writeLog() { maskAllSignalsCurrentThread(); int graceCounter = 0; LogLine line; while(running || (!lines.empty() && graceCounter++ < 1000 )) { sem_wait(&linesPending); if (reload) { reOpen(); } { std::lock_guard locker(logMutex); if (lines.empty()) continue; line = std::move(lines.front()); lines.pop(); } if (ofile.is_open()) { ofile << line.getLine() << std::endl; if (!ofile.good()) { std::cerr << "Writing to log failed. Closing file and enabling stdout logger" << std::endl; ofile.close(); } } if (!(ofile.is_open() && ofile.good()) || line.alsoLogToStdOut()) { #ifdef TESTING std::cerr << line.getLine() << std::endl; // the stdout interfers with Qt test XML output, so using stderr. #else std::cout << line.getLine() << std::endl; #endif } } } std::string Logger::getPrefix(int level) { thread_local static auto caller_thread_id = pthread_self(); const auto td = ThreadGlobals::getThreadData(); std::ostringstream oss; const std::string stamp = timestampWithMillis(); oss << "[" << stamp << "] [" << getLogLevelString(level) << "] "; if (td) { oss << "[T " << td->threadnr << "] "; } else if (pthread_equal(caller_thread_id, globals->createdByThread)) { oss << "[main] "; } else { std::array buf{}; if (pthread_getname_np(caller_thread_id, buf.data(), buf.size()) == 0) { std::string tname(buf.data()); oss << "[T " << tname << "] "; } else { oss << "[T custom] "; } } std::string result = oss.str(); return result; } void Logger::logstring(int level, const std::string &str) { if ((level & curLogLevel) == 0) return; std::string s = getPrefix(level); s.append(str); LogLine line(std::move(s), alsoLogToStd); { std::lock_guard locker(logMutex); lines.push(std::move(line)); } sem_post(&linesPending); } void Logger::logf(int level, const char *str, va_list valist) { if ((level & curLogLevel) == 0) return; std::string s = getPrefix(level); s.append(str); const char *logfmtstring = s.c_str(); constexpr const int buf_size = 512; char buf[buf_size + 1]; buf[buf_size] = 0; va_list valist2; va_copy(valist2, valist); const int rc = vsnprintf(buf, buf_size, logfmtstring, valist2); va_end(valist2); if (rc < 0) return; size_t len = std::min(buf_size, strlen(buf)); LogLine line(buf, len, alsoLogToStd); { std::lock_guard locker(logMutex); lines.push(std::move(line)); } sem_post(&linesPending); } int logSslError(const char *str, size_t len, void *u) { (void) u; std::string msg(str, len); Logger *logger = Logger::getInstance(); logger->logstring(LOG_ERR, msg); return 0; } StreamToLog::StreamToLog(int level) : level(level) { } StreamToLog::~StreamToLog() { const std::string s = str(); Logger *logger = Logger::getInstance(); logger->logstring(this->level, s); } ================================================ FILE: logger.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef LOGGER_H #define LOGGER_H #include #include #include #include #include #include #include #include "semaphore.h" #include "flashmq_plugin.h" enum class LogLevel { Debug, Info, Notice, Warning, Error, None }; int logSslError(const char *str, size_t len, void *u); /** * @brief Use as a temporary, so don't give a name. This makes the stream gets logged immediately. */ class StreamToLog : public std::ostringstream { int level = LOG_NOTICE; public: StreamToLog(int level); ~StreamToLog(); }; class LogLine { std::string line; bool alsoToStdOut; public: LogLine(std::string &&line, bool alsoToStdOut); LogLine(const char *s, size_t len, bool alsoToStdOut); LogLine(); LogLine(const LogLine &other) = delete; LogLine(LogLine &&other) = default; LogLine &operator=(LogLine &&other) = default; bool alsoLogToStdOut() const; const std::string &getLine() const { return line;} ; }; class Logger { std::string logPath; int curLogLevel = LOG_ERR | LOG_WARNING | LOG_NOTICE | LOG_INFO | LOG_SUBSCRIBE | LOG_UNSUBSCRIBE ; std::mutex logMutex; std::mutex startStopMutex; std::queue lines; sem_t linesPending; std::thread writerThread; bool running = true; std::fstream ofile; bool alsoLogToStd = true; bool reload = false; Logger(); ~Logger(); static std::string_view getLogLevelString(int level); void reOpen(); void writeLog(); static std::string getPrefix(int level); public: static Logger *getInstance(); void logstring(int level, const std::string &str); void logf(int level, const char *str, va_list args); void logf(int level, const char *str, ...); StreamToLog log(int level); bool wouldLog(int level) const; void queueReOpen(); void noLongerLogToStd(); void setLogPath(const std::string &path); void setFlags(LogLevel level, bool logSubscriptions, bool logPublishes); void setFlags(std::optional logDebug, std::optional quiet); void quit(); void start(); }; #endif // LOGGER_H ================================================ FILE: main.cpp ================================================ #include "fmqmain.h" int main(int argc, char *argv[]) { return fmqmain(argc, argv); } ================================================ FILE: mainapp.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "mainapp.h" #include #include "exceptions.h" #include #include #include #include #include #include #include #include #include #include "logger.h" #include "threadglobals.h" #include "threadloop.h" #include "threadglobals.h" #include "globalstats.h" #include "utils.h" #include "bridgeconfig.h" #include "bridgeinfodb.h" #include "globals.h" #include "fmqssl.h" #include "persistencefunctions.h" #include "sdnotify.h" MainApp::MainApp(const std::string &configFilePath) { globals = Globals(); subscriptionStore = globals->subscriptionStore; epollFdAccept = check(epoll_create(999)); taskEventFd = eventfd(0, EFD_NONBLOCK); confFileParser = std::make_unique(configFilePath); loadConfig(false); this->num_threads = get_nprocs(); if (settings.threadCount > 0) { this->num_threads = settings.threadCount; logger->logf(LOG_NOTICE, "%d threads specified by 'thread_count'.", num_threads); } else { logger->logf(LOG_NOTICE, "%d CPUs are detected, making as many threads. Use 'thread_count' setting to override.", num_threads); } if (num_threads <= 0) throw std::runtime_error("Invalid number of CPUs: " + std::to_string(num_threads)); if (!settings.storageDir.empty()) { try { correctBackupDbPermissions(settings.storageDir); } catch (std::exception &ex) {} const std::string retainedDbPath = settings.getRetainedMessagesDBFile(); if (settings.retainedMessagesMode == RetainedMessagesMode::Enabled) subscriptionStore->loadRetainedMessages(settings.getRetainedMessagesDBFile()); else logger->logf(LOG_INFO, "Not loading '%s', because 'retained_messages_mode' is not 'enabled'.", retainedDbPath.c_str()); subscriptionStore->loadSessionsAndSubscriptions(settings.getSessionsDBFile()); } } MainApp::~MainApp() { if (taskEventFd >= 0) close(taskEventFd); if (epollFdAccept >= 0) close(epollFdAccept); globals = Globals(); } void MainApp::doHelp(const char *arg) { puts("FlashMQ - the scalable light-weight MQTT broker"); puts(""); puts("Documentation:"); puts(" 'man 5 flashmq.conf' or https://www.flashmq.org/man/flashmq.conf.5"); puts(" 'man 1 flashmq' or https://www.flashmq.org/man/flashmq.1"); puts(""); puts("Signals:"); puts(" SIGHUP: reload configuration."); puts(" SIGUSR1: reopen log files."); puts(" SIGUSR2: perform malloc_trim(), to try to yield unused heap memory to the OS."); puts(""); printf("Usage: %s [options]\n", arg); puts(""); puts(" -h, --help Print help"); puts(" -c, --config-file Configuration file. Default '/etc/flashmq/flashmq.conf'."); puts(" -t, --test-config Test configuration file."); #ifndef NDEBUG puts(" -z, --fuzz-file For fuzzing, provides the bytes that would be sent by a client."); puts(" If the name contains 'web' it will activate websocket mode."); puts(" If the name also contains 'upgrade', it will assume the websocket"); puts(" client is upgrade, and bypass the cryptograhically secured websocket"); puts(" handshake."); #endif puts(" -V, --version Show version"); puts(" -l, --license Show license"); } void MainApp::showLicense() { std::string sse = "without SSE support"; #ifdef __SSE4_2__ sse = "with SSE4.2 support"; #endif printf("FlashMQ Version %s %s\n", FLASHMQ_VERSION, sse.c_str()); puts("Copyright (C) 2021-2026 Wiebe Cazemier."); puts("License OSL3: Open Software License 3.0 ."); puts(""); puts("Author: Wiebe Cazemier "); } std::list MainApp::createListenSocket(const std::shared_ptr &listener, bool do_listen) { std::list result; if (listener->protocol != ListenerProtocol::Unix && listener->port <= 0) return result; std::vector protocols; if (listener->protocol == ListenerProtocol::IPv46) { protocols.push_back(ListenerProtocol::IPv4); protocols.push_back(ListenerProtocol::IPv6); } else { protocols.push_back(listener->protocol); } bool error = false; for (ListenerProtocol p : protocols) { std::string pname; sa_family_t family = AF_UNSPEC; if (p == ListenerProtocol::IPv4) { pname = "IPv4"; family = AF_INET; } else if (p == ListenerProtocol::IPv6) { pname = "IPv6"; family = AF_INET6; } else if (p == ListenerProtocol::Unix) { pname = "unix socket"; family = AF_UNIX; } std::ostringstream logtext; try { if (family == AF_UNIX) { logtext << "Creating unix socket listener on " << listener->unixSocketPath; } else { logtext << "Creating " << pname << " " << listener->getProtocolName() << " "; if (listener->haProxyMode == HaProxyMode::On) logtext << "haproxy "; else if (listener->haProxyMode >= HaProxyMode::HaProxyClientVerification) logtext << "haproxy with client verification"; logtext << "listener on [" << listener->getBindAddress(p) << "]:" << listener->port; } BindAddr bindAddr(family, listener->getBindAddress(p), listener->port, listener->unixSocketUser, listener->unixSocketGroup, listener->unixSocketMode); ScopedSocket uniqueListenFd(check(socket(family, SOCK_STREAM, 0)), listener->unixSocketPath, listener); if (p == ListenerProtocol::Unix) { unlink_if_sock(listener->unixSocketPath); } else { // Not needed for now. Maybe I will make multiple accept threads later, with SO_REUSEPORT. int optval = 1; check(setsockopt(uniqueListenFd.get(), SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval))); if (listener->isTcpNoDelay()) { int tcp_nodelay_optval = 1; check(setsockopt(uniqueListenFd.get(), IPPROTO_TCP, TCP_NODELAY, &tcp_nodelay_optval, sizeof(tcp_nodelay_optval))); } } int flags = fcntl(uniqueListenFd.get(), F_GETFL); check(fcntl(uniqueListenFd.get(), F_SETFL, flags | O_NONBLOCK )); bindAddr.bind_socket(uniqueListenFd.get()); uniqueListenFd.setListenMessage(logtext.str()); if (do_listen) { uniqueListenFd.doListen(this->epollFdAccept); } result.push_back(std::move(uniqueListenFd)); } catch (std::exception &ex) { logger->log(LOG_ERR) << logtext.str() << " failed: " << ex.what(); error = true; } } if (error) return std::list(); return result; } void MainApp::wakeUpThread() { uint64_t one = 1; check(write(taskEventFd, &one, sizeof(uint64_t))); } void MainApp::addImmediateTask(std::function f) { bool wakeupNeeded = true; { auto task_queue_locked = taskQueue.lock(); wakeupNeeded = task_queue_locked->empty(); task_queue_locked->push_back(std::move(f)); } if (wakeupNeeded) { wakeUpThread(); } } void MainApp::queueKeepAliveCheckAtAllThreads() { for (ThreadDataOwner &thread : threads) { thread->queueDoKeepAliveCheck(); } } void MainApp::queuePasswordFileReloadAllThreads() { for (ThreadDataOwner &thread : threads) { thread->queuePasswdFileReload(); } } void MainApp::queuepluginPeriodicEventAllThreads() { for (ThreadDataOwner &thread : threads) { thread->queuepluginPeriodicEvent(); } } void MainApp::setFuzzFile(const std::string &fuzzFilePath) { this->fuzzFilePath = fuzzFilePath; } /** * @brief MainApp::queuePublishStatsOnDollarTopic publishes the dollar topics, on a thread that has thread local authentication. */ void MainApp::queuePublishStatsOnDollarTopic() { if (!threads.empty()) { std::vector> thread_datas; for(ThreadDataOwner &t : threads) { thread_datas.push_back(t.getThreadData()); } threads.at(0)->queuePublishStatsOnDollarTopic(thread_datas); } } /** * @brief MainApp::saveStateInThread starts a thread for disk IO, because file IO is not async. */ void MainApp::saveStateInThread() { std::list bridgeInfos = BridgeInfoForSerializing::getBridgeInfosForSerializing(this->bridgeConfigs); auto f = std::bind(&saveState, this->settings, bridgeInfos, true); this->bgWorker.addTask(f, true); } void MainApp::queueSendQueuedWills() { if (!threads.empty()) { int threadnr = rand() % threads.size(); std::shared_ptr t = threads[threadnr].getThreadData(); t->queueSendingQueuedWills(); } } void MainApp::waitForWillsQueued() { int i = 0; while(std::any_of(threads.begin(), threads.end(), [](const ThreadDataOwner &t){ return !t->allWillsQueued && t->running; }) && i++ < 5000) { usleep(1000); } } void MainApp::queueRetainedMessageExpiration() { if (!threads.empty()) { int threadnr = rand() % threads.size(); std::shared_ptr t = threads[threadnr].getThreadData(); t->queueRemoveExpiredRetainedMessages(); } } void MainApp::sendBridgesToThreads() { if (threads.empty()) return; size_t i1 = 0; size_t i2 = 0; auto bridge_pos = this->bridgeConfigs.begin(); while (bridge_pos != this->bridgeConfigs.end()) { auto cur = bridge_pos; bridge_pos++; BridgeConfig &bridge = cur->second; std::shared_ptr owner = bridge.owner.lock(); if (!owner) { size_t index = 0; if (bridge.getFmqClientGroupId()) index = i1++; else index = i2++; owner = threads.at(index % threads.size()).getThreadData(); bridge.owner = owner; } if (bridge.queueForDelete) { owner->removeBridgeQueued(bridge, "Bridge disappeared from config"); this->bridgeConfigs.erase(cur); } else { std::shared_ptr bridgeState = std::make_shared(bridge); bridgeState->threadData = std::weak_ptr(owner); owner->giveBridge(bridgeState); } } } void MainApp::queueBridgeReconnectAllThreads() { try { for (ThreadDataOwner &thread : threads) { thread->queueBridgeReconnect(); } } catch (std::exception &ex) { Logger *logger = Logger::getInstance(); logger->logf(LOG_ERR, ex.what()); } } void MainApp::queueInternalHeartbeat() { if (threads.empty()) return; auto queue_time = std::chrono::steady_clock::now(); const std::chrono::milliseconds main_loop_drift = drift.getDrift(); if (main_loop_drift > settings.maxEventLoopDrift) { Logger::getInstance()->log(LOG_WARNING) << "Main loop thread drift is " << main_loop_drift.count() << " ms."; } if (this->medianThreadDrift > settings.maxEventLoopDrift) { Logger::getInstance()->log(LOG_WARNING) << "Median thread drift is " << this->medianThreadDrift.count() << " ms."; } drift.update(queue_time); std::vector drifts(threads.size()); std::transform(threads.begin(), threads.end(), drifts.begin(), [] (const ThreadDataOwner &t) { return t->driftCounter.getDrift(); }); const size_t n = drifts.size() / 2; std::nth_element(drifts.begin(), drifts.begin() + n, drifts.end()); this->medianThreadDrift = drifts.at(n); for (ThreadDataOwner &thread : threads) { thread->queueInternalHeartbeat(); } } void MainApp::performAllImmediateTasks() { std::vector> copiedTasks; { auto task_queue_locked = taskQueue.lock(); copiedTasks = std::move(*task_queue_locked); task_queue_locked->clear(); } for(auto &f : copiedTasks) { try { f(); } catch (std::exception &ex) { Logger::getInstance()->log(LOG_ERR) << "Error in MainApp::performAllImmediateTasks: " << ex.what(); } } } std::shared_ptr MainApp::initMainApp(int argc, char *argv[]) { static struct option long_options[] = { {"help", no_argument, nullptr, 'h'}, {"config-file", required_argument, nullptr, 'c'}, {"test-config", no_argument, nullptr, 't'}, {"fuzz-file", required_argument, nullptr, 'z'}, {"version", no_argument, nullptr, 'V'}, {"license", no_argument, nullptr, 'l'}, {nullptr, 0, nullptr, 0} }; #ifdef TESTING const std::string defaultConfigFile = "/dummy/flashmq.org"; #else const std::string defaultConfigFile = "/etc/flashmq/flashmq.conf"; #endif std::string configFile; if (access(defaultConfigFile.c_str(), R_OK) == 0) { configFile = defaultConfigFile; } std::string fuzzFile; int option_index = 0; int opt; bool testConfig = false; optind = 1; // allow repeated calls to getopt_long. while((opt = getopt_long(argc, argv, "hc:Vltz:", long_options, &option_index)) != -1) { switch(opt) { case 'c': configFile = optarg; break; case 'l': MainApp::showLicense(); exit(0); case 'V': MainApp::showLicense(); exit(0); case 'z': fuzzFile = optarg; break; case 'h': MainApp::doHelp(argv[0]); exit(16); case 't': testConfig = true; break; case '?': MainApp::doHelp(argv[0]); exit(16); } } if (optind < argc) { throw std::runtime_error("Error: positional arguments given. Did you mean --config-file ?"); } if (testConfig) { try { if (configFile.empty()) { std::cerr << "No config specified (with -c) and the default " << defaultConfigFile << " not found." << std::endl << std::endl; MainApp::doHelp(argv[0]); exit(1); } ConfigFileParser c(configFile); c.loadFile(true); printf("Config '%s' OK\n", configFile.c_str()); exit(0); } catch (ConfigFileException &ex) { std::cerr << ex.what() << std::endl; exit(1); } } std::shared_ptr result(new MainApp(configFile)); result->setSelf(result); result->setFuzzFile(fuzzFile); return result; } void MainApp::start() { assert(mInitializedByThread == pthread_self()); #ifndef NDEBUG #ifndef TESTING if (!getFuzzMode()) { oneInstanceLock.lock(); } #endif #endif #ifdef NDEBUG logger->noLongerLogToStd(); #endif struct epoll_event ev {}; ev.data.fd = taskEventFd; ev.events = EPOLLIN; check(epoll_ctl(this->epollFdAccept, EPOLL_CTL_ADD, taskEventFd, &ev)); #ifndef NDEBUG // I fuzzed using afl-fuzz. You need to compile it with their compiler. if (getFuzzMode()) { // No threads for execution stability/determinism. num_threads = 0; settings.allowAnonymous = true; int fd = open(fuzzFilePath.c_str(), O_RDONLY); assert(fd > 0); int fdnull = open("/dev/null", O_RDWR); assert(fdnull > 0); int fdnull2 = open("/dev/null", O_RDWR); assert(fdnull2 > 0); // TODO: matching for filename patterns doesn't work, because AFL fuzz changes the name. const std::string fuzzFilePathLower = str_tolower(fuzzFilePath); ConnectionProtocol connectionProtocol = strContains(fuzzFilePathLower, "web") ? ConnectionProtocol::WebsocketMqtt : ConnectionProtocol::Mqtt; try { const std::string empty; std::vector packetQueueIn; std::vector subtopics; std::shared_ptr pluginLoader = std::make_shared(); std::shared_ptr threaddata = std::make_shared(0, settings, pluginLoader, mSelf); ThreadGlobals::assignThreadData(threaddata); std::shared_ptr client = std::make_shared(ClientType::Normal, fd, threaddata, FmqSsl(), connectionProtocol, HaProxyMode::Off, nullptr, settings, true); std::shared_ptr subscriber = std::make_shared(ClientType::Normal, fdnull, threaddata, FmqSsl(), connectionProtocol, HaProxyMode::Off, nullptr, settings, true); subscriber->setClientProperties(ProtocolVersion::Mqtt311, "subscriber", {}, "subuser", true, 60); subscriber->setAuthenticated(true); std::shared_ptr websocketsubscriber = std::make_shared(ClientType::Normal, fdnull2, threaddata, FmqSsl(), ConnectionProtocol::WebsocketMqtt, HaProxyMode::Off, nullptr, settings, true); websocketsubscriber->setClientProperties(ProtocolVersion::Mqtt311, "websocketsubscriber", {}, "websocksubuser", true, 60); websocketsubscriber->setAuthenticated(true); websocketsubscriber->setFakeUpgraded(); subscriptionStore->registerClientAndKickExistingOne(websocketsubscriber); subtopics = splitTopic("#"); subscriptionStore->addSubscription(websocketsubscriber->getSession(), subtopics, 0, false, false, empty, 0); subscriptionStore->registerClientAndKickExistingOne(subscriber); subscriptionStore->addSubscription(subscriber->getSession(), subtopics, 0, false, false, empty, 0); if (connectionProtocol == ConnectionProtocol::WebsocketMqtt && strContains(fuzzFilePathLower, "upgrade")) { client->setFakeUpgraded(); subscriber->setFakeUpgraded(); } { VectorClearGuard vectorClearGuard(packetQueueIn); client->readFdIntoBuffer(); client->bufferToMqttPackets(packetQueueIn, client); for (MqttPacket &packet : packetQueueIn) { packet.handle(client); } subscriber->writeBufIntoFd(); websocketsubscriber->writeBufIntoFd(); } } catch (ProtocolError &ex) { logger->logf(LOG_ERR, "Expected MqttPacket handling error: %s", ex.what()); } running = false; } #endif std::shared_ptr pluginLoader = std::make_shared(); pluginLoader->loadPlugin(settings.pluginPath); std::unordered_map &authOpts = settings.getFlashmqpluginOpts(); pluginLoader->mainInit(authOpts); for (int i = 0; i < num_threads; i++) { threads.emplace_back(i, settings, pluginLoader, mSelf); threads.back().start(); } // Populate the $SYS topics, otherwise you have to wait until the timer expires. if (!threads.empty()) { std::vector> thread_datas; for(ThreadDataOwner &t : threads) { thread_datas.push_back(t.getThreadData()); } threads.front()->queuePublishStatsOnDollarTopic(thread_datas); } this->threadsPendingInit = threads.size(); { auto locked_global_threads = globals->threadDatas.lock(); for (auto &t : threads) { locked_global_threads->push_back(t.getThreadData()); } } sendBridgesToThreads(); queueBridgeReconnectAllThreads(); std::minstd_rand randomish; randomish.seed(get_random_int()); this->bgWorker.start(); std::vector events(128); while (running) { const uint32_t next_task_delay = timed_tasks.getTimeTillNext(); const uint32_t epoll_wait_time = std::min(next_task_delay, 100); int num_fds = epoll_wait(this->epollFdAccept, events.data(), events.size(), epoll_wait_time); if (epoll_wait_time == 0) timed_tasks.performAll(); if (num_fds < 0) { if (errno == EINTR) continue; logger->logf(LOG_ERR, "Waiting for listening socket error: %s", strerror(errno)); } for (int i = 0; i < num_fds; i++) { int cur_fd = events[i].data.fd; try { if (cur_fd != taskEventFd) { std::shared_ptr listener = activeListenSockets[cur_fd].getListener(); if (!listener) continue; std::shared_ptr thread_data = threads[listener->next_thread_index % num_threads].getThreadData(); if (logger->wouldLog(LOG_DEBUG)) logger->log(LOG_DEBUG) << "Accepting connection on thread " << thread_data->threadnr << " on " << listener->getProtocolName(); struct sockaddr_storage addr_mem {}; struct sockaddr *addr = reinterpret_cast(&addr_mem); socklen_t len = sizeof(addr_mem); int fd = check(accept(cur_fd, addr, &len)); if (!listener->isAllowed(addr)) { std::ostringstream oss; oss << "Connection from " << sockaddrToString(addr) << " not allowed on " << listener->getProtocolName(); if (timed_tasks.getTaskCount() < 1000) { const uint32_t delay = (randomish() & 0x0FFF) + 1000; oss << ". Closing after " << delay << " ms."; // Queue managed fd so that on SIGHUP, they all get closed when the timers are cleared. A bit clunky to // make a shared pointer, but tasks need copyable things. std::shared_ptr fd_managed = std::make_shared(fd); auto close_f = [fd_managed]() { (void)fd_managed; }; timed_tasks.addTask(close_f, delay); } else { oss << ". Closing now."; close(fd); } logger->log(LOG_NOTICE) << oss.str(); continue; } listener->next_thread_index++; /* * I decided to not use a delayed close mechanism. It has been observed that under overload and clients in a reconnect loop, * you can collect open files up to (a) million(s). By accepting and closing, the hope is we can keep clients at bay from * the thread loops well enough. */ if (this->medianThreadDrift > settings.maxEventLoopDrift || this->drift.getDrift() > settings.maxEventLoopDrift) { const std::string addr_s = sockaddrToString(addr); bool do_close = false; const OverloadMode overload_mode = listener->overloadMode.value_or(settings.overloadMode); if (overload_mode == OverloadMode::CloseNewClients) { if (overloadLogCounter <= OVERLOAD_LOGS_MUTE_AFTER_LINES) { overloadLogCounter++; logger->log(LOG_ERROR) << "[OVERLOAD] FlashMQ seems to be overloaded while accepting new connection(s) from '" << addr_s << ". Closing socket. See 'overload_mode' and 'max_event_loop_drift'."; } do_close = true; } else if (overload_mode == OverloadMode::Log) { if (overloadLogCounter <= OVERLOAD_LOGS_MUTE_AFTER_LINES) { overloadLogCounter++; logger->log(LOG_WARNING) << "[OVERLOAD] FlashMQ seems to be overloaded while accepting new connection(s) from '" << addr_s << ". See 'overload_mode' and 'max_event_loop_drift'."; } } else { throw std::runtime_error("Unimplemented OverloadMode"); } if (overloadLogCounter > OVERLOAD_LOGS_MUTE_AFTER_LINES && overloadLogCounter < OVERLOAD_LOGS_MUTE_AFTER_LINES * 2) { overloadLogCounter = OVERLOAD_LOGS_MUTE_AFTER_LINES * 5; logger->log(LOG_WARNING) << "[OVERLOAD] Muting overload logging until it recovers, to avoid log spam and extra load."; } if (do_close) { close(fd); continue; } } else { overloadLogCounter = 0; } FmqSsl clientSSL; if (listener->isSsl()) { if (!listener->sslctx) { logger->log(LOG_ERR) << "Listener is SSL but SSL context is null. Application bug."; close(fd); continue; } clientSSL = FmqSsl(*listener->sslctx); if (!clientSSL) { logger->logf(LOG_ERR, "Problem creating SSL object. Closing client."); close(fd); continue; } clientSSL.set_fd(fd); } // Don't use std::make_shared to avoid the weak pointers keeping the control block in memory. std::shared_ptr client = std::shared_ptr(new Client( ClientType::Normal, fd, thread_data, std::move(clientSSL), listener->connectionProtocol, listener->haProxyMode, addr, settings)); if (listener->getX509ClientVerficationMode() != X509ClientVerification::None) { client->setSslVerify(listener->getX509ClientVerficationMode()); } client->setAllowAnonymousOverride(listener->allowAnonymous); client->setAcmeRedirect(listener->acmeRedirectURL); if (listener->maxBufferSize) client->setMaxBufSizeOverride(listener->maxBufferSize.value()); if (listener->maxQos) client->setMaxQos(listener->maxQos.value()); if (listener->mqtt3QoSExceedAction) client->setMqtt3QoSExceedAction(listener->mqtt3QoSExceedAction.value()); thread_data->giveClient(std::move(client)); globals->stats.socketConnects.inc(); } else { uint64_t eventfd_value = 0; check(read(cur_fd, &eventfd_value, sizeof(uint64_t))); if (doConfigReload) { reloadConfig(); } if (doLogFileReOpen) { reopenLogfile(); } if (doQuitAction) { quit(); } if (doMemoryTrim) { memoryTrim(); } performAllImmediateTasks(); } } catch (std::exception &ex) { logger->logf(LOG_ERR, "Problem in main thread: %s", ex.what()); } } } activeListenSockets.clear(); this->bgWorker.stop(); if (settings.willsEnabled) { logger->logf(LOG_DEBUG, "Having all client in all threads send or queue their will."); for(ThreadDataOwner &thread : threads) { thread->queueSendWills(); } waitForWillsQueued(); } logger->logf(LOG_DEBUG, "Having all client in all threads send a disconnect packet and initiate quit."); for(ThreadDataOwner &thread : threads) { thread->queueSendDisconnects(); } oneInstanceLock.unlock(); { logger->logf(LOG_DEBUG, "Waiting for our own quit event to have been queued."); int count = 0; while(std::any_of(threads.begin(), threads.end(), [](const ThreadDataOwner &t){ return t->running; })) { if (count++ >= 30000) break; usleep(1000); } } logger->logf(LOG_DEBUG, "Waiting for threads clean-up functions to finish."); int count = 0; bool waitTimeExpired = false; while(std::any_of(threads.begin(), threads.end(), [](ThreadDataOwner &t){ return !t->finished; })) { if (count++ >= 30000) { waitTimeExpired = true; break; } usleep(1000); } if (waitTimeExpired) { logger->logf(LOG_WARNING, "(Some) threads failed to terminate. Program will exit uncleanly. If you're using a plugin, it may not be thread-safe."); } else { for(ThreadDataOwner &thread : threads) { logger->logf(LOG_DEBUG, "Waiting for thread %d to join.", thread->threadnr); thread.waitForQuit(); } } pluginLoader->mainDeinit(settings.getFlashmqpluginOpts()); globals->quitting = true; this->bgWorker.waitForStop(); std::list bridgeInfos = BridgeInfoForSerializing::getBridgeInfosForSerializing(this->bridgeConfigs); saveState(this->settings, bridgeInfos, false); if (errorExit) { throw std::runtime_error("Exiting with error."); } } void MainApp::queueQuit() { this->doQuitAction = true; wakeUpThread(); } void MainApp::quit() { doQuitAction = false; std::lock_guard guard(quitMutex); if (!running) return; Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Quitting FlashMQ"); running = false; } void MainApp::setSelf(const std::shared_ptr &self) { assert(this == self.get()); mSelf = self; } bool MainApp::getFuzzMode() const { bool fuzzMode = false; #ifndef NDEBUG fuzzMode = !fuzzFilePath.empty(); #endif return fuzzMode; } void MainApp::setlimits() { rlim_t nofile = settings.rlimitNoFile; logger->log(LOG_INFO) << "Setting rlimit nofile to " << nofile; struct rlimit v = { nofile, nofile }; if (setrlimit(RLIMIT_NOFILE, &v) < 0) { logger->logf(LOG_ERR, "Setting ulimit nofile failed: '%s'. This means the default is used. Note. It's also subject to systemd's 'LimitNOFILE', " "which in turn is maxed to '/proc/sys/fs/nr_open', which can be set like 'sysctl fs.nr_open=15000000'", strerror(errno)); } } /** * @brief MainApp::loadConfig is loaded on app start where you want it to crash, loaded from within try/catch on reload, to allow the program to continue. */ void MainApp::loadConfig(bool reload) { Logger *logger = Logger::getInstance(); logger->log(LOG_NOTICE) << std::boolalpha << "Loading config. Reload: " << reload << "."; // Atomic loading, first test. confFileParser->loadFile(true); confFileParser->loadFile(false); const Settings oldSettings = settings; settings = confFileParser->getSettings(); ThreadGlobals::assignSettings(&settings); if (settings.listeners.empty()) { std::shared_ptr defaultListener = std::make_shared(); // Kind of a quick hack to do this this way. #ifdef TESTING defaultListener->port = 21883; #endif defaultListener->isValid(); settings.listeners.push_back(defaultListener); } listeners = settings.listeners; for (std::shared_ptr &listener : this->listeners) { listener->isValid(); } if (!getFuzzMode()) { activeListenSockets.clear(); bool listenerCreateError = false; for(std::shared_ptr &listener : this->listeners) { listener->loadCertAndKeyFromConfig(); std::list scopedSockets = createListenSocket(listener, reload); if (scopedSockets.empty()) { listenerCreateError = true; continue; } for(ScopedSocket &s : scopedSockets) { activeListenSockets[s.get()] = std::move(s); } } if (listenerCreateError && !reload) { throw std::runtime_error("Some listeners failed."); } } { for (auto &pair : bridgeConfigs) { pair.second.queueForDelete = true; } std::list bridges = loadBridgeInfo(this->settings); for (BridgeConfig &bridge : bridges) { auto pos = this->bridgeConfigs.find(bridge.clientidPrefix); if (pos != this->bridgeConfigs.end()) { logger->log(LOG_NOTICE) << "Assing new config to bridge '" << bridge.clientidPrefix << "' and reconnect if needed."; BridgeConfig &cur = pos->second; std::shared_ptr owner = cur.owner.lock(); std::string clientid = cur.getClientid(); cur = bridge; cur.owner = owner; cur.setClientId(cur.clientidPrefix, clientid); } else { logger->log(LOG_NOTICE) << "Adding bridge '" << bridge.clientidPrefix << "'."; this->bridgeConfigs[bridge.clientidPrefix] = bridge; } } for (auto &pair : bridgeConfigs) { if (pair.second.queueForDelete) { logger->log(LOG_NOTICE) << "Queueing bridge '" << pair.first << "' for removal, because it disappeared from config."; } } // On first load, the start() function will take care of it. if (reload) { sendBridgesToThreads(); queueBridgeReconnectAllThreads(); } } logger->setLogPath(settings.logPath); logger->queueReOpen(); logger->setFlags(settings.logLevel, settings.logSubscriptions, settings.logPublishes); logger->setFlags(settings.logDebug, settings.quiet); setlimits(); for (ThreadDataOwner &thread : threads) { thread->queueReload(settings); } if (reload) reloadTimers(&oldSettings); } void MainApp::reloadConfig() { doConfigReload = false; Logger *logger = Logger::getInstance(); try { loadConfig(true); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error reloading config: %s", ex.what()); } } void MainApp::reopenLogfile() { doLogFileReOpen = false; Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Reopening log files"); logger->queueReOpen(); logger->logf(LOG_NOTICE, "Log files reopened"); } void MainApp::reloadTimers(const Settings *old_settings) { /* * Avoid postponing timed events when people continously send sighup signals. * * TODO: better method, that is not susceptible to forgetting adding settings here when more timers can change. */ if (old_settings && settings.pluginTimerPeriod == old_settings->pluginTimerPeriod && settings.saveStateInterval == old_settings->saveStateInterval) { logger->log(LOG_NOTICE) << "Timer config not changed. Not re-adding timers."; return; } if (old_settings) logger->log(LOG_NOTICE) << "Settings impacting timed events changed. Re-adding timers."; else logger->log(LOG_NOTICE) << "Adding timers"; timed_tasks.clear(); if (settings.pluginTimerPeriod > 0) { auto fpluginPeriodicEvent = std::bind(&MainApp::queuepluginPeriodicEventAllThreads, this); timed_tasks.addTask(fpluginPeriodicEvent, settings.pluginTimerPeriod * 1000, true); } { auto fSaveState = std::bind(&MainApp::saveStateInThread, this); timed_tasks.addTask(fSaveState, settings.saveStateInterval.count() * 1000, true); } { auto f = std::bind(&MainApp::queueCleanup, this); //const uint64_t derrivedSessionCheckInterval = std::max((settings.expireSessionsAfterSeconds)*1000*2, 600000); //const uint64_t sessionCheckInterval = std::min(derrivedSessionCheckInterval, 86400000); uint32_t interval = 10000; #ifdef TESTING interval = 1000; #endif timed_tasks.addTask(f, interval, true); } { uint32_t interval = 1846849; // prime auto f = std::bind(&MainApp::queuePurgeSubscriptionTree, this); timed_tasks.addTask(f, interval, true); } { uint32_t interval = 3949193; // prime #ifdef TESTING interval = 500; #endif auto f = std::bind(&MainApp::queueRetainedMessageExpiration, this); timed_tasks.addTask(f, interval, true); } { auto fKeepAlive = std::bind(&MainApp::queueKeepAliveCheckAtAllThreads, this); timed_tasks.addTask(fKeepAlive, 5000, true); } { auto fPasswordFileReload = std::bind(&MainApp::queuePasswordFileReloadAllThreads, this); timed_tasks.addTask(fPasswordFileReload, 2000, true); } { auto fPublishStats = std::bind(&MainApp::queuePublishStatsOnDollarTopic, this); timed_tasks.addTask(fPublishStats, 10000, true); } { auto fSendPendingWills = std::bind(&MainApp::queueSendQueuedWills, this); timed_tasks.addTask(fSendPendingWills, 2000, true); } { auto fInternalHeartbeat = std::bind(&MainApp::queueInternalHeartbeat, this); timed_tasks.addTask(fInternalHeartbeat, HEARTBEAT_INTERVAL, true); } { uint32_t interval = 5000; #ifdef TESTING interval = 1000; #endif auto fReconnectBridges = std::bind(&MainApp::queueBridgeReconnectAllThreads, this); timed_tasks.addTask(fReconnectBridges, interval, true); } } /** * @brief MainApp::queueConfigReload is called by a signal handler, and it was observed that it should not do anything that allocates memory, * to avoid locking itself when another signal is received. */ void MainApp::queueConfigReload() { doConfigReload = true; wakeUpThread(); } void MainApp::queueReopenLogFile() { doLogFileReOpen = true; wakeUpThread(); } void MainApp::queueCleanup() { if (!threads.empty()) { int threadnr = rand() % threads.size(); std::shared_ptr t = threads[threadnr].getThreadData(); t->queueRemoveExpiredSessions(); } } void MainApp::queuePurgeSubscriptionTree() { if (!threads.empty()) { int threadnr = rand() % threads.size(); std::shared_ptr t = threads[threadnr].getThreadData(); t->queuePurgeSubscriptionTree(); } } void MainApp::queueMemoryTrim() { doMemoryTrim = true; wakeUpThread(); } void MainApp::memoryTrim() { doMemoryTrim = false; Logger *logger = Logger::getInstance(); logger->log(LOG_NOTICE) << "Initiating malloc_trim(0). Main thread will not be able to accept new connections while it's running."; const auto a = std::chrono::steady_clock::now(); const int result = malloc_trim(0); const auto b = std::chrono::steady_clock::now(); const std::chrono::microseconds dur = std::chrono::duration_cast(b - a); std::string sresult = "Unknown result from malloc_trim(0)."; if (result == 0) sresult = "Result was 0, so no memory was returned to the system."; else if (result == 1) sresult = "Result was 1, so memory was returned to the system."; logger->log(LOG_NOTICE) << "Operation malloc_trim(0) done. " << sresult << " Duration was " << dur.count() << " µs."; } void MainApp::queueThreadInitDecrement() { auto f = [this]() { if (this->threadsPendingInit == 0) return; this->threadsPendingInit--; if (this->threadsPendingInit != 0) return; Logger::getInstance()->log(LOG_NOTICE) << "Threads initialized. Creating listeners and timers."; try { for (auto &pair : activeListenSockets) { pair.second.doListen(this->epollFdAccept); } reloadTimers(nullptr); started = true; notify_ready(); } catch (std::exception &ex) { Logger::getInstance()->log(LOG_ERROR) << "Error creating listeners and timers : " << ex.what(); errorExit = true; queueQuit(); } }; addImmediateTask(f); } ================================================ FILE: mainapp.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef MAINAPP_H #define MAINAPP_H #include #include #include #include #include #include #include #include #include #include #include #include "threaddata.h" #include "subscriptionstore.h" #include "configfileparser.h" #include "scopedsocket.h" #include "oneinstancelock.h" #include "bridgeinfodb.h" #include "backgroundworker.h" #include "driftcounter.h" #include "globals.h" class MainApp { #ifdef TESTING friend class MainAppInThread; friend class MainAppAsFork; #endif std::weak_ptr mSelf; int num_threads = 0; bool started = false; bool running = true; std::vector threads; std::shared_ptr subscriptionStore; std::unique_ptr confFileParser; int epollFdAccept = -1; int taskEventFd = -1; bool doConfigReload = false; bool doLogFileReOpen = false; bool doQuitAction = false; bool errorExit = false; bool doMemoryTrim = false; size_t threadsPendingInit = 0; QueuedTasks timed_tasks; MutexOwned>> taskQueue; uint overloadLogCounter = 0; DriftCounter drift; std::chrono::milliseconds medianThreadDrift = std::chrono::milliseconds(0); Settings settings; std::list> listeners; std::unordered_map activeListenSockets; std::unordered_map bridgeConfigs; std::mutex quitMutex; std::string fuzzFilePath; OneInstanceLock oneInstanceLock; const pthread_t mInitializedByThread = pthread_self(); Logger *logger = Logger::getInstance(); BackgroundWorker bgWorker; void setSelf(const std::shared_ptr &self); bool getFuzzMode() const; void setlimits(); void loadConfig(bool reload); void reloadConfig(); void reopenLogfile(); void reloadTimers(const Settings *old_settings); static void doHelp(const char *arg); static void showLicense(); std::list createListenSocket(const std::shared_ptr &listener, bool do_listen); void wakeUpThread(); void addImmediateTask(std::function f); void queueKeepAliveCheckAtAllThreads(); void queuePasswordFileReloadAllThreads(); void queuepluginPeriodicEventAllThreads(); void setFuzzFile(const std::string &fuzzFilePath); void queuePublishStatsOnDollarTopic(); void saveStateInThread(); void queueSendQueuedWills(); void waitForWillsQueued(); void queueRetainedMessageExpiration(); void sendBridgesToThreads(); void queueBridgeReconnectAllThreads(); void queueInternalHeartbeat(); void performAllImmediateTasks(); MainApp(const std::string &configFilePath); public: MainApp(const MainApp &rhs) = delete; MainApp(MainApp &&rhs) = delete; ~MainApp(); static std::shared_ptr initMainApp(int argc, char *argv[]); void start(); void queueQuit(); void quit(); bool getStarted() const {return started;} static void testConfig(); void queueConfigReload(); void queueReopenLogFile(); void queueCleanup(); void queuePurgeSubscriptionTree(); void queueMemoryTrim(); void memoryTrim(); void queueThreadInitDecrement(); }; #endif // MAINAPP_H ================================================ FILE: man/.gitignore ================================================ ================================================ FILE: man/Makefile ================================================ GROFF_TARGETS := flashmq.1 flashmq.conf.5 HTML5_TARGETS := flashmq.1.html flashmq.conf.5.html all: $(GROFF_TARGETS) $(HTML5_TARGETS) $(GROFF_TARGETS): %: flashmq-docbook5-refentry-to-manpage.xsl %.dbk5 xsltproc --xinclude --stringparam creation-date "$(shell date +'%b %d %Y')" $^ > $@ %.html: flashmq-docbook5-refentry-to-html5.xsl %.dbk5 xsltproc --xinclude --stringparam dbk5.reference $(CURDIR)/reference.dbk5 $^ > $@ ================================================ FILE: man/README.md ================================================ To render a man page: ``` man -l flashmq.1 man -l flashmq.conf.5 ``` To enable color on Debian derivates: ``` env GROFF_SGR=1 man -l flashmq.1 env GROFF_SGR=1 man -l flashmq.conf.5 ``` ================================================ FILE: man/docbook5-refentry-xslt/docbook5-refentry-to-html5.xsl ================================================ # # <xsl:choose> <xsl:when test="dbk:refnamediv/dbk:refdescriptor"> <!-- “When none of the `refname`s is appropriate, [the optional] `refdescriptor` is used to specify the topic name.” --> <xsl:value-of select="dbk:refnamediv/dbk:refdescriptor"/> </xsl:when> <xsl:when test="dbk:refmeta"> <!-- To give this precedence over `dbk:refname` slightly violates std. “processing expectations”: https://tdg.docbook.org/tdg/5.0/refnamediv.html --> <xsl:value-of select="dbk:refmeta/dbk:refentrytitle"/> <xsl:if test="dbk:refmeta/dbk:manvolnum"> <xsl:text> (</xsl:text> <xsl:value-of select="dbk:refmeta/dbk:manvolnum"/> <xsl:text>)</xsl:text> </xsl:if> </xsl:when> <xsl:otherwise> <xsl:value-of select="dbk:refnamediv/dbk:refname[1]"/> </xsl:otherwise> </xsl:choose> <xsl:text> – </xsl:text> <xsl:value-of select="dbk:refnamediv/dbk:refpurpose"/>

( )

Synopsis

      
        
        
           
          
        
        
           
          
        
      
      
    
{ [ } ] ... |

      
        
        
           
          
        
        
           
          
        
      
      
    

refentry/info/biblioid[@class='uri'] missing in source DocBook for man: ( )
  • , DocBook element unrecognized by XSLT: < >
    ================================================ FILE: man/docbook5-refentry-xslt/docbook5-refentry-to-manpage.xsl ================================================ docbook5-to-man.xsl 7 docbook5-to-man.xsl XSLT sheet to transform a DocBook 5.x document to a groff man page xsltproc path/to/docbook5-to-man.xsl path/to/source-doc.dbk > man-page.1 Dependencies docbook5-to-man.xsl requires a XSLT 1.0 processor that understands the following EXSLT extensions: XSLT parameters <xsl:param name="current-date" select="date:date-time()"/> Any string that's understood by https://exslt.github.io/date/functions/format-date/index.html <xsl:param name="date-format" select="yyyy-MM-dd"/> https://exslt.github.io/date/functions/format-date/index.html https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html .if \n(.g .ds T> \\F[\n[.fam]] .color .de URL \\$2 \(la\\$1\(ra\\$3 .. .if \n(.g .mso www.tmac .TH " " "" "" .SH NAME \- .SH SYNOPSIS 'nh .fi .ad l 'in \fB\m[green] \m[]\fR \kx .if (\nx>(\n(.l/2)) .nr x (\n(.l/5) 'in \n(.iu+\nxu .br ... \m[green] \m[] ... | [ ] { } \m[green] \m[] \fI\m[green] \m[]\fR .SH .TP \*(T<\fB \fR\*(T> .TQ \m[green] \m[] .RS .RE .TP 0.2i \(bu , .PP .nf .in +7 .in .fi \m[blue] \m[] \(lB\fI\m[blue] \m[]\fR\(rB \fB \fR \[lq] \[rq] \m[blue] \fB man: \fR \m[] ( ) \fI \fR \fI\m[blue] \m[]\fR \fI\m[cyan] \m[]\fR \fB\m[green] \fR\m[] \fB\m[green] \m[]\fR \fB \fR \fI \fR \fB \fR \fB\m[red] \m[]\fR <\m[blue] \m[]> \fB \fR DocBook element unrecognized by XSLT: < > collapse trim collapse trim db2m:escape-car() function called with more than 1 character. \[rs] \[dq] db2m:has-only-whitespace() function called with non-text node. ================================================ FILE: man/flashmq-docbook5-refentry-to-html5.xsl ================================================ numeric alphanumeric
    ≥ v
    ================================================ FILE: man/flashmq-docbook5-refentry-to-manpage.xsl ================================================ \m[blue] \m[] \m[yellow] \m[] { \m[green] \m[] \m[cyan] \m[magenta] \m[blue] \m[] \m[default] \m[] ================================================ FILE: man/flashmq.1 ================================================ .if \n(.g .ds T< \\FC .if \n(.g .ds T> \\F[\n[.fam]] .color .de URL \\$2 \(la\\$1\(ra\\$3 .. .if \n(.g .mso www.tmac .TH flashmq 1 "Jan 17 2026" "" "" .SH NAME flashmq \- A fast light-weight scalable MQTT server .SH SYNOPSIS 'nh .fi .ad l \fB\m[green]flashmq\m[]\fR \kx .if (\nx>(\n(.l/2)) .nr x (\n(.l/5) 'in \n(.iu+\nxu [\fB-c\fR | \fB--config-file\fR \fI\m[green]config_file_path\m[]\fR] | [\fB-t\fR | \fB--test-config\fR] | .br [\fB-h\fR | \fB--help\fR] | .br [\fB-v\fR | \fB--version\fR] | .br [\fB-l\fR | \fB--license\fR] .SH DESCRIPTION FlashMQ is a MQTT 3.1, 3.1.1 and 5 broker designed to be light-weight and handle millions of clients and/or messages. .SH SIGNALS .TP \*(T<\fB\fB\m[red]SIGHUP\m[]\fR\fR\*(T> Reload config file and reopen log files. Listeners are recreated. Bridges are reconnected (when their config has changed), added or removed as necessary. .TP \*(T<\fB\fB\m[red]SIGUSR1\m[]\fR\fR\*(T> Reopen log files. Use this in log rotation. .TP \*(T<\fB\fB\m[red]SIGUSR2\m[]\fR\fR\*(T> It has been observed that certain implementations of malloc have a high memory use while still having many free blocks. See for instance the libc mailing list discussion \m[blue]\[lq]Abnormal memory usage with glibc 2.31 related to thread cache and trimming strategy\[rq]\m[] \(lB\fI\m[blue]https://sourceware.org/pipermail/libc-help/2020-September/005457.html\m[]\fR\(rB. This can be exacerbated by continuous thread recreation, because of how "memory arenas" are managed. FlashMQ has a fixed number of threads, but that may not be true for loaded plugins. Sending a \fB\m[red]SIGUSR2\m[]\fR will cause FlashMQ to call \fImalloc_trim(0)\fR, possibly resulting in it giving memory back to the operating system. The action is mostly pretty fast, but if not, the main loop will block during the operation, blocking the ability to accept new connections. The worker threads themselves will keep running and keep serving clients. Use, or need, of this feature on a regular basis is questionable, but at least it can help in a pinch. .SH COMMAND-LINE ARGUMENTS .TP \*(T<\fB\m[green]--help\m[] | \m[green]-h\m[]\fR\*(T> Print help with synopsis. .TP \*(T<\fB\m[green]--version\m[] | \m[green]-v\m[]\fR\*(T> Print FlashMQ version details. .TP \*(T<\fB\m[green]--license\m[] | \m[green]-l\m[]\fR\*(T> Print FlashMQ license. .TP \*(T<\fB\m[green]--config-file\m[] | \m[green]-c\m[] \fI\m[cyan]config_file_path\m[]\fR\fR\*(T> \fB\m[green]flashmq\m[]\fR will read the config file from the given \fI\m[cyan]config_file_path\m[]\fR. Without this option, the default \fI\m[cyan]config_file_path\m[]\fR is \fI\m[blue]/etc/flashmq/flashmq.conf\m[]\fR. See the \m[blue]\fBflashmq.conf\fR(5)\m[] manual page for the format of this file. .TP \*(T<\fB\m[green]--test-config\m[] | \m[green]-t\m[]\fR\*(T> Test the configuration, without starting the daemon. .SH AUTHOR Wiebe Cazemier <\m[blue]contact@flashmq.org\m[]>. .SH SEE ALSO \m[blue]man:flashmq.conf\fR(5)\m[], \m[blue]https://www.flashmq.org/\m[] .SH COLOPHON The sources for the FlashMQ manual pages are maintained in \m[blue]DocBook 5.2\m[] \(lB\fI\m[blue]https://tdg.docbook.org/tdg/5.2/\m[]\fR\(rB XML files. The transformation to the multiple destination file formats is done using a bunch of XSLT 1.0 sheets, contributed to this project by Rowan van der Molen. The groff source of this man-page has ANSI-color support for the terminal. However, Debian-derived Linux distributions turn off groff color support by default. To override this, set the \fBGROFF_SGR\fR environment variable to \fI1\fR. ================================================ FILE: man/flashmq.1.dbk5 ================================================ https://www.flashmq.org/man/flashmq.1 flashmq 1 flashmq A fast light-weight scalable MQTT server flashmq -c --config-file config_file_path -t --test-config -h --help -v --version -l --license Description FlashMQ is a MQTT 3.1, 3.1.1 and 5 broker designed to be light-weight and handle millions of clients and/or messages. Signals SIGHUP Reload config file and reopen log files. Listeners are recreated. Bridges are reconnected (when their config has changed), added or removed as necessary. SIGUSR1 Reopen log files. Use this in log rotation. SIGUSR2 It has been observed that certain implementations of malloc have a high memory use while still having many free blocks. See for instance the libc mailing list discussion Abnormal memory usage with glibc 2.31 related to thread cache and trimming strategy. This can be exacerbated by continuous thread recreation, because of how "memory arenas" are managed. FlashMQ has a fixed number of threads, but that may not be true for loaded plugins. Sending a SIGUSR2 will cause FlashMQ to call malloc_trim(0), possibly resulting in it giving memory back to the operating system. The action is mostly pretty fast, but if not, the main loop will block during the operation, blocking the ability to accept new connections. The worker threads themselves will keep running and keep serving clients. Use, or need, of this feature on a regular basis is questionable, but at least it can help in a pinch. Command-line arguments | Print help with synopsis. | Print FlashMQ version details. | Print FlashMQ license. | config_file_path flashmq will read the config file from the given config_file_path. Without this option, the default config_file_path is /etc/flashmq/flashmq.conf. See the flashmq.conf5 manual page for the format of this file. | Test the configuration, without starting the daemon. Author Wiebe Cazemier contact@flashmq.org. See also flashmq.conf 5 https://www.flashmq.org/ ================================================ FILE: man/flashmq.1.html ================================================ flashmq (1) – A fast light-weight scalable MQTT server

    flashmq (1)

    A fast light-weight scalable MQTT server

    Synopsis

    flashmq [-c | --config-file config_file_path] [-t | --test-config] |
         [-h | --help] |
         [-v | --version] |
         [-l | --license]

    Description

    FlashMQ is a MQTT 3.1, 3.1.1 and 5 broker designed to be light-weight and handle millions of clients and/or messages.

    Signals#

    SIGHUP#

    Reload config file and reopen log files. Listeners are recreated. Bridges are reconnected (when their config has changed), added or removed as necessary.

    SIGUSR1#

    Reopen log files. Use this in log rotation.

    SIGUSR2#

    It has been observed that certain implementations of malloc have a high memory use while still having many free blocks. See for instance the libc mailing list discussion Abnormal memory usage with glibc 2.31 related to thread cache and trimming strategy. This can be exacerbated by continuous thread recreation, because of how "memory arenas" are managed. FlashMQ has a fixed number of threads, but that may not be true for loaded plugins. Sending a SIGUSR2 will cause FlashMQ to call malloc_trim(0), possibly resulting in it giving memory back to the operating system.

    The action is mostly pretty fast, but if not, the main loop will block during the operation, blocking the ability to accept new connections. The worker threads themselves will keep running and keep serving clients.

    Use, or need, of this feature on a regular basis is questionable, but at least it can help in a pinch.

    Command-line arguments#

    --help | -h#

    Print help with synopsis.

    --version | -v#

    Print FlashMQ version details.

    --license | -l#

    Print FlashMQ license.

    --config-file | -cconfig_file_path#

    flashmq will read the config file from the given config_file_path.

    Without this option, the default config_file_path is /etc/flashmq/flashmq.conf.

    See the flashmq.conf(5) manual page for the format of this file.

    --test-config | -t#

    Test the configuration, without starting the daemon.

    Author

    Wiebe Cazemier contact@flashmq.org.

    See also

    flashmq.conf(5) , https://www.flashmq.org/

    Colophon#

    The sources for the FlashMQ manual pages are maintained in DocBook 5.2 XML files. The transformation to the multiple destination file formats is done using a bunch of XSLT 1.0 sheets, contributed to this project by Rowan van der Molen.

    ================================================ FILE: man/flashmq.conf.5 ================================================ .if \n(.g .ds T< \\FC .if \n(.g .ds T> \\F[\n[.fam]] .color .de URL \\$2 \(la\\$1\(ra\\$3 .. .if \n(.g .mso www.tmac .TH flashmq.conf 5 "Jan 28 2026" "" "" .SH NAME flashmq.conf \- FlashMQ configuration file format .SH SYNOPSIS 'nh .fi The \fI\m[blue]flashmq.conf\m[]\fR file is the configuration used for configuring the FlashMQ MQTT broker. .SH CONFIG LOCATION By default, the \fB\m[green]flashmq\m[]\fR daemon expects to find its configuration file at \fI\m[blue]/etc/flashmq/flashmq.conf\m[]\fR, but this can be overriden using the \fB\m[green]--config-file\fR\m[] command-line argument; see the \m[blue]\fBflashmq\fR(1)\m[] man page for details. Using the \fB\fB\m[green]include_dir\fR\m[]\fR parameter in your config file, you can load all the \fI\m[blue]*.conf\m[]\fR files from that given directory. .SH FILE FORMAT To set a parameter, its name must appear on a single line, followed by one or more potentially quoted arguments. .PP .nf .in +7 \m[green]parameter-name1 \m[]\m[cyan]parameter-value\m[] \m[green]parameter-name2 \m[]\m[magenta]'value with spaces'\m[default] \m[green]parameter-name3 \m[]\m[magenta]\[dq]escaped \m[blue]\[rs]\[dq]\m[] quote\[dq]\m[default] \m[green]multi-value-param \m[]\m[cyan]one\m[] \m[magenta]'two'\m[default] \m[magenta]\[dq]three\[dq]\m[default] \m[magenta]\[dq]with ' char\[dq]\m[default] \m[green]\m[] .in .fi Quoted values are the same as unquoted values when they don't need it. They are necessary for when argument values have spaces, for instance. When setting boolean values, \fIyes\fR/\fIno\fR, \fItrue\fR/\fIfalse\fR and \fIon\fR/\fIoff\fR can all be used. To configure the listeners, use \fB\m[green]listen\fR\m[] blocks, defined by \fI{\fR and \fI}\fR. See \fBEXAMPLE LISTENERS\fR for details. Lines beginning with the hash character (“\fI#\fR”) and empty lines are ignored. Thus, a line can be commented out by prepending a “\fI#\fR” to it. .SH GLOBAL PARAMETERS .TP \*(T<\fB\m[green]plugin\m[] \fI\m[blue]/path/to/plugin.so\m[]\fR\fR\*(T> FlashMQ supports an ELF shared object (\fI\m[blue].so\m[]\fR file) plugin interface to add functionality, authorization and authentication, because it’s hard to provide a declarative mechanism that works for everybody. See \fI\m[blue]flashmq_plugin.h\m[]\fR for the API and its documentation. It’s written in C++ for ease of passing FlashMQ internals without conversion to C, but you can basically just use a C++ compiler and program like it was C; the C++ constructs are simple. FlashMQ will auto-detect which plugin interface you’re trying to load (Mosquitto version 2 or FlashMQ native). Keep in mind that each thread initializes the plugin, inline with multi-core programming (minimize shared data and interaction between threads). You could use static variables with thread synchronization if you really want to. And of course, any Mosquitto plugin that uses global and/or static variables instead of initializing memory in its \fBinit()\fR method, will not be thread-safe and won’t work. You can only have one plugin active, but you can combine it with \fB\fB\m[green]mosquitto_password_file\fR\m[]\fR and \fB\fB\m[green]mosquitto_acl_file\fR\m[]\fR. The password and ACL file take precedence, and on a ‘deny’, will not ask the plugin. .TP \*(T<\fB\m[green]plugin_opt_*\m[] \fI\m[cyan]value\m[]\fR\fR\*(T> Options passed to the plugin \fBinit()\fR function. .TP \*(T<\fB\m[green]plugin_serialize_init\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> There could be circumstances where the plugin code is mostly thread-safe, but not on initialization. Libmysqlclient for instance, needs a one-time initialization. To add to the confusion, Qt hides that away. The plugin should preferrably be written with proper synchronization like that, but as a last resort, you can use this to synchronize initialization. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]plugin_serialize_auth_checks\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> Like \fB\m[green]plugin_serialize_init\fR\m[], but then for all login and ACL checks. This option may be dropped at some point, because it negates much of the multi-core design. One may as well run with only one thread then. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]plugin_timer_period\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> The FlashMQ auth plugin interface has an optional function that is called periodically this amount of seconds. This can be used to refresh state, commit data, etc. Setting a value of 0 disables it. You can enable and disable this timer with a config reload. See \fI\m[blue]flashmq_plugin.h\m[]\fR for details. Default value: \fI60\fR .TP \*(T<\fB\m[green]log_file\m[] \fI\m[blue]/path/to/flashmq.log\m[]\fR\fR\*(T> This configuration parameter sets the path to FlashMQ's log file. If you omit this option from the config file, the output will go to stdout. .TP \*(T<\fB\m[green]log_level\m[] \fI\m[cyan]debug\m[]\fR|\fI\m[cyan]info\m[]\fR|\fI\m[cyan]notice\m[]\fR|\fI\m[cyan]warning\m[]\fR|\fI\m[cyan]error\m[]\fR|\fI\m[cyan]none\m[]\fR\fR\*(T> Set the log level to specified level and above. That means \fI\m[cyan]notice\m[]\fR will log \fI\m[cyan]notice\m[]\fR, \fI\m[cyan]warning\m[]\fR and \fI\m[cyan]error\m[]\fR. Use this setting over the deprecated \fB\m[green]log_debug\fR\m[] and \fB\m[green]quiet\fR\m[]. If you do have those directives, they override the \fB\m[green]log_level\fR\m[], for backwards compatability reasons. Default value: \fIinfo\fR. .TP \*(T<\fB\m[green]log_debug\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> Debug logging obviously creates a lot of log noise, so should only be done to diagnose problems. Deprecated. Use \fB\m[green]log_level\fR\m[] instead. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]log_subscriptions\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> Default value: \fIfalse\fR .TP \*(T<\fB\m[green]log_publishes\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> Both publishes and QoS actions related to publishes are logged at this level. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]allow_unsafe_clientid_chars\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> If you have topics with client IDs in it, people can possibly manipulate your ACL checking by saying their client ID is 'John+foobar'. Audit your security before you allow this. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]allow_unsafe_username_chars\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> If you have topics with usernames in it, people can possibly manipulate your ACL checking by saying their username is 'John+foobar'. Audit your security before you allow this. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]max_string_length\m[] \fI\m[cyan]bytes\m[]\fR\fR\*(T> The max length of almost all strings encoded in MQTT packets. This includes topics, user properties, client IDs, etc. Some fields, like passwords, are excluded. Payloads are also excluded (as they are not strings). The reason is to avoid abuse. The default should be enough for most deployments. If the limit is tripped, the error in the log will clearly say so, and mention \fB\m[green]max_string_length\fR\m[]. Default value: \fI4096\fR .TP \*(T<\fB\m[green]max_packet_size\m[] \fI\m[cyan]bytes\m[]\fR\fR\*(T> MQTT packets have a maximum size of about 256 MB. This memory will (temporarily) be allocated upon arrival of such packets, so there may be cause to set it lower. This option works in conjunction with \fB\m[green]client_max_write_buffer_size\fR\m[] to limit memory use. Default value: \fI268435461\fR .TP \*(T<\fB\m[green]client_max_write_buffer_size\m[] \fI\m[cyan]bytes\m[]\fR\fR\*(T> The client's write buffer is where packets are stored before the event loop has the chance to flush them out. Any time a client's connection is bad and bytes can't be flushed, this buffer fills. So, there's good reason to limit this to something sensible. A good indication is your average packet size (or better yet, a configured \fB\m[green]max_packet_size\fR\m[]) multiplied by the amount of packets you want to be able to buffer. Despite the name, this settings also control the read buffer's maximum size. Having a bigger read buffer may be necessary when you know your clients receive many packets. Note that it's an approximate value and not a hard limit. Buffer sizes only grow by powers of two, and buffers are always allowed to grow to make place for ping packets. Additionally, upon arrival of large packets (up to \fB\m[green]max_packet_size\fR\m[] bytes), room will be made up to twice their size. So, you may also want to reduce \fB\m[green]max_packet_size\fR\m[] from the default. Default value is \fI1048576\fR (1 MB) .TP \*(T<\fB\m[green]client_initial_buffer_size\m[] \fI\m[cyan]bytes\m[]\fR\fR\*(T> The buffers for reading and writing, also for websockets when relevant, start out with a particular size and double when they need to grow. If you know your clients send bulks of a particular size, it helps to set this to match, to avoid constant memory reallocation. The default value is set conservatively, for scenario's with millions of clients. After buffers have grown, they are eventually reset to their original size when possible. Also see \fB\m[green]client_max_write_buffer_size\fR\m[] and \fB\m[green]max_packet_size\fR\m[]. Value must be a power of two. Default value: \fI1024\fR .TP \*(T<\fB\m[green]mosquitto_password_file\m[] \fI\m[blue]/foo/bar/mosquitto_password_file\m[]\fR\fR\*(T> File with usernames and hashed+salted passwords as generated by Mosquitto's \fB\m[green]mosquitto_passwd\m[]\fR. Mosquitto up to version 1.6 uses the sha512 algorithm. Newer version use sha512-pbkdf2. Both are supported. .TP \*(T<\fB\m[green]mosquitto_acl_file\m[] \fI\m[blue]/foo/bar/mosquitto_acl_file\m[]\fR\fR\*(T> ACL (access control lists) for users, anonymous users and patterns expandable with %u (username) and %c (clientid). Format is Mosquitto's acl_file. .TP \*(T<\fB\m[green]allow_anonymous\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> This option can be overriden on a per-listener basis; see \fB\fB\m[green]listener.allow_anonymous\fR\m[]\fR. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]zero_byte_username_is_anonymous\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> The proper way to signal an anonymous client is by setting the 'username present' flag in the CONNECT packet to 0, which in MQTT3 also demands the absence of a password. However, there are also clients out there that set the 'username present' flag to 1 and then give an empty username. This is an undesirable situation, because it means there are two ways to identify an anonymous client. Anonymous clients are not authenticated against a loaded plugin when \fB\m[green]allow_anonymous\fR\m[] is true. With this option enabled, that means users with empty string as usernames also aren't. With this option disabled, clients connecting with an empty username will be reject with 'bad username or password' as MQTT error code. The default is to be unambigious, but this can be overridden with this option. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]rlimit_nofile\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> The general Linux default of \fI1024\fR can be overridden. Note: \fIsystemd\fR blocks you from setting it, so it needs to be set on the unit. The default systemd unit file sets \fB\m[green]LimitNOFILE=infinity\fR\m[]. You may also need to set \fB\m[green]sysctl -w fs.file-max=10000000\fR\m[] Default value: \fI1000000\fR .TP \*(T<\fB\m[green]expire_sessions_after_seconds\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> Expire sessions after this time. Setting to 0 disables it and is (MQTT3) standard-compliant. But, existing sessions cause load on the server (because they cost memory and are still subscribers), so keeping sessions after any client that connects with a random ID doesn't make sense. Default value: \fI1209600\fR .TP \*(T<\fB\m[green]quiet\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> Don't log LOG_INFO and LOG_NOTICE. This is useful when you have a lot of foot traffic, because otherwise the log gets filled with connect/disconnect notices. Deprecated. Use \fB\m[green]log_level\fR\m[] instead. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]storage_dir\m[] \fI\m[blue]/path/to/dir\m[]\fR\fR\*(T> Location to store sessions, subscriptions and retained messages. Not specifying this will turn off persistence. .TP \*(T<\fB\m[green]save_state_interval\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> The interval at which the state is saved, if enabled with \fB\m[green]storage_dir\fR\m[]. This setting is also applied on reload. Default: 3623 .TP \*(T<\fB\m[green]persistence_data_to_save\m[] \fI\m[cyan]all|sessions_and_subscriptions|retained_messages|bridge_info\m[]\fR\fR\*(T> When a \fB\m[green]storage_dir\fR\m[] is defined, specify which information should be saved. The arguments are a list of one or more of the specified options, and can be negated with a !. Meaning like: \fIall !retained_messages\fR. The order is relevant. If you end with \fIall\fR, it overrides the previous. The option \fIbridge_info\fR causes a small bit of meta data of the bridges to be saved. See the bridge option \fB\fB\m[green]use_saved_clientid\fR\m[]\fR. The default is \fIall\fR, but as soon as this option is defined, it's cleared, and the desired list has to be composed. .TP \*(T<\fB\m[green]max_qos_msg_pending_per_client\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> .TQ \*(T<\fB\m[green]max_qos_bytes_pending_per_client\m[] \fI\m[cyan]bytes\m[]\fR\fR\*(T> There is a limit to how many QoS packets can be stored in a session, so you can define a maximum amount of messages and bytes. If any of these is exceeded, the packet is dropped. Note that changing \fB\m[green]max_qos_msg_pending_per_client\fR\m[] only takes effect for new clients (also when picking up existing sessions). This is largely due to it being part of the MQTT5 connection handshake and is supposed to be adhered to. Defaults: .RS .TP 0.2i \(bu max_qos_msg_pending_per_client 512 .TP 0.2i \(bu max_qos_bytes_pending_per_client 65536 .RE .TP \*(T<\fB\m[green]max_qos\m[] \fI\m[cyan]qos_value\m[]\fR\fR\*(T> The maximum QoS value FlashMQ will allow clients to use for subscriptions, publishes and wills. Subscriptions will be downgraded to this value. Publishes and wills will cause a disconnect for MQTT5 clients, and the action configured with \fB\fB\m[green]mqtt3_qos_exceed_action\fR\m[]\fR for MQTT3 clients. The value is updated on config reload, but only for new clients. The reason is that the max QoS is part of the handshake in MQTT5, and clients will be surprised (and cause disconnects) if that changes during the lifetime of a connection. For consistency, the MQTT3 clients behave the same. This setting is also available per listener. Default value: \fI2\fR .TP \*(T<\fB\m[green]mqtt3_qos_exceed_action\m[] \fI\m[cyan]drop|disconnect\m[]\fR\fR\*(T> Unlike MQTT5, MQTT3 doesn't have defined behavior for exceeding a configured maximum QoS (see \fB\fB\m[green]max_qos\fR\m[]\fR ). This allows the behavior to be controlled. In case of \fIdrop\fR the client gets an acknowledgement as if the publish has succeeded, to avoid stale in-flight packets. This setting is also available per listener. Default value: \fIdisconnect\fR .TP \*(T<\fB\m[green]max_incoming_topic_alias_value\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> Is communicated towards MQTT5 clients. It is then up to them to decide to set them or not. Changing this setting and reloading the config only has effect on new clients, because existing clients would otherwise exceed the limit they think applies. Default value: \fI65535\fR .TP \*(T<\fB\m[green]max_outgoing_topic_alias_value\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> FlashMQ will make this many aliases per MQTT5 client, if they ask for aliases (with the connect property \fB\m[green]TopicAliasMaximum\fR\m[]). Default value: \fI65535\fR .TP \*(T<\fB\m[green]thread_count\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> If you want to have a different amount of worker threads then CPUs, you can set this value. Typically you don't need to set this. Default value: \fI\m[blue]auto-detect\m[]\fR .TP \*(T<\fB\m[green]wills_enabled\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> When disabled, the server will not set last will and testament specified by connecting clients. Default value: \fI\m[blue]true\m[]\fR .TP \*(T<\fB\m[green]retained_messages_mode\m[] \fI\m[cyan]enabled\m[]\fR|\fI\m[cyan]enabled_without_persistence\m[]\fR|\fI\m[cyan]downgrade\m[]\fR|\fI\m[cyan]drop\m[]\fR|\fI\m[cyan]disconnect_with_error\m[]\fR\fR\*(T> Retained messages can be a strain on the server you may not need. You can set various ways of dealing with them: \fI\m[blue]enabled\m[]\fR. This is normal operation. \fI\m[blue]enabled_without_persistence\m[]\fR. Like 'normal', except it won't store them to disk if \fB\m[green]storage_dir\fR\m[] is defined. \fI\m[blue]enabled_without_retaining\m[]\fR. This somewhat counter-intuitive sounding mode is like \fB\m[green]downgrade\fR\m[], except that the 'retain' flag is not removed. This allows MQTT5 subscribers that subscribe with 'retain as published' to see which messages were originally sent as retained. It's just that FlashMQ won't retain them. \fI\m[blue]downgrade\m[]\fR. The retain flag is removed and treated like a normal publish. \fI\m[blue]drop\m[]\fR. Messages with retain set are dropped. \fI\m[blue]disconnect_with_error\m[]\fR. Disconnect clients who try to set them. Default value: \fI\m[blue]enabled\m[]\fR .TP \*(T<\fB\m[green]expire_retained_messages_after_seconds\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> Use this to limit the life time of retained messages. Without this, the amount of retained messages may never decrease. Default value: \fI\m[blue]4294967295\m[]\fR .TP \*(T<\fB\m[green]retained_messages_delivery_limit\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> Deprecated. .TP \*(T<\fB\m[green]retained_messages_node_limit\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> When clients place a subscription, they will get the retained messages matching that subscription. Even though traversing the retained message tree is deprioritized in favor of other traffic, it will still cause CPU load until it's done. If you have a tree with millions of nodes and clients subscribe to \fI#\fR, this is potentially unwanted. You can use this setting to limit how many nodes of the retrained tree are traversed. Note that the topic \fIone/two/three\fR is three nodes, and each node doesn't necessarilly need to contain a message. Default value: \fI\m[blue]4294967295\m[]\fR .TP \*(T<\fB\m[green]set_retained_message_defer_timeout\m[] \fI\m[cyan]milliseconds\m[]\fR\fR\*(T> The time after which FlashMQ will fall back to (b)locking vs queued mode for setting retained messages. 0, the default, disables queued mode altogether. It's disabled by default because it can incur some extra CPU and memory overhead. Each retained message lives in a node in a tree. The topic 'one/two/three' is three nodes. When a node in that tree does not exist yet, it needs to be created. This requires a write lock on the tree. At this point, other threads reading from or writing to the retained message tree need to wait. This can cause a compounding blocking effect, especially if many threads do it at once. This feature is to favor server responsiveness vs the speed at which retained messages become available in the server. It is primarily useful for when you have a lot of retained messages on different/changing topics. If at first a retained message can't be set, the action to do so will be retried in the event loop, asynchronously. This setting determines the maximum amount of time to defer setting a retained message, after which it will fall back to using locks. Also see \fB\m[green]set_retained_message_defer_timeout_spread\fR\m[] Default value: \fI\m[blue]0\m[]\fR .TP \*(T<\fB\m[green]set_retained_message_defer_timeout_spread\m[] \fI\m[cyan]milliseconds\m[]\fR\fR\*(T> For \fB\m[green]set_retained_message_defer_timeout\fR\m[], the amount of random spread between 0 and this value for the timeout. This spreads out locking over time, reducing contention. Default value: \fI\m[blue]1000\m[]\fR .TP \*(T<\fB\m[green]retained_message_node_lifetime\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> The grace period after which a retained message node is eligible for deletion. The topic 'one/two/three' is three nodes, and if that topic had a message, it would be contained in 'three'. FlashMQ will periodically clear out retained message nodes that have no message anymore. This is required to save memory. But, when you receive retained messages on the same topics repeatedly, it may be beneficial to keep the nodes around, to avoid the need for locks to recreate them. If you know that retained messages come and go within a certain period, it's benificial to set this value so that no unnecessary node destruction and creation takes place. Default value: \fI\m[blue]0\m[]\fR .TP \*(T<\fB\m[green]subscription_node_lifetime\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> The grace period after which a subscription node is eligible for deletion. The subscription 'one/two/three' is three nodes. FlashMQ will periodically clear our nodes in the subscription tree that have no entries anymore. This is required to save memory. But, when clients place the same subscriptions repeatedly, it may be beneficial to keep the nodes around, to avoid the need for locks to recreate them. If you know that certain subscription patterns come and go within a certain period, it's benificial to set this value so that no unnecessary node destruction and creation takes place. Default value: \fI\m[blue]3600\m[]\fR .TP \*(T<\fB\m[green]websocket_set_real_ip_from\m[] \fI\m[cyan]inet4_address\m[]\fR|\fI\m[cyan]inet6_address\m[]\fR\fR\*(T> HTTP proxies in front of the websocket listeners can set the \fI\m[cyan]X-Real-IP\m[]\fR header to identify the original connecting client. With \fB\m[green]websocket_set_real_ip_from\fR\m[] you can mark IP networks as trusted. By default, clients are not trusted, to avoid spoofing. You can repeat the option to allow for multiple addresses. Valid notations are \fI\m[cyan]1.2.3.4\m[]\fR, \fI\m[cyan]1.2.3.4/16\m[]\fR, \fI\m[cyan]1.2.0.0/16\m[]\fR, \fI\m[cyan]2a01:1337::1\m[]\fR, \fI\m[cyan]2a01:1337::1/64\m[]\fR, etc. The header \fI\m[cyan]X-Forwarded-For\m[]\fR is not used, because that's designed to contain a list of addresses, if applicable. As a side note about using a proxy on your listener; you can only have an absolute max of 65535 connections to an IP+port combination (and the practical limit is lower). If you need more, you have to set up multiple listeners. This can be multiple IP addresses, or simply multiple ports. .TP \*(T<\fB\m[green]shared_subscription_targeting\m[] \fI\m[cyan]round_robin\m[]\fR|\fI\m[cyan]sender_hash\m[]\fR|\fI\m[cyan]first\m[]\fR\fR\*(T> When having multiple subscribers on a shared subscription (like '$share/myshare/jane/doe'), select how the messages should be distributed over the subscribers. \fI\m[cyan]round_robin\m[]\fR. Select the next subscriber for each message. There is still some amount of randomness to it because the counter for this is not thread safe. Using an atomic/mutexed counter for it would just be too slow to justify. \fI\m[cyan]sender_hash\m[]\fR. Selects a receiver deterministically based on the hash of the client ID of the sender. The selected subscriber will depend on how many subscribers there are, so if some disconnect, the distribution will change. Moreover, the selection may also change when FlashMQ cleans up empty spaces in the list of shared subscribers. \fI\m[cyan]first\m[]\fR. Selects the first subscriber in the list. This mode can be useful for fallback. When one client disappears, the other will seamlessly take over. Default: \fI\m[cyan]round_robin\m[]\fR .TP \*(T<\fB\m[green]minimum_wildcard_subscription_depth\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> Defines the minimum level of the first wildcard topic filter (\fB\m[green]#\fR\m[] and \fB\m[green]+\fR\m[]). In a topic filter like \fB\m[green]sensors/temperature/#\fR\m[], that is 2. If you specify 2, a subscription to \fB\m[green]sensors/#\fR\m[] will be denied. Remember that only MQTT 3.1.1 and newer actually notify the client of the denial in the sub-ack packet. The reason you may want to limit it, is performance. If you have a base message load of 100,000 messages per second, each client subscribing to \fB\m[green]#\fR\m[] causes that many permission checks per second. If you have 100 clients doing that, there will be 10 million permission checks per second. Default: \fI\m[cyan]0\m[]\fR .TP \*(T<\fB\m[green]max_topic_split_depth\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> Defines the maximum number of components a topic/filter string can have. For example, \fIone/two/three\fR has three. The reason is to avoid abuse. The default should be enough for most deployments. If the limit is tripped, the error in the log will clearly say so, and mention \fB\m[green]max_topic_split_depth\fR\m[]. Default: \fI\m[cyan]128\m[]\fR .TP \*(T<\fB\m[green]wildcard_subscription_deny_mode\m[] \fI\m[cyan]deny_all\m[]\fR|\fI\m[cyan]deny_retained_only\m[]\fR\fR\*(T> For \fB\m[green]minimum_wildcard_subscription_depth\fR\m[], specify what you want to deny. Trying to give a client all retained messages can cause quite some load, so only denying the retained messages upon receiving a broad wildcard subscription can be useful if you have a low enough general message volume, but a high number of retained messages. Default: \fIdeny_all\fR .TP \*(T<\fB\m[green]overload_mode\m[] \fI\m[cyan]log\m[]\fR|\fI\m[cyan]close_new_clients\m[]\fR\fR\*(T> Define the action to perform when the value defined with \fB\m[green]max_event_loop_drift\fR\m[] is exceeded. When a server is (re)started, and hundreds of thousands of clients connect, the SSL handshaking and authenticating can be so heavy that it doesn't get to clients in time. They will then reconnect and try again, and get stuck in a loop. This option is to mitigate that. With \fIclose_new_clients\fR, new clients will be closed immediately after connecting while the server is overloaded. This will allow the worker threads to process the new clients in a controlled manner. For really large deployments, this can be augmented with extra rate limiting in iptables, or other firewalls. A stateless method is preferred, like: \fIiptables -I INPUT -p tcp -m multiport --dports 8883,1883 --syn -m hashlimit --hashlimit-name newmqttconns --hashlimit-above 10000/second --hashlimit-burst 15000 -j DROP\fR The current default is \fIlog\fR, but that will likely change in the future. Default: \fIlog\fR .TP \*(T<\fB\m[green]max_event_loop_drift\m[] \fI\m[cyan]milliseconds\m[]\fR\fR\*(T> For \fB\m[green]overload_mode\fR\m[], the maximum permissible thread drift before the overload action is taken. The drift values considered are those of the main loop, in which clients are accepted, and the median of all worker threads. Default: \fI2000\fR .TP \*(T<\fB\m[green]include_dir\m[] \fI\m[cyan]/path/to/dir\m[]\fR\fR\*(T> Load *.conf files from the specified directory, to merge with the main configuration file. An error is generated when the directory is not there. This is to protect against running incorrect configurations by accident, when the dir has been renamed, for example. .TP \*(T<\fB\m[green]subscription_identifiers_enabled\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> Subscription identifiers allow clients to see which subscription was responsible for a message. Publish messages will contain the identifier included in the original subscription. Enabling will prevent FlashMQ from using optimizations involving packet reuse, because the packets are unique per client when it contains a subscription identifier. Therefore you may want to assess the performance difference in high message volume deployments. As per the spec, clients sending subscription identifiers when the server reported the feature as unavailable will cause them to be disconnected. This has the side effect that changing this setting on a running server will disconnect clients when they send a subscription with an identifier in it. This was chosen as behavior over the alternatives, because of simplicity and operator control (otherwise it can't be turned off at all for existing clients). Default value: \fItrue\fR .SH LISTEN PARAMETERS Listen parameters can only be used within \fIlisten { }\fR blocks. .TP \*(T<\fB\m[green]port\m[]\fR\*(T> The default port depends on the \fB\m[green]protocol\fR\m[] parameter and whether or not \fB\m[green]fullchain\fR\m[] and \fB\m[green]privkey\fR\m[] parameters are supplied: .RS .TP 0.2i \(bu For unencrypted MQTT, the default port is \fI1883\fR .TP 0.2i \(bu For encrypted MQTT, the default port is \fI8883\fR .TP 0.2i \(bu For plain HTTP websockets, the default port is \fI8080\fR .TP 0.2i \(bu For encrypted HTTPS websockets, the default port is \fI4443\fR .RE .TP \*(T<\fB\m[green]protocol\m[] \fI\m[cyan]mqtt\m[]\fR|\fI\m[cyan]websockets\m[]\fR|\fI\m[cyan]acme\m[]\fR\fR\*(T> This is a required parameter. For \fIacme\fR, see \fBacme_redirect_url\fR. .TP \*(T<\fB\m[green]inet_protocol\m[] \fI\m[cyan]ip4_ip6\m[]\fR|\fI\m[cyan]ip4\m[]\fR|\fI\m[cyan]ip6\m[]\fR|\fI\m[cyan]unix\m[]\fR\fR\*(T> When using \fIunix\fR, a \fB\m[green]unix_socket_path\fR\m[] is required. Default: \fIip4_ip6\fR .TP \*(T<\fB\m[green]inet4_bind_address\m[] \fI\m[cyan]inet4address\m[]\fR\fR\*(T> Default: 0.0.0.0 .TP \*(T<\fB\m[green]inet6_bind_address\m[] \fI\m[cyan]inet6address\m[]\fR\fR\*(T> Default: ::0 .TP \*(T<\fB\m[green]unix_socket_path\m[] \fI\m[cyan]path\m[]\fR\fR\*(T> When using \fIunix\fR for \fB\m[green]inet_protocol\fR\m[], the file path of the socket. FlashMQ will remove pre-existing socket files if they already exist. .TP \*(T<\fB\m[green]unix_socket_user\m[] \fI\m[cyan]user\m[]\fR\fR\*(T> Use this to set the owner of the socket. It will always be attempted to set it, but a warning may be logged if not successful. Users may be specified as numeric or names. A config test will not verify the existence of users, for portability. .TP \*(T<\fB\m[green]unix_socket_group\m[] \fI\m[cyan]group\m[]\fR\fR\*(T> Use this to set the group of the socket. It will always be attempted to set it, but a warning may be logged if not successful. Groups may be specified as numeric or names. A config test will not verify the existence of groups, for portability. .TP \*(T<\fB\m[green]unix_socket_mode\m[] \fI\m[cyan]mode\m[]\fR\fR\*(T> Use this to specify the permission mode of the unix socket, like \fI600\fR. .TP \*(T<\fB\m[green]fullchain\m[] \fI\m[cyan]/foobar/server.crt\m[]\fR\fR\*(T> Specifying a chain makes the listener SSL, and also requires the \fB\m[green]privkey\fR\m[] to be set. .TP \*(T<\fB\m[green]privkey\m[] \fI\m[cyan]/foobar/server.key\m[]\fR\fR\*(T> Specifying a private key makes the listener SSL, and also requires the \fB\m[green]fullchain\fR\m[] to be set. .TP \*(T<\fB\m[green]minimum_tls_version\m[] \fI\m[cyan]tlsv1.1\m[]\fR|\fI\m[cyan]tlsv1.2\m[]\fR|\fI\m[cyan]tlsv1.3\m[]\fR\fR\*(T> Set minimum supported TLS version for TLS listeners. Note that setting this value low many not actually enable that protocol version if OpenSSL won't support it (anymore). The TLS version clients use is logged. Default: \fI\m[cyan]tlsv1.1\m[]\fR .TP \*(T<\fB\m[green]client_verification_ca_file\m[] \fI\m[cyan]/foobar/client_authority.crt\m[]\fR\fR\*(T> Clients can be authenticated using X509 certificates, and the username taken from the CN (common name) field. Use this directive to specify the certificate authority you trust. Specifying this or \fB\m[green]client_verification_ca_dir\fR\m[] will require the listener to be TLS. .TP \*(T<\fB\m[green]client_verification_ca_dir\m[] \fI\m[cyan]/foobar/dir_with_certificates\m[]\fR\fR\*(T> Clients can be authenticated using X509 certificates, and the username taken from the CN (common name) field. Use this directive to specify the dir containing certificate authorities you trust. Note that the filename requirements are dictated by OpenSSL. Use the utility \fB\m[green]openssl rehash /path/to/dir\m[]\fR. Specifying this or \fB\m[green]client_verification_ca_file\fR\m[] will require the listener to be TLS. .TP \*(T<\fB\m[green]client_verification_still_do_authn\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> When using X509 client authentication with \fB\m[green]client_verification_ca_file\fR\m[] or \fB\m[green]client_verification_ca_dir\fR\m[], the username will not be checked with a user database or a plugin by default. Set this option to \fItrue\fR to override that. .TP \*(T<\fB\m[green]acme_redirect_url\m[] \fI\m[cyan]http://example.com/\m[]\fR\fR\*(T> This allows an ACME (automated certificate management environment) challenge to be redirected elsewhere. This allows decoupling of the certificate creation from the host(s) that run FlashMQ. This can either be configured on a dedicated listener with \fB\m[green]protocol\fR\m[] \fIacme\fR, or multiplexed on a non-SSL \fImqtt\fR or \fIwebsockets\fR listener. .TP \*(T<\fB\m[green]drop_on_absent_certificate\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> When both \fB\m[green]privkey\fR\m[] and \fB\m[green]fullchain\fR\m[] are absent, don't create this listener. This can help in situations where you don't have the certificate and key yet, but you are expecting them. .TP \*(T<\fB\m[green]only_allow_from\m[] \fI\m[cyan]inet4_address\m[]\fR|\fI\m[cyan]inet6_address\m[]\fR\fR\*(T> When set, restricts the listener to the source address/network given. You can repeat the option to allow multiple addresses/networks. Valid notations are \fI\m[cyan]1.2.3.4\m[]\fR, \fI\m[cyan]1.2.3.4/16\m[]\fR, \fI\m[cyan]1.2.0.0/16\m[]\fR, \fI\m[cyan]2a01:1337::1\m[]\fR, \fI\m[cyan]2a01:1337::1/64\m[]\fR, etc. .TP \*(T<\fB\m[green]deny_from\m[] \fI\m[cyan]inet4_address\m[]\fR|\fI\m[cyan]inet6_address\m[]\fR\fR\*(T> Block connections from this address or network. You can repeat the option multiple times. Valid notations are \fI\m[cyan]1.2.3.4\m[]\fR, \fI\m[cyan]1.2.3.4/16\m[]\fR, \fI\m[cyan]1.2.0.0/16\m[]\fR, \fI\m[cyan]2a01:1337::1\m[]\fR, \fI\m[cyan]2a01:1337::1/64\m[]\fR, etc. .TP \*(T<\fB\m[green]allow_anonymous\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> This allows you to override the \fBglobal \fB\m[green]allow_anonymous\fR\m[]\fR setting on the listener level. .TP \*(T<\fB\m[green]overload_mode\m[] \fI\m[cyan]log\m[]\fR|\fI\m[cyan]close_new_clients\m[]\fR\fR\*(T> This allows you to override the \fBglobal \fB\m[green]overload_mode\fR\m[]\fR setting on the listener level. .TP \*(T<\fB\m[green]haproxy\m[] \fI\m[cyan]off\m[]\fR|\fI\m[cyan]on\m[]\fR|\fI\m[cyan]client_verification\m[]\fR|\fI\m[cyan]client_verification_with_authn\m[]\fR\fR\*(T> The nature of running something behind a proxy is that you would lose certain information or abilities. For that, HAProxy allows sending information in a header frame. With the \fB\m[green]haproxy\fR\m[] option, you can configure the following: \fIoff\fR and \fIon\fR configure basic functionality: reading the client's source address from the haproxy frame. This address will be shown in the logs and plugin functions. \fIclient_verification\fR and \fIclient_verification_with_authn\fR will require haproxy to send the \fIPP2_SUBTYPE_SSL_CN\fR field when using mTLS (mutual TLS), i.e. client authentication. When using the \fIclient_verification_with_authn\fR option, FlashMQ's internal authentication will still happen (with the username from the certificate). This can be convenient when you have a plugin, so that you can still take action in the login hook. Make sure this listener is private, firewalled or use \fB\m[green]only_allow_from\fR\m[], otherwise anybody can set a different source address. Note that HAProxy's server health checks only started using the 'local' specifier as of version 2.4. This means earlier version will pretend to be a client and break the connection, causing log spam. As a side note about using a proxy on your listener; you can only have an absolute max of 65535 connections to an IP+port combination (and the practical limit is lower). If you need more, you have to set up multiple listeners. This can be multiple IP addresses, or simply multiple ports. The following is an example HAProxy backend config for the features described: .PP .nf .in +7 \m[green]backend \m[]\m[cyan]backend_server\m[] \m[green] mode \m[]\m[cyan]tcp\m[] \m[green] timeout \m[]\m[cyan]client\m[] \m[cyan]2m\m[] \m[green] timeout \m[]\m[cyan]connect\m[] \m[cyan]10s\m[] \m[green] timeout \m[]\m[cyan]server\m[] \m[cyan]2m\m[] \m[green] server \m[]\m[cyan]server1\m[] \m[cyan]::1:1885\m[] \m[cyan]send-proxy-v2\m[] \m[cyan]send-proxy-v2-ssl-cn\m[] .in .fi See \m[blue]haproxy.org\m[] \(lB\fI\m[blue]http://www.haproxy.org/\m[]\fR\(rB. .TP \*(T<\fB\m[green]tcp_nodelay\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> \fB\m[green]tcp_nodelay\fR\m[] will cause the \fITCP_NODELAY\fR option to be set for the listener's socket(s), and therefore for all clients accepted on that listener. \fITCP_NODELAY\fR is a OS TCP-layer option that will cause messages written by FlashMQ to the socket to be flushed immediately, without letting Nagle's algorithm (the default) collect small outgoing TCP packets into bigger packets. Foregoing Nagle's algorithm by setting \fB\m[green]tcp_nodelay\fR\m[] to \fI\m[cyan]true\m[]\fR \fBmay\fR decrease latency, at the likely cost of some network efficiency. Default: \fI\m[cyan]false\m[]\fR .TP \*(T<\fB\m[green]max_buffer_size\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> Override the \fB\m[green]client_max_write_buffer_size\fR\m[] for this listener. This is especially useful when this listener is the receiving side of a bridge, because these clients will likely see more traffic. .TP \*(T<\fB\m[green]max_qos\m[] \fI\m[cyan]qos_value\m[]\fR\fR\*(T> Per-listener counterpart of \fB\fB\m[green]max_qos\fR\m[]\fR. .TP \*(T<\fB\m[green]mqtt3_qos_exceed_action\m[] \fI\m[cyan]drop|disconnect\m[]\fR\fR\*(T> Per-listener counterpart of \fB\fB\m[green]mqtt3_qos_exceed_action\fR\m[]\fR. .SH EXAMPLE LISTENERS .PP .nf .in +7 \m[yellow]listen \m[]{ \m[green] protocol \m[]\m[cyan]mqtt\m[] \m[green] inet_protocol \m[]\m[cyan]ip4_ip6\m[] \m[green] inet4_bind_address \m[]\m[cyan]127.0.0.1\m[] \m[green] inet6_bind_address \m[]\m[cyan]::1\m[] \m[green] fullchain \m[]\m[cyan]/foobar/server.crt\m[] \m[green] privkey \m[]\m[cyan]/foobar/server.key\m[] \m[blue] # default = 8883\m[] \m[green] port \m[]\m[cyan]8883\m[] \m[green]\m[]} \m[yellow]listen \m[]{ \m[green] protocol \m[]\m[cyan]mqtt\m[] \m[green] fullchain \m[]\m[cyan]/foobar/server.crt\m[] \m[green] privkey \m[]\m[cyan]/foobar/server.key\m[] \m[green] client_verification_ca_file \m[]\m[cyan]/foobar/client_authority.crt\m[] \m[green] client_verification_still_do_authn \m[]\m[cyan]false\m[] \m[green]\m[]} \m[yellow]listen \m[]{ \m[green] protocol \m[]\m[cyan]mqtt\m[] \m[green] inet_protocol \m[]\m[cyan]ip4\m[] \m[blue] # default = 1883\m[] \m[green] port \m[]\m[cyan]1883\m[] \m[green]\m[]} \m[yellow]listen \m[]{ \m[green] protocol \m[]\m[cyan]websockets\m[] \m[green] fullchain \m[]\m[cyan]/foobar/server.crt\m[] \m[green] privkey \m[]\m[cyan]/foobar/server.key\m[] \m[blue] # default = 4443\m[] \m[green] port \m[]\m[cyan]4443\m[] \m[green]\m[]} \m[yellow]listen \m[]{ \m[green] protocol \m[]\m[cyan]websockets\m[] \m[blue] # default = 8080\m[] \m[green] port \m[]\m[cyan]8080\m[] \m[green]\m[]} \m[yellow]listen \m[]{ \m[green] port \m[]\m[cyan]2883\m[] \m[green] haproxy \m[]\m[cyan]on\m[] \m[green]\m[]} .in .fi .SH BRIDGE CONFIGURATION Bridges can be defined inside \fIbridge { }\fR blocks. A bridge is essentially just an outgoing connection to another server with loop-detection and retain flag relaying. It is not a form of clustering, although with careful design, it can be deployed to achieve some sort of load balancing. Note that normally (unless \fB\m[green]connection_count\fR\m[] is set) one bridge is one connection, and because FlashMQ's threading model is that clients are serviced by one selected thread only, a bridge has the potential to saturate a thread, if it's heavily loaded. You can improve that with \fB\m[green]connection_count\fR\m[]. Bridges are dynamically created, removed or changed upon config reload. When a bridge configuration changes, it will disconnect and reconnect. .TP \*(T<\fB\m[green]address\m[] \fI\m[cyan]address\m[]\fR\fR\*(T> The DNS name, IPv4 or IPv6 address of the server you want to connect to. .TP \*(T<\fB\m[green]port\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> The default port depends on the \fB\m[green]tls\fR\m[] option, either 1883 or 8883. .TP \*(T<\fB\m[green]inet_protocol\m[] \fI\m[cyan]ip4_ip6/ip4/ip6\m[]\fR\fR\*(T> Default: \fIip4_ip6\fR .TP \*(T<\fB\m[green]tls\m[] \fI\m[cyan]off/on/unverified\m[]\fR\fR\*(T> Set TLS mode. The value \fB\m[green]unverified\fR\m[] means the x509 chain is not verified. .TP \*(T<\fB\m[green]minimum_tls_version\m[] \fI\m[cyan]tlsv1.1\m[]\fR|\fI\m[cyan]tlsv1.2\m[]\fR|\fI\m[cyan]tlsv1.3\m[]\fR\fR\*(T> Set minimum supported TLS version the bridge will negotiate with the other side. Note that setting this value low many not actually enable that protocol version if OpenSSL won't support it (anymore). Default: \fI\m[cyan]tlsv1.1\m[]\fR .TP \*(T<\fB\m[green]fullchain\m[] \fI\m[cyan]/foobar/bridge.crt\m[]\fR\fR\*(T> With TLS enabled, specifying a chain makes the bridge connection authenticate to the remote broker using a public certificate, and also requires the \fB\m[green]privkey\fR\m[] to be set. .TP \*(T<\fB\m[green]privkey\m[] \fI\m[cyan]/foobar/bridge.key\m[]\fR\fR\*(T> With TLS enabled, specifying a private key makes the bridge connection to remote broker use that key, and also requires the \fB\m[green]fullchain\fR\m[] to be set. .TP \*(T<\fB\m[green]ca_file\m[] \fI\m[cyan]path\m[]\fR\fR\*(T> File to be used for x509 certificate chain validation. .TP \*(T<\fB\m[green]ca_dir\m[] \fI\m[cyan]path\m[]\fR\fR\*(T> Directory containing certificates for x509 certificate chain validation. .TP \*(T<\fB\m[green]protocol_version\m[] \fI\m[cyan]mqtt3.1\m[]\fR|\fI\m[cyan]mqtt3.1.1\m[]\fR|\fI\m[cyan]mqtt5\m[]\fR\fR\*(T> Default: \fImqtt3.1.1\fR .TP \*(T<\fB\m[green]bridge_protocol_bit\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> An unofficial standard is to set the most significant bit of the protocol version byte to 1 to signal the connection is a bridge. This allows the other side to alter its behavior slightly. However, this is not always supported, so you can disable this if you get disconnected for reporting an invalid protocol version. This setting has no effect when using MQTT5, because the behavior it influences is done with subscription options. Default: \fItrue\fR .TP \*(T<\fB\m[green]keepalive\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> The time between sending ping packets to the other side. Default: \fI60\fR .TP \*(T<\fB\m[green]clientid_prefix\m[] \fI\m[cyan]prefix\m[]\fR\fR\*(T> The prefix of the randomly generated client ID. Client IDs cannot be explicitely set for security reasons. See \m[blue]\[lq]Understanding clean session and clean start\[rq]\m[] \(lB\fI\m[blue]https://www.flashmq.org/2022/11/26/understanding-clean-session-and-clean-start/\m[]\fR\(rB. Default: \fIfmqbridge\fR .TP \*(T<\fB\m[green]publish\m[] \fI\m[cyan]filter\m[]\fR \fI\m[cyan]qos\m[]\fR\fR\*(T> Messages matching this filter will be published to the other side. Examples: \fI#\fR or \fIsport/tennis/#\fR. This option can be repeated several times. The QoS value should be seen as the QoS value of the internal subscription causing outgoing messages. Messages that are relayed have this QoS level at most. Default: \fI0\fR .TP \*(T<\fB\m[green]subscribe\m[] \fI\m[cyan]filter\m[]\fR \fI\m[cyan]qos\m[]\fR\fR\*(T> Subscriptions for this filter is placed at the other side. Examples: \fI#\fR or \fIsport/tennis/#\fR. This option can be repeated several times. The QoS value is like any subscription at a server. Messages received by the other end will be given this QoS level at most. Default: \fI0\fR .TP \*(T<\fB\m[green]local_username\m[] \fI\m[cyan]username\m[]\fR\fR\*(T> Username as seen by the local FlashMQ's plugin or ACL checks. This is not always necessary. .TP \*(T<\fB\m[green]remote_username\m[] \fI\m[cyan]username\m[]\fR\fR\*(T> Username sent to the remote connection. .TP \*(T<\fB\m[green]remote_password\m[] \fI\m[cyan]password\m[]\fR\fR\*(T> Password sent to the remote connection. .TP \*(T<\fB\m[green]remote_clean_start\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> In MQTT3, this means 'clean session', meaning the remote server removes any existing session with the same ID on (re)connect, and destroys it immediately on disconnect. If you want reuseable sessions that survive disconnects, set this to false. If you also want to pick up remote sessions on FlashMQ restart, set \fB\m[green]use_saved_clientid\fR\m[] to true. In MQTT5, this option only influences reconnection behavior. It essentially has no effect on the first connect, because the client ID is random and will always be new (except when you set \fB\m[green]use_saved_clientid\fR\m[]). But when set to true, any reconnects, which do use the already generated client ID, will destroy the session and in-flight messages will be lost. Also see \m[blue]understanding clean session and clean start\m[] \(lB\fI\m[blue]https://www.flashmq.org/2022/11/26/understanding-clean-session-and-clean-start/\m[]\fR\(rB. Default value: \fItrue\fR .TP \*(T<\fB\m[green]local_clean_start\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> In MQTT3 mode, this means 'clean session' and means the session is removed upon disconnect. If you want to reuse sessions on reconnect, set this to false. Any new start of FlashMQ will give you a new client ID so will always be a fresh session, except if you set \fB\m[green]use_saved_clientid\fR\m[]. In MQTT5 mode, this only has effect on start, where any existing local session is removed if found. If you want the session to be removed immediately on disconnect, use \fB\m[green]local_session_expiry_interval\fR\m[] to 0. Also see \m[blue]understanding clean session and clean start\m[] \(lB\fI\m[blue]https://www.flashmq.org/2022/11/26/understanding-clean-session-and-clean-start/\m[]\fR\(rB. Default value: \fItrue\fR .TP \*(T<\fB\m[green]remote_session_expiry_interval\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> Is only used in MQTT5 mode and determines the amount of seconds after which the session can be removed from the remote server. Default value: \fI0\fR .TP \*(T<\fB\m[green]local_session_expiry_interval\m[] \fI\m[cyan]seconds\m[]\fR\fR\*(T> Determines when a local session without an active client will be removed, in both MQTT3 and MQTT5 mode. Note that in MQTT3 mode, the session is removed on disconnect when \fB\m[green]local_clean_start\fR\m[] is true. Default value: \fI0\fR .TP \*(T<\fB\m[green]remote_retain_available\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> MQTT5 allows a server to tell a client it doesn't support retained messages, or has it disabled. When using MQTT3, use this option to achieve the same. Messages will not be relayed with 'retained as published' and the retained messages that are normally sent on matching subscription, are not sent. Default value: \fItrue\fR .TP \*(T<\fB\m[green]use_saved_clientid\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> When you want your bridges to resume local and remote sessions after restart, set this to true and set \fB\m[green]remote_clean_start\fR\m[], \fB\m[green]local_clean_start\fR\m[], \fB\m[green]remote_session_expiry_interval\fR\m[] and \fB\m[green]local_session_expiry_interval\fR\m[] accordingly. It only has effect when you have set a \fB\m[green]storage_dir\fR\m[] and include \fIbridge_info\fR in \fB\m[green]persistence_data_to_save\fR\m[]. It is important to fully understand the clean session / clean start behavior and the role the client ID plays in that. The primary goal of sessions is to survive link disconnects. Configuring a fixed client ID and use that each time an MQTT client starts, is often an anti-pattern, because most clients like actual IoT devices start fresh upon restart and don't store their sessions (with in-flight packets, etc) to disk. FlashMQ does store it on disk however, so it can be used legitamately. However, you can run into unexpected situations. For instance, you will get your existing subscriptions from the session too. So, if you remove a \fB\m[green]subscribe\fR\m[] line from your bridge configuration and restart, it will actually have no effect, because the server on the other side still has that subscription in the session. See \m[blue]understanding clean session and clean start\m[] \(lB\fI\m[blue]https://www.flashmq.org/2022/11/26/understanding-clean-session-and-clean-start/\m[]\fR\(rB for details. Default value: \fIfalse\fR .TP \*(T<\fB\m[green]max_outgoing_topic_aliases\m[] \fI\m[cyan]amount\m[]\fR\fR\*(T> If you want FlashMQ to initiate topic aliases for this bridge, set this to a non-zero value. Note that it's floored to the value the remote side gives in the CONNACK packet, so it only works if the other side permits it. Default: \fI0\fR .TP \*(T<\fB\m[green]max_incoming_topic_aliases\m[] \fI\m[cyan]amount\m[]\fR\fR\*(T> If you want to accept topic aliases for this bridge, set this to a non-zero value. The value is set in the CONNECT packet to inform the remote side of the wish. It's not guaranteed that the other side will actually make aliases. Default: \fI0\fR .TP \*(T<\fB\m[green]tcp_nodelay\m[] \fI\m[cyan]true\m[]\fR|\fI\m[cyan]false\m[]\fR\fR\*(T> \fB\m[green]tcp_nodelay\fR\m[] will cause the \fITCP_NODELAY\fR option to be set for the client socket that is used to connect to the other end of the bridge. See the documentation for the \fB\fB\m[green]tcp_nodelay\fR\m[]\fR \fBlistener\fR parameter for further elaboration. Default: \fI\m[cyan]false\m[]\fR .TP \*(T<\fB\m[green]local_prefix\m[] \fI\m[cyan]prefix\m[]\fR\fR\*(T> Prefixes can be used to remap topics to and from the other end of the bridge. This makes it possible to insert a topic tree into the topic tree on another server, like a shared one. When a message comes in, the \fB\m[green]remote_prefix\fR\m[] is stripped from the topic, and the \fB\m[green]local_prefix\fR\m[] is added. The resulting topic is used for authorization 'write' checking. When a message goes out, the opposite happens: the \fB\m[green]local_prefix\fR\m[] is stripped and the \fB\m[green]remote_prefix\fR\m[] is added. However, this time, the original topic is used for authorization 'read' checking. The prefixes aren't applied to the \fB\m[green]subscribe\fR\m[] and \fB\m[green]publish\fR\m[] bridge options. You'll have to include the prefix in the subscriptions you configure. This is so that you can have multiple subscriptions to the other end, and only have the prefix applied to the relevant one(s). Messages that come in and go out that don't match the prefixes, are sent and received unchanged. The prefixes removal isn't done to topics that match the prefix exactly. This is to avoid \fIone/two/three/\fR (which has a legal empty string as last subtopic), becoming an empty string (which is illegal). If you define a prefix, they are required to end with a \fI/\fR. It's valid to have only a local or remote prefix. .TP \*(T<\fB\m[green]remote_prefix\m[] \fI\m[cyan]prefix\m[]\fR\fR\*(T> See \fB\fB\m[green]local_prefix\fR\m[]\fR. .TP \*(T<\fB\m[green]connection_count\m[] \fI\m[cyan]number\m[]\fR|\fI\m[cyan]auto\m[]\fR\fR\*(T> Normally a bridge has one TCP connection to the other side. This means that dealing with bridge traffic is limited to one thread, which also applies to the remote side if that is also FlashMQ. With this option, you can make a bridge have multiple connections, and share the traffic load over them using MQTT5 'shared subscriptions'. When you specify a \fB\m[green]publish\fR\m[] or \fB\m[green]subscribe\fR\m[] path of \fIone/two/three\fR, the topic is adjusted to \fI$share/RANDOM/one/two/three\fR to create a shared subscription so that load is balanced. With FlashMQ, load balancing is especially important for the side processing publishes (because each received publish packet means subscribers have to be looked up, auth checked, etc). The \fB\m[green]shared_subscription_targeting\fR\m[] mode is automatically set to \fIsender_hash\fR. This is required to ensure sequential message relaying (to retain ordering), and it's also better in plugin code when messages from one source are kept to one thread, as it would be under normal circumstances. If FlashMQ is also on the receiving end of these load balanced connections, it uses an extra feature to group clients of one bridge together to enhance loop detection. Normally MQTT5 supports the 'no-local' subscription option for that, but the standard states that is not allowed for shared subscriptions. FlashMQ uses 'user properties' to communicate the group they belong to, so that we can still do this kind of loop detection. This allows you to specify overlapping paths in the \fB\m[green]subscribe\fR\m[] and \fB\m[green]publish\fR\m[] options of a bridge. Note that both ends should use FlashMQ version 1.22.0 or higher. When the target server indeed is also FlashMQ, it may be smart to create a dedicated listener, for several reasons. One is that incoming connections are given to threads in a sequential order per listener, this ensures the best spread over worker threads. It also allows you to set \fB\m[green]overload_mode\fR\m[], \fB\m[green]max_buffer_size\fR\m[] and posssibly \fB\m[green]only_allow_from\fR\m[] and \fB\m[green]deny_from\fR\m[] differently for that listener. You can specify an amount of connections, or \fIauto\fR for one connection per CPU. FlashMQs load is mostly on the receiver of messages, so \fIauto\fR will likely be a good choice if most of your message load from the other side is incoming, vs outgoing. Otherwise it's best matched to the other side's number of CPUs. Default: \fI\m[cyan]1\m[]\fR .TP \*(T<\fB\m[green]max_buffer_size\m[] \fI\m[cyan]number\m[]\fR\fR\*(T> Override the \fB\m[green]client_max_write_buffer_size\fR\m[] for this bridge. When \fB\m[green]connection_count\fR\m[] is used, this size applies to each connection individually. Bridges typically have more traffic than single clients, in which case it makes sense to increase this. .SH EXAMPLE BRIDGE .PP .nf .in +7 \m[yellow]bridge \m[]{ \m[green] address \m[]\m[cyan]demo.flashmq.org\m[] \m[green] publish \m[]\m[cyan]send/this\m[] \m[green] subscribe \m[]\m[cyan]receive/this\m[] \m[green] local_username \m[]\m[cyan]my_local_user\m[] \m[green] remote_username \m[]\m[cyan]my_remote_user\m[] \m[green] remote_password \m[]\m[cyan]my_remote_pass\m[] \m[green] bridge_protocol_bit \m[]\m[cyan]false\m[] \m[green] tls \m[]\m[cyan]on\m[] \m[green] ca_file \m[]\m[cyan]/path/to/ca.crt\m[] \m[green]\m[]} .in .fi .SH AUTHOR Wiebe Cazemier <\m[blue]contact@flashmq.org\m[]>. .SH SEE ALSO \m[blue]man:flashmq\fR(1)\m[] \m[blue]https://www.flashmq.org/\m[] .SH COLOPHON The sources for the FlashMQ manual pages are maintained in \m[blue]DocBook 5.2\m[] \(lB\fI\m[blue]https://tdg.docbook.org/tdg/5.2/\m[]\fR\(rB XML files. The transformation to the multiple destination file formats is done using a bunch of XSLT 1.0 sheets, contributed to this project by Rowan van der Molen. The groff source of this man-page has ANSI-color support for the terminal. However, Debian-derived Linux distributions turn off groff color support by default. To override this, set the \fBGROFF_SGR\fR environment variable to \fI1\fR. ================================================ FILE: man/flashmq.conf.5.dbk5 ================================================ https://www.flashmq.org/man/flashmq.conf.5 flashmq.conf 5 flashmq.conf FlashMQ configuration file format The flashmq.conf file is the configuration used for configuring the FlashMQ MQTT broker. Config location By default, the flashmq daemon expects to find its configuration file at /etc/flashmq/flashmq.conf, but this can be overriden using the command-line argument; see the flashmq1 man page for details. Using the parameter in your config file, you can load all the *.conf files from that given directory. File format To set a parameter, its name must appear on a single line, followed by one or more potentially quoted arguments. Quoted values are the same as unquoted values when they don't need it. They are necessary for when argument values have spaces, for instance. When setting boolean values, yes/no, true/false and on/off can all be used. To configure the listeners, use blocks, defined by { and }. See EXAMPLE LISTENERS for details. Lines beginning with the hash character (“#”) and empty lines are ignored. Thus, a line can be commented out by prepending a “#” to it. Global parameters /path/to/plugin.so FlashMQ supports an ELF shared object (.so file) plugin interface to add functionality, authorization and authentication, because it’s hard to provide a declarative mechanism that works for everybody. See flashmq_plugin.h for the API and its documentation. It’s written in C++ for ease of passing FlashMQ internals without conversion to C, but you can basically just use a C++ compiler and program like it was C; the C++ constructs are simple. FlashMQ will auto-detect which plugin interface you’re trying to load (Mosquitto version 2 or FlashMQ native). Keep in mind that each thread initializes the plugin, inline with multi-core programming (minimize shared data and interaction between threads). You could use static variables with thread synchronization if you really want to. And of course, any Mosquitto plugin that uses global and/or static variables instead of initializing memory in its init() method, will not be thread-safe and won’t work. You can only have one plugin active, but you can combine it with and . The password and ACL file take precedence, and on a ‘deny’, will not ask the plugin. value Options passed to the plugin init() function. true|false There could be circumstances where the plugin code is mostly thread-safe, but not on initialization. Libmysqlclient for instance, needs a one-time initialization. To add to the confusion, Qt hides that away. The plugin should preferrably be written with proper synchronization like that, but as a last resort, you can use this to synchronize initialization. Default value: false true|false Like , but then for all login and ACL checks. This option may be dropped at some point, because it negates much of the multi-core design. One may as well run with only one thread then. Default value: false seconds The FlashMQ auth plugin interface has an optional function that is called periodically this amount of seconds. This can be used to refresh state, commit data, etc. Setting a value of 0 disables it. You can enable and disable this timer with a config reload. See flashmq_plugin.h for details. Default value: 60 /path/to/flashmq.log This configuration parameter sets the path to FlashMQ's log file. If you omit this option from the config file, the output will go to stdout. debug|info|notice|warning|error|none Set the log level to specified level and above. That means notice will log notice, warning and error. Use this setting over the deprecated and . If you do have those directives, they override the , for backwards compatability reasons. Default value: info. true|false Debug logging obviously creates a lot of log noise, so should only be done to diagnose problems. Deprecated. Use instead. Default value: false true|false Default value: false true|false Both publishes and QoS actions related to publishes are logged at this level. Default value: false true|false If you have topics with client IDs in it, people can possibly manipulate your ACL checking by saying their client ID is 'John+foobar'. Audit your security before you allow this. Default value: false true|false If you have topics with usernames in it, people can possibly manipulate your ACL checking by saying their username is 'John+foobar'. Audit your security before you allow this. Default value: false bytes The max length of almost all strings encoded in MQTT packets. This includes topics, user properties, client IDs, etc. Some fields, like passwords, are excluded. Payloads are also excluded (as they are not strings). The reason is to avoid abuse. The default should be enough for most deployments. If the limit is tripped, the error in the log will clearly say so, and mention . Default value: 4096 bytes MQTT packets have a maximum size of about 256 MB. This memory will (temporarily) be allocated upon arrival of such packets, so there may be cause to set it lower. This option works in conjunction with to limit memory use. Default value: 268435461 bytes The client's write buffer is where packets are stored before the event loop has the chance to flush them out. Any time a client's connection is bad and bytes can't be flushed, this buffer fills. So, there's good reason to limit this to something sensible. A good indication is your average packet size (or better yet, a configured ) multiplied by the amount of packets you want to be able to buffer. Despite the name, this settings also control the read buffer's maximum size. Having a bigger read buffer may be necessary when you know your clients receive many packets. Note that it's an approximate value and not a hard limit. Buffer sizes only grow by powers of two, and buffers are always allowed to grow to make place for ping packets. Additionally, upon arrival of large packets (up to bytes), room will be made up to twice their size. So, you may also want to reduce from the default. Default value is 1048576 (1 MB) bytes The buffers for reading and writing, also for websockets when relevant, start out with a particular size and double when they need to grow. If you know your clients send bulks of a particular size, it helps to set this to match, to avoid constant memory reallocation. The default value is set conservatively, for scenario's with millions of clients. After buffers have grown, they are eventually reset to their original size when possible. Also see and . Value must be a power of two. Default value: 1024 /foo/bar/mosquitto_password_file File with usernames and hashed+salted passwords as generated by Mosquitto's mosquitto_passwd. Mosquitto up to version 1.6 uses the sha512 algorithm. Newer version use sha512-pbkdf2. Both are supported. /foo/bar/mosquitto_acl_file ACL (access control lists) for users, anonymous users and patterns expandable with %u (username) and %c (clientid). Format is Mosquitto's acl_file. true|false This option can be overriden on a per-listener basis; see . Default value: false true|false The proper way to signal an anonymous client is by setting the 'username present' flag in the CONNECT packet to 0, which in MQTT3 also demands the absence of a password. However, there are also clients out there that set the 'username present' flag to 1 and then give an empty username. This is an undesirable situation, because it means there are two ways to identify an anonymous client. Anonymous clients are not authenticated against a loaded plugin when is true. With this option enabled, that means users with empty string as usernames also aren't. With this option disabled, clients connecting with an empty username will be reject with 'bad username or password' as MQTT error code. The default is to be unambigious, but this can be overridden with this option. Default value: false number The general Linux default of 1024 can be overridden. Note: systemd blocks you from setting it, so it needs to be set on the unit. The default systemd unit file sets . You may also need to set Default value: 1000000 seconds Expire sessions after this time. Setting to 0 disables it and is (MQTT3) standard-compliant. But, existing sessions cause load on the server (because they cost memory and are still subscribers), so keeping sessions after any client that connects with a random ID doesn't make sense. Default value: 1209600 true|false Don't log LOG_INFO and LOG_NOTICE. This is useful when you have a lot of foot traffic, because otherwise the log gets filled with connect/disconnect notices. Deprecated. Use instead. Default value: false /path/to/dir Location to store sessions, subscriptions and retained messages. Not specifying this will turn off persistence. seconds The interval at which the state is saved, if enabled with . This setting is also applied on reload. Default: 3623 all|sessions_and_subscriptions|retained_messages|bridge_info When a is defined, specify which information should be saved. The arguments are a list of one or more of the specified options, and can be negated with a !. Meaning like: all !retained_messages. The order is relevant. If you end with all, it overrides the previous. The option bridge_info causes a small bit of meta data of the bridges to be saved. See the bridge option . The default is all, but as soon as this option is defined, it's cleared, and the desired list has to be composed. number bytes There is a limit to how many QoS packets can be stored in a session, so you can define a maximum amount of messages and bytes. If any of these is exceeded, the packet is dropped. Note that changing only takes effect for new clients (also when picking up existing sessions). This is largely due to it being part of the MQTT5 connection handshake and is supposed to be adhered to. Defaults: max_qos_msg_pending_per_client 512 max_qos_bytes_pending_per_client 65536 qos_value The maximum QoS value FlashMQ will allow clients to use for subscriptions, publishes and wills. Subscriptions will be downgraded to this value. Publishes and wills will cause a disconnect for MQTT5 clients, and the action configured with for MQTT3 clients. The value is updated on config reload, but only for new clients. The reason is that the max QoS is part of the handshake in MQTT5, and clients will be surprised (and cause disconnects) if that changes during the lifetime of a connection. For consistency, the MQTT3 clients behave the same. This setting is also available per listener. Default value: 2 drop|disconnect Unlike MQTT5, MQTT3 doesn't have defined behavior for exceeding a configured maximum QoS (see ). This allows the behavior to be controlled. In case of drop the client gets an acknowledgement as if the publish has succeeded, to avoid stale in-flight packets. This setting is also available per listener. Default value: disconnect number Is communicated towards MQTT5 clients. It is then up to them to decide to set them or not. Changing this setting and reloading the config only has effect on new clients, because existing clients would otherwise exceed the limit they think applies. Default value: 65535 number FlashMQ will make this many aliases per MQTT5 client, if they ask for aliases (with the connect property ). Default value: 65535 number If you want to have a different amount of worker threads then CPUs, you can set this value. Typically you don't need to set this. Default value: auto-detect true|false When disabled, the server will not set last will and testament specified by connecting clients. Default value: true enabled|enabled_without_persistence|downgrade|drop|disconnect_with_error Retained messages can be a strain on the server you may not need. You can set various ways of dealing with them: enabled. This is normal operation. enabled_without_persistence. Like 'normal', except it won't store them to disk if is defined. enabled_without_retaining. This somewhat counter-intuitive sounding mode is like , except that the 'retain' flag is not removed. This allows MQTT5 subscribers that subscribe with 'retain as published' to see which messages were originally sent as retained. It's just that FlashMQ won't retain them. downgrade. The retain flag is removed and treated like a normal publish. drop. Messages with retain set are dropped. disconnect_with_error. Disconnect clients who try to set them. Default value: enabled seconds Use this to limit the life time of retained messages. Without this, the amount of retained messages may never decrease. Default value: 4294967295 number Deprecated. number When clients place a subscription, they will get the retained messages matching that subscription. Even though traversing the retained message tree is deprioritized in favor of other traffic, it will still cause CPU load until it's done. If you have a tree with millions of nodes and clients subscribe to #, this is potentially unwanted. You can use this setting to limit how many nodes of the retrained tree are traversed. Note that the topic one/two/three is three nodes, and each node doesn't necessarilly need to contain a message. Default value: 4294967295 milliseconds The time after which FlashMQ will fall back to (b)locking vs queued mode for setting retained messages. 0, the default, disables queued mode altogether. It's disabled by default because it can incur some extra CPU and memory overhead. Each retained message lives in a node in a tree. The topic 'one/two/three' is three nodes. When a node in that tree does not exist yet, it needs to be created. This requires a write lock on the tree. At this point, other threads reading from or writing to the retained message tree need to wait. This can cause a compounding blocking effect, especially if many threads do it at once. This feature is to favor server responsiveness vs the speed at which retained messages become available in the server. It is primarily useful for when you have a lot of retained messages on different/changing topics. If at first a retained message can't be set, the action to do so will be retried in the event loop, asynchronously. This setting determines the maximum amount of time to defer setting a retained message, after which it will fall back to using locks. Also see Default value: 0 milliseconds For , the amount of random spread between 0 and this value for the timeout. This spreads out locking over time, reducing contention. Default value: 1000 seconds The grace period after which a retained message node is eligible for deletion. The topic 'one/two/three' is three nodes, and if that topic had a message, it would be contained in 'three'. FlashMQ will periodically clear out retained message nodes that have no message anymore. This is required to save memory. But, when you receive retained messages on the same topics repeatedly, it may be beneficial to keep the nodes around, to avoid the need for locks to recreate them. If you know that retained messages come and go within a certain period, it's benificial to set this value so that no unnecessary node destruction and creation takes place. Default value: 0 seconds The grace period after which a subscription node is eligible for deletion. The subscription 'one/two/three' is three nodes. FlashMQ will periodically clear our nodes in the subscription tree that have no entries anymore. This is required to save memory. But, when clients place the same subscriptions repeatedly, it may be beneficial to keep the nodes around, to avoid the need for locks to recreate them. If you know that certain subscription patterns come and go within a certain period, it's benificial to set this value so that no unnecessary node destruction and creation takes place. Default value: 3600 inet4_address|inet6_address HTTP proxies in front of the websocket listeners can set the X-Real-IP header to identify the original connecting client. With you can mark IP networks as trusted. By default, clients are not trusted, to avoid spoofing. You can repeat the option to allow for multiple addresses. Valid notations are 1.2.3.4, 1.2.3.4/16, 1.2.0.0/16, 2a01:1337::1, 2a01:1337::1/64, etc. The header X-Forwarded-For is not used, because that's designed to contain a list of addresses, if applicable. As a side note about using a proxy on your listener; you can only have an absolute max of 65535 connections to an IP+port combination (and the practical limit is lower). If you need more, you have to set up multiple listeners. This can be multiple IP addresses, or simply multiple ports. round_robin|sender_hash|first When having multiple subscribers on a shared subscription (like '$share/myshare/jane/doe'), select how the messages should be distributed over the subscribers. round_robin. Select the next subscriber for each message. There is still some amount of randomness to it because the counter for this is not thread safe. Using an atomic/mutexed counter for it would just be too slow to justify. sender_hash. Selects a receiver deterministically based on the hash of the client ID of the sender. The selected subscriber will depend on how many subscribers there are, so if some disconnect, the distribution will change. Moreover, the selection may also change when FlashMQ cleans up empty spaces in the list of shared subscribers. first. Selects the first subscriber in the list. This mode can be useful for fallback. When one client disappears, the other will seamlessly take over. Default: round_robin number Defines the minimum level of the first wildcard topic filter ( and ). In a topic filter like , that is 2. If you specify 2, a subscription to will be denied. Remember that only MQTT 3.1.1 and newer actually notify the client of the denial in the sub-ack packet. The reason you may want to limit it, is performance. If you have a base message load of 100,000 messages per second, each client subscribing to causes that many permission checks per second. If you have 100 clients doing that, there will be 10 million permission checks per second. Default: 0 number Defines the maximum number of components a topic/filter string can have. For example, one/two/three has three. The reason is to avoid abuse. The default should be enough for most deployments. If the limit is tripped, the error in the log will clearly say so, and mention . Default: 128 deny_all|deny_retained_only For , specify what you want to deny. Trying to give a client all retained messages can cause quite some load, so only denying the retained messages upon receiving a broad wildcard subscription can be useful if you have a low enough general message volume, but a high number of retained messages. Default: deny_all log|close_new_clients Define the action to perform when the value defined with is exceeded. When a server is (re)started, and hundreds of thousands of clients connect, the SSL handshaking and authenticating can be so heavy that it doesn't get to clients in time. They will then reconnect and try again, and get stuck in a loop. This option is to mitigate that. With close_new_clients, new clients will be closed immediately after connecting while the server is overloaded. This will allow the worker threads to process the new clients in a controlled manner. For really large deployments, this can be augmented with extra rate limiting in iptables, or other firewalls. A stateless method is preferred, like: iptables -I INPUT -p tcp -m multiport --dports 8883,1883 --syn -m hashlimit --hashlimit-name newmqttconns --hashlimit-above 10000/second --hashlimit-burst 15000 -j DROP The current default is log, but that will likely change in the future. Default: log milliseconds For , the maximum permissible thread drift before the overload action is taken. The drift values considered are those of the main loop, in which clients are accepted, and the median of all worker threads. Default: 2000 /path/to/dir Load *.conf files from the specified directory, to merge with the main configuration file. An error is generated when the directory is not there. This is to protect against running incorrect configurations by accident, when the dir has been renamed, for example. true|false Subscription identifiers allow clients to see which subscription was responsible for a message. Publish messages will contain the identifier included in the original subscription. Enabling will prevent FlashMQ from using optimizations involving packet reuse, because the packets are unique per client when it contains a subscription identifier. Therefore you may want to assess the performance difference in high message volume deployments. As per the spec, clients sending subscription identifiers when the server reported the feature as unavailable will cause them to be disconnected. This has the side effect that changing this setting on a running server will disconnect clients when they send a subscription with an identifier in it. This was chosen as behavior over the alternatives, because of simplicity and operator control (otherwise it can't be turned off at all for existing clients). Default value: true Listen parameters Listen parameters can only be used within listen { } blocks. The default port depends on the parameter and whether or not and parameters are supplied: For unencrypted MQTT, the default port is 1883 For encrypted MQTT, the default port is 8883 For plain HTTP websockets, the default port is 8080 For encrypted HTTPS websockets, the default port is 4443 mqtt|websockets|acme This is a required parameter. For acme, see acme_redirect_url. ip4_ip6|ip4|ip6|unix When using unix, a is required. Default: ip4_ip6 inet4address Default: 0.0.0.0 inet6address Default: ::0 path When using unix for , the file path of the socket. FlashMQ will remove pre-existing socket files if they already exist. user Use this to set the owner of the socket. It will always be attempted to set it, but a warning may be logged if not successful. Users may be specified as numeric or names. A config test will not verify the existence of users, for portability. group Use this to set the group of the socket. It will always be attempted to set it, but a warning may be logged if not successful. Groups may be specified as numeric or names. A config test will not verify the existence of groups, for portability. mode Use this to specify the permission mode of the unix socket, like 600. /foobar/server.crt Specifying a chain makes the listener SSL, and also requires the to be set. /foobar/server.key Specifying a private key makes the listener SSL, and also requires the to be set. tlsv1.1|tlsv1.2|tlsv1.3 Set minimum supported TLS version for TLS listeners. Note that setting this value low many not actually enable that protocol version if OpenSSL won't support it (anymore). The TLS version clients use is logged. Default: tlsv1.1 /foobar/client_authority.crt Clients can be authenticated using X509 certificates, and the username taken from the CN (common name) field. Use this directive to specify the certificate authority you trust. Specifying this or will require the listener to be TLS. /foobar/dir_with_certificates Clients can be authenticated using X509 certificates, and the username taken from the CN (common name) field. Use this directive to specify the dir containing certificate authorities you trust. Note that the filename requirements are dictated by OpenSSL. Use the utility openssl rehash /path/to/dir. Specifying this or will require the listener to be TLS. true|false When using X509 client authentication with or , the username will not be checked with a user database or a plugin by default. Set this option to true to override that. http://example.com/ This allows an ACME (automated certificate management environment) challenge to be redirected elsewhere. This allows decoupling of the certificate creation from the host(s) that run FlashMQ. This can either be configured on a dedicated listener with acme, or multiplexed on a non-SSL mqtt or websockets listener. true|false When both and are absent, don't create this listener. This can help in situations where you don't have the certificate and key yet, but you are expecting them. inet4_address|inet6_address When set, restricts the listener to the source address/network given. You can repeat the option to allow multiple addresses/networks. Valid notations are 1.2.3.4, 1.2.3.4/16, 1.2.0.0/16, 2a01:1337::1, 2a01:1337::1/64, etc. inet4_address|inet6_address Block connections from this address or network. You can repeat the option multiple times. Valid notations are 1.2.3.4, 1.2.3.4/16, 1.2.0.0/16, 2a01:1337::1, 2a01:1337::1/64, etc. true|false This allows you to override the global setting on the listener level. log|close_new_clients This allows you to override the global setting on the listener level. off|on|client_verification|client_verification_with_authn The nature of running something behind a proxy is that you would lose certain information or abilities. For that, HAProxy allows sending information in a header frame. With the option, you can configure the following: off and on configure basic functionality: reading the client's source address from the haproxy frame. This address will be shown in the logs and plugin functions. client_verification and client_verification_with_authn will require haproxy to send the PP2_SUBTYPE_SSL_CN field when using mTLS (mutual TLS), i.e. client authentication. When using the client_verification_with_authn option, FlashMQ's internal authentication will still happen (with the username from the certificate). This can be convenient when you have a plugin, so that you can still take action in the login hook. Make sure this listener is private, firewalled or use , otherwise anybody can set a different source address. Note that HAProxy's server health checks only started using the 'local' specifier as of version 2.4. This means earlier version will pretend to be a client and break the connection, causing log spam. As a side note about using a proxy on your listener; you can only have an absolute max of 65535 connections to an IP+port combination (and the practical limit is lower). If you need more, you have to set up multiple listeners. This can be multiple IP addresses, or simply multiple ports. The following is an example HAProxy backend config for the features described: See haproxy.org. true|false will cause the TCP_NODELAY option to be set for the listener's socket(s), and therefore for all clients accepted on that listener. TCP_NODELAY is a OS TCP-layer option that will cause messages written by FlashMQ to the socket to be flushed immediately, without letting Nagle's algorithm (the default) collect small outgoing TCP packets into bigger packets. Foregoing Nagle's algorithm by setting to true may decrease latency, at the likely cost of some network efficiency. Default: false number Override the for this listener. This is especially useful when this listener is the receiving side of a bridge, because these clients will likely see more traffic. qos_value Per-listener counterpart of . drop|disconnect Per-listener counterpart of . Example listeners Bridge configuration Bridges can be defined inside bridge { } blocks. A bridge is essentially just an outgoing connection to another server with loop-detection and retain flag relaying. It is not a form of clustering, although with careful design, it can be deployed to achieve some sort of load balancing. Note that normally (unless is set) one bridge is one connection, and because FlashMQ's threading model is that clients are serviced by one selected thread only, a bridge has the potential to saturate a thread, if it's heavily loaded. You can improve that with . Bridges are dynamically created, removed or changed upon config reload. When a bridge configuration changes, it will disconnect and reconnect. address The DNS name, IPv4 or IPv6 address of the server you want to connect to. number The default port depends on the option, either 1883 or 8883. ip4_ip6/ip4/ip6 Default: ip4_ip6 off/on/unverified Set TLS mode. The value means the x509 chain is not verified. tlsv1.1|tlsv1.2|tlsv1.3 Set minimum supported TLS version the bridge will negotiate with the other side. Note that setting this value low many not actually enable that protocol version if OpenSSL won't support it (anymore). Default: tlsv1.1 /foobar/bridge.crt With TLS enabled, specifying a chain makes the bridge connection authenticate to the remote broker using a public certificate, and also requires the to be set. /foobar/bridge.key With TLS enabled, specifying a private key makes the bridge connection to remote broker use that key, and also requires the to be set. path File to be used for x509 certificate chain validation. path Directory containing certificates for x509 certificate chain validation. mqtt3.1|mqtt3.1.1|mqtt5 Default: mqtt3.1.1 true|false An unofficial standard is to set the most significant bit of the protocol version byte to 1 to signal the connection is a bridge. This allows the other side to alter its behavior slightly. However, this is not always supported, so you can disable this if you get disconnected for reporting an invalid protocol version. This setting has no effect when using MQTT5, because the behavior it influences is done with subscription options. Default: true seconds The time between sending ping packets to the other side. Default: 60 prefix The prefix of the randomly generated client ID. Client IDs cannot be explicitely set for security reasons. See Understanding clean session and clean start. Default: fmqbridge filter qos Messages matching this filter will be published to the other side. Examples: # or sport/tennis/#. This option can be repeated several times. The QoS value should be seen as the QoS value of the internal subscription causing outgoing messages. Messages that are relayed have this QoS level at most. Default: 0 filter qos Subscriptions for this filter is placed at the other side. Examples: # or sport/tennis/#. This option can be repeated several times. The QoS value is like any subscription at a server. Messages received by the other end will be given this QoS level at most. Default: 0 username Username as seen by the local FlashMQ's plugin or ACL checks. This is not always necessary. username Username sent to the remote connection. password Password sent to the remote connection. true|false In MQTT3, this means 'clean session', meaning the remote server removes any existing session with the same ID on (re)connect, and destroys it immediately on disconnect. If you want reuseable sessions that survive disconnects, set this to false. If you also want to pick up remote sessions on FlashMQ restart, set to true. In MQTT5, this option only influences reconnection behavior. It essentially has no effect on the first connect, because the client ID is random and will always be new (except when you set ). But when set to true, any reconnects, which do use the already generated client ID, will destroy the session and in-flight messages will be lost. Also see understanding clean session and clean start. Default value: true true|false In MQTT3 mode, this means 'clean session' and means the session is removed upon disconnect. If you want to reuse sessions on reconnect, set this to false. Any new start of FlashMQ will give you a new client ID so will always be a fresh session, except if you set . In MQTT5 mode, this only has effect on start, where any existing local session is removed if found. If you want the session to be removed immediately on disconnect, use to 0. Also see understanding clean session and clean start. Default value: true seconds Is only used in MQTT5 mode and determines the amount of seconds after which the session can be removed from the remote server. Default value: 0 seconds Determines when a local session without an active client will be removed, in both MQTT3 and MQTT5 mode. Note that in MQTT3 mode, the session is removed on disconnect when is true. Default value: 0 true|false MQTT5 allows a server to tell a client it doesn't support retained messages, or has it disabled. When using MQTT3, use this option to achieve the same. Messages will not be relayed with 'retained as published' and the retained messages that are normally sent on matching subscription, are not sent. Default value: true true|false When you want your bridges to resume local and remote sessions after restart, set this to true and set , , and accordingly. It only has effect when you have set a and include bridge_info in . It is important to fully understand the clean session / clean start behavior and the role the client ID plays in that. The primary goal of sessions is to survive link disconnects. Configuring a fixed client ID and use that each time an MQTT client starts, is often an anti-pattern, because most clients like actual IoT devices start fresh upon restart and don't store their sessions (with in-flight packets, etc) to disk. FlashMQ does store it on disk however, so it can be used legitamately. However, you can run into unexpected situations. For instance, you will get your existing subscriptions from the session too. So, if you remove a line from your bridge configuration and restart, it will actually have no effect, because the server on the other side still has that subscription in the session. See understanding clean session and clean start for details. Default value: false amount If you want FlashMQ to initiate topic aliases for this bridge, set this to a non-zero value. Note that it's floored to the value the remote side gives in the CONNACK packet, so it only works if the other side permits it. Default: 0 amount If you want to accept topic aliases for this bridge, set this to a non-zero value. The value is set in the CONNECT packet to inform the remote side of the wish. It's not guaranteed that the other side will actually make aliases. Default: 0 true|false will cause the TCP_NODELAY option to be set for the client socket that is used to connect to the other end of the bridge. See the documentation for the listener parameter for further elaboration. Default: false prefix Prefixes can be used to remap topics to and from the other end of the bridge. This makes it possible to insert a topic tree into the topic tree on another server, like a shared one. When a message comes in, the is stripped from the topic, and the is added. The resulting topic is used for authorization 'write' checking. When a message goes out, the opposite happens: the is stripped and the is added. However, this time, the original topic is used for authorization 'read' checking. The prefixes aren't applied to the and bridge options. You'll have to include the prefix in the subscriptions you configure. This is so that you can have multiple subscriptions to the other end, and only have the prefix applied to the relevant one(s). Messages that come in and go out that don't match the prefixes, are sent and received unchanged. The prefixes removal isn't done to topics that match the prefix exactly. This is to avoid one/two/three/ (which has a legal empty string as last subtopic), becoming an empty string (which is illegal). If you define a prefix, they are required to end with a /. It's valid to have only a local or remote prefix. prefix See . number|auto Normally a bridge has one TCP connection to the other side. This means that dealing with bridge traffic is limited to one thread, which also applies to the remote side if that is also FlashMQ. With this option, you can make a bridge have multiple connections, and share the traffic load over them using MQTT5 'shared subscriptions'. When you specify a or path of one/two/three, the topic is adjusted to $share/RANDOM/one/two/three to create a shared subscription so that load is balanced. With FlashMQ, load balancing is especially important for the side processing publishes (because each received publish packet means subscribers have to be looked up, auth checked, etc). The mode is automatically set to sender_hash. This is required to ensure sequential message relaying (to retain ordering), and it's also better in plugin code when messages from one source are kept to one thread, as it would be under normal circumstances. If FlashMQ is also on the receiving end of these load balanced connections, it uses an extra feature to group clients of one bridge together to enhance loop detection. Normally MQTT5 supports the 'no-local' subscription option for that, but the standard states that is not allowed for shared subscriptions. FlashMQ uses 'user properties' to communicate the group they belong to, so that we can still do this kind of loop detection. This allows you to specify overlapping paths in the and options of a bridge. Note that both ends should use FlashMQ version 1.22.0 or higher. When the target server indeed is also FlashMQ, it may be smart to create a dedicated listener, for several reasons. One is that incoming connections are given to threads in a sequential order per listener, this ensures the best spread over worker threads. It also allows you to set , and posssibly and differently for that listener. You can specify an amount of connections, or auto for one connection per CPU. FlashMQs load is mostly on the receiver of messages, so auto will likely be a good choice if most of your message load from the other side is incoming, vs outgoing. Otherwise it's best matched to the other side's number of CPUs. Default: 1 number Override the for this bridge. When is used, this size applies to each connection individually. Bridges typically have more traffic than single clients, in which case it makes sense to increase this. Example bridge Author Wiebe Cazemier contact@flashmq.org. See also flashmq 1 https://www.flashmq.org/ ================================================ FILE: man/flashmq.conf.5.html ================================================ flashmq.conf (5) – FlashMQ configuration file format

    flashmq.conf (5)

    FlashMQ configuration file format

    Synopsis

    The flashmq.conf file is the configuration used for configuring the FlashMQ MQTT broker.

    Config location#

    By default, the flashmq daemon expects to find its configuration file at /etc/flashmq/flashmq.conf, but this can be overriden using the --config-file command-line argument; see the flashmq(1) man page for details.

    Using the include_dir parameter in your config file, you can load all the *.conf files from that given directory.

    File format#

    To set a parameter, its name must appear on a single line, followed by one or more potentially quoted arguments.

    parameter-name1 parameter-value
    parameter-name2 'value with spaces'
    parameter-name3 "escaped \" quote"
    multi-value-param one 'two' "three" "with ' char"
          

    Quoted values are the same as unquoted values when they don't need it. They are necessary for when argument values have spaces, for instance.

    When setting boolean values, yes/no, true/false and on/off can all be used.

    To configure the listeners, use listen blocks, defined by { and }. See EXAMPLE LISTENERS for details.

    Lines beginning with the hash character (“#”) and empty lines are ignored. Thus, a line can be commented out by prepending a “#” to it.

    Global parameters#

    plugin /path/to/plugin.so#

    FlashMQ supports an ELF shared object (.so file) plugin interface to add functionality, authorization and authentication, because it’s hard to provide a declarative mechanism that works for everybody. See flashmq_plugin.h for the API and its documentation. It’s written in C++ for ease of passing FlashMQ internals without conversion to C, but you can basically just use a C++ compiler and program like it was C; the C++ constructs are simple.

    FlashMQ will auto-detect which plugin interface you’re trying to load (Mosquitto version 2 or FlashMQ native). Keep in mind that each thread initializes the plugin, inline with multi-core programming (minimize shared data and interaction between threads). You could use static variables with thread synchronization if you really want to. And of course, any Mosquitto plugin that uses global and/or static variables instead of initializing memory in its init() method, will not be thread-safe and won’t work.

    You can only have one plugin active, but you can combine it with mosquitto_password_file and mosquitto_acl_file. The password and ACL file take precedence, and on a ‘deny’, will not ask the plugin.

    plugin_opt_* value#

    Options passed to the plugin init() function.

    plugin_serialize_init true|false#

    There could be circumstances where the plugin code is mostly thread-safe, but not on initialization. Libmysqlclient for instance, needs a one-time initialization. To add to the confusion, Qt hides that away.

    The plugin should preferrably be written with proper synchronization like that, but as a last resort, you can use this to synchronize initialization.

    Default value: false

    plugin_serialize_auth_checks true|false#

    Like plugin_serialize_init, but then for all login and ACL checks.

    This option may be dropped at some point, because it negates much of the multi-core design. One may as well run with only one thread then.

    Default value: false

    plugin_timer_period seconds#

    The FlashMQ auth plugin interface has an optional function that is called periodically this amount of seconds. This can be used to refresh state, commit data, etc.

    Setting a value of 0 disables it. You can enable and disable this timer with a config reload.

    See flashmq_plugin.h for details.

    Default value: 60

    log_file /path/to/flashmq.log#

    This configuration parameter sets the path to FlashMQ's log file. If you omit this option from the config file, the output will go to stdout.

    ≥ v1.11.0
    log_level debug|info|notice|warning|error|none#

    Set the log level to specified level and above. That means notice will log notice, warning and error.

    Use this setting over the deprecated log_debug and quiet. If you do have those directives, they override the log_level, for backwards compatability reasons.

    Default value: info.

    log_debug true|false#

    Debug logging obviously creates a lot of log noise, so should only be done to diagnose problems.

    Deprecated. Use log_level instead.

    Default value: false

    log_subscriptions true|false#

    Default value: false

    ≥ v1.25.0
    log_publishes true|false#

    Both publishes and QoS actions related to publishes are logged at this level.

    Default value: false

    allow_unsafe_clientid_chars true|false#

    If you have topics with client IDs in it, people can possibly manipulate your ACL checking by saying their client ID is 'John+foobar'. Audit your security before you allow this.

    Default value: false

    allow_unsafe_username_chars true|false#

    If you have topics with usernames in it, people can possibly manipulate your ACL checking by saying their username is 'John+foobar'. Audit your security before you allow this.

    Default value: false

    ≥ v1.26.0
    max_string_length bytes#

    The max length of almost all strings encoded in MQTT packets. This includes topics, user properties, client IDs, etc. Some fields, like passwords, are excluded. Payloads are also excluded (as they are not strings).

    The reason is to avoid abuse. The default should be enough for most deployments.

    If the limit is tripped, the error in the log will clearly say so, and mention max_string_length.

    Default value: 4096

    max_packet_size bytes#

    MQTT packets have a maximum size of about 256 MB. This memory will (temporarily) be allocated upon arrival of such packets, so there may be cause to set it lower.

    This option works in conjunction with client_max_write_buffer_size to limit memory use.

    Default value: 268435461

    ≥ v1.4.4
    client_max_write_buffer_size bytes#

    The client's write buffer is where packets are stored before the event loop has the chance to flush them out. Any time a client's connection is bad and bytes can't be flushed, this buffer fills. So, there's good reason to limit this to something sensible. A good indication is your average packet size (or better yet, a configured max_packet_size) multiplied by the amount of packets you want to be able to buffer.

    Despite the name, this settings also control the read buffer's maximum size. Having a bigger read buffer may be necessary when you know your clients receive many packets.

    Note that it's an approximate value and not a hard limit. Buffer sizes only grow by powers of two, and buffers are always allowed to grow to make place for ping packets. Additionally, upon arrival of large packets (up to max_packet_size bytes), room will be made up to twice their size. So, you may also want to reduce max_packet_size from the default.

    Default value is 1048576 (1 MB)

    client_initial_buffer_size bytes#

    The buffers for reading and writing, also for websockets when relevant, start out with a particular size and double when they need to grow. If you know your clients send bulks of a particular size, it helps to set this to match, to avoid constant memory reallocation. The default value is set conservatively, for scenario's with millions of clients.

    After buffers have grown, they are eventually reset to their original size when possible.

    Also see client_max_write_buffer_size and max_packet_size.

    Value must be a power of two.

    Default value: 1024

    mosquitto_password_file /foo/bar/mosquitto_password_file#

    File with usernames and hashed+salted passwords as generated by Mosquitto's mosquitto_passwd.

    Mosquitto up to version 1.6 uses the sha512 algorithm. Newer version use sha512-pbkdf2. Both are supported.

    mosquitto_acl_file /foo/bar/mosquitto_acl_file#

    ACL (access control lists) for users, anonymous users and patterns expandable with %u (username) and %c (clientid). Format is Mosquitto's acl_file.

    allow_anonymous true|false#

    This option can be overriden on a per-listener basis; see listener.allow_anonymous.

    Default value: false

    ≥ v1.11.0
    zero_byte_username_is_anonymous true|false#

    The proper way to signal an anonymous client is by setting the 'username present' flag in the CONNECT packet to 0, which in MQTT3 also demands the absence of a password. However, there are also clients out there that set the 'username present' flag to 1 and then give an empty username. This is an undesirable situation, because it means there are two ways to identify an anonymous client.

    Anonymous clients are not authenticated against a loaded plugin when allow_anonymous is true. With this option enabled, that means users with empty string as usernames also aren't.

    With this option disabled, clients connecting with an empty username will be reject with 'bad username or password' as MQTT error code.

    The default is to be unambigious, but this can be overridden with this option.

    Default value: false

    rlimit_nofile number#

    The general Linux default of 1024 can be overridden. Note: systemd blocks you from setting it, so it needs to be set on the unit. The default systemd unit file sets LimitNOFILE=infinity. You may also need to set sysctl -w fs.file-max=10000000

    Default value: 1000000

    expire_sessions_after_seconds seconds#

    Expire sessions after this time. Setting to 0 disables it and is (MQTT3) standard-compliant. But, existing sessions cause load on the server (because they cost memory and are still subscribers), so keeping sessions after any client that connects with a random ID doesn't make sense.

    Default value: 1209600

    quiet true|false#

    Don't log LOG_INFO and LOG_NOTICE. This is useful when you have a lot of foot traffic, because otherwise the log gets filled with connect/disconnect notices.

    Deprecated. Use log_level instead.

    Default value: false

    storage_dir /path/to/dir#

    Location to store sessions, subscriptions and retained messages. Not specifying this will turn off persistence.

    ≥ v1.15.1
    save_state_interval seconds#

    The interval at which the state is saved, if enabled with storage_dir.

    This setting is also applied on reload.

    Default: 3623

    ≥ v1.24.0
    persistence_data_to_save all|sessions_and_subscriptions|retained_messages|bridge_info#

    When a storage_dir is defined, specify which information should be saved. The arguments are a list of one or more of the specified options, and can be negated with a !. Meaning like: all !retained_messages. The order is relevant. If you end with all, it overrides the previous.

    The option bridge_info causes a small bit of meta data of the bridges to be saved. See the bridge option use_saved_clientid.

    The default is all, but as soon as this option is defined, it's cleared, and the desired list has to be composed.

    max_qos_msg_pending_per_client number
    max_qos_bytes_pending_per_client bytes#

    There is a limit to how many QoS packets can be stored in a session, so you can define a maximum amount of messages and bytes. If any of these is exceeded, the packet is dropped.

    Note that changing max_qos_msg_pending_per_client only takes effect for new clients (also when picking up existing sessions). This is largely due to it being part of the MQTT5 connection handshake and is supposed to be adhered to.

    Defaults:

    • max_qos_msg_pending_per_client 512

    • max_qos_bytes_pending_per_client 65536

    ≥ v1.23.0
    max_qos qos_value#

    The maximum QoS value FlashMQ will allow clients to use for subscriptions, publishes and wills. Subscriptions will be downgraded to this value. Publishes and wills will cause a disconnect for MQTT5 clients, and the action configured with mqtt3_qos_exceed_action for MQTT3 clients.

    The value is updated on config reload, but only for new clients. The reason is that the max QoS is part of the handshake in MQTT5, and clients will be surprised (and cause disconnects) if that changes during the lifetime of a connection. For consistency, the MQTT3 clients behave the same.

    This setting is also available per listener.

    Default value: 2

    ≥ v1.23.0
    mqtt3_qos_exceed_action drop|disconnect#

    Unlike MQTT5, MQTT3 doesn't have defined behavior for exceeding a configured maximum QoS (see max_qos ). This allows the behavior to be controlled.

    In case of drop the client gets an acknowledgement as if the publish has succeeded, to avoid stale in-flight packets.

    This setting is also available per listener.

    Default value: disconnect

    ≥ v1.4.2
    max_incoming_topic_alias_value number#

    Is communicated towards MQTT5 clients. It is then up to them to decide to set them or not.

    Changing this setting and reloading the config only has effect on new clients, because existing clients would otherwise exceed the limit they think applies.

    Default value: 65535

    ≥ v1.4.2
    max_outgoing_topic_alias_value number#

    FlashMQ will make this many aliases per MQTT5 client, if they ask for aliases (with the connect property TopicAliasMaximum).

    Default value: 65535

    thread_count number#

    If you want to have a different amount of worker threads then CPUs, you can set this value. Typically you don't need to set this.

    Default value: auto-detect

    wills_enabled true|false#

    When disabled, the server will not set last will and testament specified by connecting clients.

    Default value: true

    retained_messages_mode enabled|enabled_without_persistence|downgrade|drop|disconnect_with_error#

    Retained messages can be a strain on the server you may not need. You can set various ways of dealing with them:

    enabled. This is normal operation.

    enabled_without_persistence. Like 'normal', except it won't store them to disk if storage_dir is defined.

    enabled_without_retaining. This somewhat counter-intuitive sounding mode is like downgrade, except that the 'retain' flag is not removed. This allows MQTT5 subscribers that subscribe with 'retain as published' to see which messages were originally sent as retained. It's just that FlashMQ won't retain them.

    downgrade. The retain flag is removed and treated like a normal publish.

    drop. Messages with retain set are dropped.

    disconnect_with_error. Disconnect clients who try to set them.

    Default value: enabled

    expire_retained_messages_after_seconds seconds#

    Use this to limit the life time of retained messages. Without this, the amount of retained messages may never decrease.

    Default value: 4294967295

    ≥ v1.6.7
    retained_messages_delivery_limit number#

    Deprecated.

    ≥ v1.8.4
    retained_messages_node_limit number#

    When clients place a subscription, they will get the retained messages matching that subscription. Even though traversing the retained message tree is deprioritized in favor of other traffic, it will still cause CPU load until it's done. If you have a tree with millions of nodes and clients subscribe to #, this is potentially unwanted. You can use this setting to limit how many nodes of the retrained tree are traversed.

    Note that the topic one/two/three is three nodes, and each node doesn't necessarilly need to contain a message.

    Default value: 4294967295

    ≥ v1.14.0
    set_retained_message_defer_timeout milliseconds#

    The time after which FlashMQ will fall back to (b)locking vs queued mode for setting retained messages. 0, the default, disables queued mode altogether. It's disabled by default because it can incur some extra CPU and memory overhead.

    Each retained message lives in a node in a tree. The topic 'one/two/three' is three nodes. When a node in that tree does not exist yet, it needs to be created. This requires a write lock on the tree. At this point, other threads reading from or writing to the retained message tree need to wait. This can cause a compounding blocking effect, especially if many threads do it at once.

    This feature is to favor server responsiveness vs the speed at which retained messages become available in the server. It is primarily useful for when you have a lot of retained messages on different/changing topics. If at first a retained message can't be set, the action to do so will be retried in the event loop, asynchronously.

    This setting determines the maximum amount of time to defer setting a retained message, after which it will fall back to using locks.

    Also see set_retained_message_defer_timeout_spread

    Default value: 0

    ≥ v1.14.0
    set_retained_message_defer_timeout_spread milliseconds#

    For set_retained_message_defer_timeout, the amount of random spread between 0 and this value for the timeout. This spreads out locking over time, reducing contention.

    Default value: 1000

    ≥ v1.14.0
    retained_message_node_lifetime seconds#

    The grace period after which a retained message node is eligible for deletion. The topic 'one/two/three' is three nodes, and if that topic had a message, it would be contained in 'three'.

    FlashMQ will periodically clear out retained message nodes that have no message anymore. This is required to save memory. But, when you receive retained messages on the same topics repeatedly, it may be beneficial to keep the nodes around, to avoid the need for locks to recreate them. If you know that retained messages come and go within a certain period, it's benificial to set this value so that no unnecessary node destruction and creation takes place.

    Default value: 0

    ≥ v1.15.1
    subscription_node_lifetime seconds#

    The grace period after which a subscription node is eligible for deletion. The subscription 'one/two/three' is three nodes.

    FlashMQ will periodically clear our nodes in the subscription tree that have no entries anymore. This is required to save memory. But, when clients place the same subscriptions repeatedly, it may be beneficial to keep the nodes around, to avoid the need for locks to recreate them. If you know that certain subscription patterns come and go within a certain period, it's benificial to set this value so that no unnecessary node destruction and creation takes place.

    Default value: 3600

    ≥ v1.2.0
    websocket_set_real_ip_from inet4_address|inet6_address#

    HTTP proxies in front of the websocket listeners can set the X-Real-IP header to identify the original connecting client. With websocket_set_real_ip_from you can mark IP networks as trusted. By default, clients are not trusted, to avoid spoofing.

    You can repeat the option to allow for multiple addresses. Valid notations are 1.2.3.4, 1.2.3.4/16, 1.2.0.0/16, 2a01:1337::1, 2a01:1337::1/64, etc.

    The header X-Forwarded-For is not used, because that's designed to contain a list of addresses, if applicable.

    As a side note about using a proxy on your listener; you can only have an absolute max of 65535 connections to an IP+port combination (and the practical limit is lower). If you need more, you have to set up multiple listeners. This can be multiple IP addresses, or simply multiple ports.

    ≥ v1.2.0
    shared_subscription_targeting round_robin|sender_hash|first#

    When having multiple subscribers on a shared subscription (like '$share/myshare/jane/doe'), select how the messages should be distributed over the subscribers.

    round_robin. Select the next subscriber for each message. There is still some amount of randomness to it because the counter for this is not thread safe. Using an atomic/mutexed counter for it would just be too slow to justify.

    sender_hash. Selects a receiver deterministically based on the hash of the client ID of the sender. The selected subscriber will depend on how many subscribers there are, so if some disconnect, the distribution will change. Moreover, the selection may also change when FlashMQ cleans up empty spaces in the list of shared subscribers.

    first. Selects the first subscriber in the list. This mode can be useful for fallback. When one client disappears, the other will seamlessly take over.

    Default: round_robin

    ≥ v1.9.0
    minimum_wildcard_subscription_depth number#

    Defines the minimum level of the first wildcard topic filter (# and +). In a topic filter like sensors/temperature/#, that is 2. If you specify 2, a subscription to sensors/# will be denied. Remember that only MQTT 3.1.1 and newer actually notify the client of the denial in the sub-ack packet.

    The reason you may want to limit it, is performance. If you have a base message load of 100,000 messages per second, each client subscribing to # causes that many permission checks per second. If you have 100 clients doing that, there will be 10 million permission checks per second.

    Default: 0

    ≥ v1.26.0
    max_topic_split_depth number#

    Defines the maximum number of components a topic/filter string can have. For example, one/two/three has three.

    The reason is to avoid abuse. The default should be enough for most deployments.

    If the limit is tripped, the error in the log will clearly say so, and mention max_topic_split_depth.

    Default: 128

    ≥ v1.9.0
    wildcard_subscription_deny_mode deny_all|deny_retained_only#

    For minimum_wildcard_subscription_depth, specify what you want to deny. Trying to give a client all retained messages can cause quite some load, so only denying the retained messages upon receiving a broad wildcard subscription can be useful if you have a low enough general message volume, but a high number of retained messages.

    Default: deny_all

    ≥ v1.12.0
    overload_mode log|close_new_clients#

    Define the action to perform when the value defined with max_event_loop_drift is exceeded.

    When a server is (re)started, and hundreds of thousands of clients connect, the SSL handshaking and authenticating can be so heavy that it doesn't get to clients in time. They will then reconnect and try again, and get stuck in a loop. This option is to mitigate that. With close_new_clients, new clients will be closed immediately after connecting while the server is overloaded. This will allow the worker threads to process the new clients in a controlled manner.

    For really large deployments, this can be augmented with extra rate limiting in iptables, or other firewalls. A stateless method is preferred, like: iptables -I INPUT -p tcp -m multiport --dports 8883,1883 --syn -m hashlimit --hashlimit-name newmqttconns --hashlimit-above 10000/second --hashlimit-burst 15000 -j DROP

    The current default is log, but that will likely change in the future.

    Default: log

    ≥ v1.12.0
    max_event_loop_drift milliseconds#

    For overload_mode, the maximum permissible thread drift before the overload action is taken.

    The drift values considered are those of the main loop, in which clients are accepted, and the median of all worker threads.

    Default: 2000

    ≥ v1.7.0
    include_dir /path/to/dir#

    Load *.conf files from the specified directory, to merge with the main configuration file.

    An error is generated when the directory is not there. This is to protect against running incorrect configurations by accident, when the dir has been renamed, for example.

    ≥ v1.18.0
    subscription_identifiers_enabled true|false#

    Subscription identifiers allow clients to see which subscription was responsible for a message. Publish messages will contain the identifier included in the original subscription.

    Enabling will prevent FlashMQ from using optimizations involving packet reuse, because the packets are unique per client when it contains a subscription identifier. Therefore you may want to assess the performance difference in high message volume deployments.

    As per the spec, clients sending subscription identifiers when the server reported the feature as unavailable will cause them to be disconnected. This has the side effect that changing this setting on a running server will disconnect clients when they send a subscription with an identifier in it. This was chosen as behavior over the alternatives, because of simplicity and operator control (otherwise it can't be turned off at all for existing clients).

    Default value: true

    Listen parameters

    Listen parameters can only be used within listen { } blocks.

    port#

    The default port depends on the protocol parameter and whether or not fullchain and privkey parameters are supplied:

    • For unencrypted MQTT, the default port is 1883

    • For encrypted MQTT, the default port is 8883

    • For plain HTTP websockets, the default port is 8080

    • For encrypted HTTPS websockets, the default port is 4443

    protocol mqtt|websockets|acme#

    This is a required parameter.

    For acme, see acme_redirect_url.

    inet_protocol ip4_ip6|ip4|ip6|unix#

    When using unix, a unix_socket_path is required.

    Default: ip4_ip6

    inet4_bind_address inet4address#

    Default: 0.0.0.0

    inet6_bind_address inet6address#

    Default: ::0

    ≥ v1.22.0
    unix_socket_path path#

    When using unix for inet_protocol, the file path of the socket.

    FlashMQ will remove pre-existing socket files if they already exist.

    ≥ v1.23.0
    unix_socket_user user#

    Use this to set the owner of the socket. It will always be attempted to set it, but a warning may be logged if not successful. Users may be specified as numeric or names.

    A config test will not verify the existence of users, for portability.

    ≥ v1.23.0
    unix_socket_group group#

    Use this to set the group of the socket. It will always be attempted to set it, but a warning may be logged if not successful. Groups may be specified as numeric or names.

    A config test will not verify the existence of groups, for portability.

    ≥ v1.23.0
    unix_socket_mode mode#

    Use this to specify the permission mode of the unix socket, like 600.

    fullchain /foobar/server.crt#

    Specifying a chain makes the listener SSL, and also requires the privkey to be set.

    privkey /foobar/server.key#

    Specifying a private key makes the listener SSL, and also requires the fullchain to be set.

    ≥ v1.20.0
    minimum_tls_version tlsv1.1|tlsv1.2|tlsv1.3#

    Set minimum supported TLS version for TLS listeners. Note that setting this value low many not actually enable that protocol version if OpenSSL won't support it (anymore).

    The TLS version clients use is logged.

    Default: tlsv1.1

    ≥ v1.8.0
    client_verification_ca_file /foobar/client_authority.crt#

    Clients can be authenticated using X509 certificates, and the username taken from the CN (common name) field. Use this directive to specify the certificate authority you trust.

    Specifying this or client_verification_ca_dir will require the listener to be TLS.

    ≥ v1.8.0
    client_verification_ca_dir /foobar/dir_with_certificates#

    Clients can be authenticated using X509 certificates, and the username taken from the CN (common name) field. Use this directive to specify the dir containing certificate authorities you trust.

    Note that the filename requirements are dictated by OpenSSL. Use the utility openssl rehash /path/to/dir.

    Specifying this or client_verification_ca_file will require the listener to be TLS.

    ≥ v1.8.0
    client_verification_still_do_authn true|false#

    When using X509 client authentication with client_verification_ca_file or client_verification_ca_dir, the username will not be checked with a user database or a plugin by default. Set this option to true to override that.

    ≥ v1.22.0
    acme_redirect_url http://example.com/#

    This allows an ACME (automated certificate management environment) challenge to be redirected elsewhere. This allows decoupling of the certificate creation from the host(s) that run FlashMQ.

    This can either be configured on a dedicated listener with protocol acme, or multiplexed on a non-SSL mqtt or websockets listener.

    ≥ v1.22.0
    drop_on_absent_certificate true|false#

    When both privkey and fullchain are absent, don't create this listener. This can help in situations where you don't have the certificate and key yet, but you are expecting them.

    ≥ v1.22.0
    only_allow_from inet4_address|inet6_address#

    When set, restricts the listener to the source address/network given.

    You can repeat the option to allow multiple addresses/networks. Valid notations are 1.2.3.4, 1.2.3.4/16, 1.2.0.0/16, 2a01:1337::1, 2a01:1337::1/64, etc.

    ≥ v1.22.0
    deny_from inet4_address|inet6_address#

    Block connections from this address or network.

    You can repeat the option multiple times. Valid notations are 1.2.3.4, 1.2.3.4/16, 1.2.0.0/16, 2a01:1337::1, 2a01:1337::1/64, etc.

    ≥ v1.10.0
    allow_anonymous true|false#

    This allows you to override the global allow_anonymous setting on the listener level.

    ≥ v1.21.0
    overload_mode log|close_new_clients#

    This allows you to override the global overload_mode setting on the listener level.

    ≥ v1.25.0
    haproxy off|on|client_verification|client_verification_with_authn#

    The nature of running something behind a proxy is that you would lose certain information or abilities. For that, HAProxy allows sending information in a header frame. With the haproxy option, you can configure the following:

    off and on configure basic functionality: reading the client's source address from the haproxy frame. This address will be shown in the logs and plugin functions.

    client_verification and client_verification_with_authn will require haproxy to send the PP2_SUBTYPE_SSL_CN field when using mTLS (mutual TLS), i.e. client authentication. When using the client_verification_with_authn option, FlashMQ's internal authentication will still happen (with the username from the certificate). This can be convenient when you have a plugin, so that you can still take action in the login hook.

    Make sure this listener is private, firewalled or use only_allow_from, otherwise anybody can set a different source address.

    Note that HAProxy's server health checks only started using the 'local' specifier as of version 2.4. This means earlier version will pretend to be a client and break the connection, causing log spam.

    As a side note about using a proxy on your listener; you can only have an absolute max of 65535 connections to an IP+port combination (and the practical limit is lower). If you need more, you have to set up multiple listeners. This can be multiple IP addresses, or simply multiple ports.

    The following is an example HAProxy backend config for the features described:

    backend backend_server
            mode tcp
            timeout client  2m
            timeout connect 10s
            timeout server  2m
            server server1 ::1:1885 send-proxy-v2 send-proxy-v2-ssl-cn 

    See haproxy.org.

    ≥ v1.13.0
    tcp_nodelay true|false#

    tcp_nodelay will cause the TCP_NODELAY option to be set for the listener's socket(s), and therefore for all clients accepted on that listener.

    TCP_NODELAY is a OS TCP-layer option that will cause messages written by FlashMQ to the socket to be flushed immediately, without letting Nagle's algorithm (the default) collect small outgoing TCP packets into bigger packets.

    Foregoing Nagle's algorithm by setting tcp_nodelay to true may decrease latency, at the likely cost of some network efficiency.

    Default: false

    ≥ v1.22.0
    max_buffer_size number#

    Override the client_max_write_buffer_size for this listener. This is especially useful when this listener is the receiving side of a bridge, because these clients will likely see more traffic.

    ≥ v1.23.0
    max_qos qos_value#

    Per-listener counterpart of max_qos.

    ≥ v1.23.0
    mqtt3_qos_exceed_action drop|disconnect#

    Per-listener counterpart of mqtt3_qos_exceed_action.

    Example listeners#

    listen {
      protocol mqtt
      inet_protocol ip4_ip6
      inet4_bind_address 127.0.0.1
      inet6_bind_address ::1
      fullchain /foobar/server.crt
      privkey /foobar/server.key
    
      # default = 8883
      port 8883
    }
    listen {
      protocol mqtt
      fullchain /foobar/server.crt
      privkey /foobar/server.key
      client_verification_ca_file /foobar/client_authority.crt
      client_verification_still_do_authn false
    }
    listen {
      protocol mqtt
      inet_protocol ip4
    
      # default = 1883
      port 1883
    }
    listen {
      protocol websockets
      fullchain /foobar/server.crt
      privkey /foobar/server.key
    
      # default = 4443
      port 4443
    }
    listen {
      protocol websockets
    
      # default = 8080
      port 8080
    }
    listen {
      port 2883
      haproxy on
    }

    Bridge configuration#

    Bridges can be defined inside bridge { } blocks. A bridge is essentially just an outgoing connection to another server with loop-detection and retain flag relaying. It is not a form of clustering, although with careful design, it can be deployed to achieve some sort of load balancing. Note that normally (unless connection_count is set) one bridge is one connection, and because FlashMQ's threading model is that clients are serviced by one selected thread only, a bridge has the potential to saturate a thread, if it's heavily loaded. You can improve that with connection_count.

    Bridges are dynamically created, removed or changed upon config reload. When a bridge configuration changes, it will disconnect and reconnect.

    ≥ v1.7.0
    address address#

    The DNS name, IPv4 or IPv6 address of the server you want to connect to.

    ≥ v1.7.0
    port number#

    The default port depends on the tls option, either 1883 or 8883.

    ≥ v1.7.0
    inet_protocol ip4_ip6/ip4/ip6#

    Default: ip4_ip6

    ≥ v1.7.0
    tls off/on/unverified#

    Set TLS mode. The value unverified means the x509 chain is not verified.

    ≥ v1.20.0
    minimum_tls_version tlsv1.1|tlsv1.2|tlsv1.3#

    Set minimum supported TLS version the bridge will negotiate with the other side. Note that setting this value low many not actually enable that protocol version if OpenSSL won't support it (anymore).

    Default: tlsv1.1

    fullchain /foobar/bridge.crt#

    With TLS enabled, specifying a chain makes the bridge connection authenticate to the remote broker using a public certificate, and also requires the privkey to be set.

    privkey /foobar/bridge.key#

    With TLS enabled, specifying a private key makes the bridge connection to remote broker use that key, and also requires the fullchain to be set.

    ≥ v1.7.0
    ca_file path#

    File to be used for x509 certificate chain validation.

    ≥ v1.7.0
    ca_dir path#

    Directory containing certificates for x509 certificate chain validation.

    ≥ v1.7.0
    protocol_version mqtt3.1|mqtt3.1.1|mqtt5#

    Default: mqtt3.1.1

    ≥ v1.7.0
    bridge_protocol_bit true|false#

    An unofficial standard is to set the most significant bit of the protocol version byte to 1 to signal the connection is a bridge. This allows the other side to alter its behavior slightly. However, this is not always supported, so you can disable this if you get disconnected for reporting an invalid protocol version.

    This setting has no effect when using MQTT5, because the behavior it influences is done with subscription options.

    Default: true

    ≥ v1.7.0
    keepalive seconds#

    The time between sending ping packets to the other side.

    Default: 60

    ≥ v1.7.0
    clientid_prefix prefix#

    The prefix of the randomly generated client ID. Client IDs cannot be explicitely set for security reasons. See Understanding clean session and clean start.

    Default: fmqbridge

    ≥ v1.7.0
    publish filter qos#

    Messages matching this filter will be published to the other side. Examples: # or sport/tennis/#. This option can be repeated several times.

    The QoS value should be seen as the QoS value of the internal subscription causing outgoing messages. Messages that are relayed have this QoS level at most.

    Default: 0

    ≥ v1.7.0
    subscribe filter qos#

    Subscriptions for this filter is placed at the other side. Examples: # or sport/tennis/#. This option can be repeated several times.

    The QoS value is like any subscription at a server. Messages received by the other end will be given this QoS level at most.

    Default: 0

    ≥ v1.7.0
    local_username username#

    Username as seen by the local FlashMQ's plugin or ACL checks. This is not always necessary.

    ≥ v1.7.0
    remote_username username#

    Username sent to the remote connection.

    ≥ v1.7.0
    remote_password password#

    Password sent to the remote connection.

    ≥ v1.7.0
    remote_clean_start true|false#

    In MQTT3, this means 'clean session', meaning the remote server removes any existing session with the same ID on (re)connect, and destroys it immediately on disconnect. If you want reuseable sessions that survive disconnects, set this to false. If you also want to pick up remote sessions on FlashMQ restart, set use_saved_clientid to true.

    In MQTT5, this option only influences reconnection behavior. It essentially has no effect on the first connect, because the client ID is random and will always be new (except when you set use_saved_clientid). But when set to true, any reconnects, which do use the already generated client ID, will destroy the session and in-flight messages will be lost.

    Also see understanding clean session and clean start.

    Default value: true

    ≥ v1.7.0
    local_clean_start true|false#

    In MQTT3 mode, this means 'clean session' and means the session is removed upon disconnect. If you want to reuse sessions on reconnect, set this to false. Any new start of FlashMQ will give you a new client ID so will always be a fresh session, except if you set use_saved_clientid.

    In MQTT5 mode, this only has effect on start, where any existing local session is removed if found. If you want the session to be removed immediately on disconnect, use local_session_expiry_interval to 0.

    Also see understanding clean session and clean start.

    Default value: true

    ≥ v1.7.0
    remote_session_expiry_interval seconds#

    Is only used in MQTT5 mode and determines the amount of seconds after which the session can be removed from the remote server.

    Default value: 0

    ≥ v1.7.0
    local_session_expiry_interval seconds#

    Determines when a local session without an active client will be removed, in both MQTT3 and MQTT5 mode. Note that in MQTT3 mode, the session is removed on disconnect when local_clean_start is true.

    Default value: 0

    ≥ v1.7.0
    remote_retain_available true|false#

    MQTT5 allows a server to tell a client it doesn't support retained messages, or has it disabled. When using MQTT3, use this option to achieve the same.

    Messages will not be relayed with 'retained as published' and the retained messages that are normally sent on matching subscription, are not sent.

    Default value: true

    ≥ v1.7.0
    use_saved_clientid true|false#

    When you want your bridges to resume local and remote sessions after restart, set this to true and set remote_clean_start, local_clean_start, remote_session_expiry_interval and local_session_expiry_interval accordingly. It only has effect when you have set a storage_dir and include bridge_info in persistence_data_to_save.

    It is important to fully understand the clean session / clean start behavior and the role the client ID plays in that. The primary goal of sessions is to survive link disconnects. Configuring a fixed client ID and use that each time an MQTT client starts, is often an anti-pattern, because most clients like actual IoT devices start fresh upon restart and don't store their sessions (with in-flight packets, etc) to disk. FlashMQ does store it on disk however, so it can be used legitamately. However, you can run into unexpected situations. For instance, you will get your existing subscriptions from the session too. So, if you remove a subscribe line from your bridge configuration and restart, it will actually have no effect, because the server on the other side still has that subscription in the session.

    See understanding clean session and clean start for details.

    Default value: false

    ≥ v1.7.0
    max_outgoing_topic_aliases amount#

    If you want FlashMQ to initiate topic aliases for this bridge, set this to a non-zero value. Note that it's floored to the value the remote side gives in the CONNACK packet, so it only works if the other side permits it.

    Default: 0

    ≥ v1.7.0
    max_incoming_topic_aliases amount#

    If you want to accept topic aliases for this bridge, set this to a non-zero value. The value is set in the CONNECT packet to inform the remote side of the wish. It's not guaranteed that the other side will actually make aliases.

    Default: 0

    ≥ v1.13.0
    tcp_nodelay true|false#

    tcp_nodelay will cause the TCP_NODELAY option to be set for the client socket that is used to connect to the other end of the bridge.

    See the documentation for the tcp_nodelay listener parameter for further elaboration.

    Default: false

    ≥ v1.19.0
    local_prefix prefix#

    Prefixes can be used to remap topics to and from the other end of the bridge. This makes it possible to insert a topic tree into the topic tree on another server, like a shared one.

    When a message comes in, the remote_prefix is stripped from the topic, and the local_prefix is added. The resulting topic is used for authorization 'write' checking.

    When a message goes out, the opposite happens: the local_prefix is stripped and the remote_prefix is added. However, this time, the original topic is used for authorization 'read' checking.

    The prefixes aren't applied to the subscribe and publish bridge options. You'll have to include the prefix in the subscriptions you configure. This is so that you can have multiple subscriptions to the other end, and only have the prefix applied to the relevant one(s). Messages that come in and go out that don't match the prefixes, are sent and received unchanged.

    The prefixes removal isn't done to topics that match the prefix exactly. This is to avoid one/two/three/ (which has a legal empty string as last subtopic), becoming an empty string (which is illegal).

    If you define a prefix, they are required to end with a /. It's valid to have only a local or remote prefix.

    ≥ v1.19.0
    remote_prefix prefix#

    See local_prefix.

    ≥ v1.22.0
    connection_count number|auto#

    Normally a bridge has one TCP connection to the other side. This means that dealing with bridge traffic is limited to one thread, which also applies to the remote side if that is also FlashMQ. With this option, you can make a bridge have multiple connections, and share the traffic load over them using MQTT5 'shared subscriptions'.

    When you specify a publish or subscribe path of one/two/three, the topic is adjusted to $share/RANDOM/one/two/three to create a shared subscription so that load is balanced. With FlashMQ, load balancing is especially important for the side processing publishes (because each received publish packet means subscribers have to be looked up, auth checked, etc).

    The shared_subscription_targeting mode is automatically set to sender_hash. This is required to ensure sequential message relaying (to retain ordering), and it's also better in plugin code when messages from one source are kept to one thread, as it would be under normal circumstances.

    If FlashMQ is also on the receiving end of these load balanced connections, it uses an extra feature to group clients of one bridge together to enhance loop detection. Normally MQTT5 supports the 'no-local' subscription option for that, but the standard states that is not allowed for shared subscriptions. FlashMQ uses 'user properties' to communicate the group they belong to, so that we can still do this kind of loop detection. This allows you to specify overlapping paths in the subscribe and publish options of a bridge. Note that both ends should use FlashMQ version 1.22.0 or higher.

    When the target server indeed is also FlashMQ, it may be smart to create a dedicated listener, for several reasons. One is that incoming connections are given to threads in a sequential order per listener, this ensures the best spread over worker threads. It also allows you to set overload_mode, max_buffer_size and posssibly only_allow_from and deny_from differently for that listener.

    You can specify an amount of connections, or auto for one connection per CPU. FlashMQs load is mostly on the receiver of messages, so auto will likely be a good choice if most of your message load from the other side is incoming, vs outgoing. Otherwise it's best matched to the other side's number of CPUs.

    Default: 1

    ≥ v1.22.0
    max_buffer_size number#

    Override the client_max_write_buffer_size for this bridge. When connection_count is used, this size applies to each connection individually.

    Bridges typically have more traffic than single clients, in which case it makes sense to increase this.

    Example bridge#

    bridge {
        address demo.flashmq.org
        publish send/this
        subscribe receive/this
        local_username my_local_user
        remote_username my_remote_user
        remote_password my_remote_pass
        bridge_protocol_bit false
        tls on
        ca_file /path/to/ca.crt
    }

    Author

    Wiebe Cazemier contact@flashmq.org.

    See also

    flashmq(1)

    https://www.flashmq.org/

    Colophon#

    The sources for the FlashMQ manual pages are maintained in DocBook 5.2 XML files. The transformation to the multiple destination file formats is done using a bunch of XSLT 1.0 sheets, contributed to this project by Rowan van der Molen.

    ================================================ FILE: man/refentry.colophon.dbk5 ================================================ Colophon The sources for the FlashMQ manual pages are maintained in DocBook 5.2 XML files. The transformation to the multiple destination file formats is done using a bunch of XSLT 1.0 sheets, contributed to this project by Rowan van der Molen. The groff source of this man-page has ANSI-color support for the terminal. However, Debian-derived Linux distributions turn off groff color support by default. To override this, set the GROFF_SGR environment variable to 1. ================================================ FILE: man/reference.dbk5 ================================================ ================================================ FILE: mosquittoauthoptcompatwrap.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "mosquittoauthoptcompatwrap.h" mosquitto_auth_opt::mosquitto_auth_opt(const std::string &key, const std::string &value) { this->key = strdup(key.c_str()); this->value = strdup(value.c_str()); } mosquitto_auth_opt::mosquitto_auth_opt(mosquitto_auth_opt &&other) { this->key = other.key; this->value = other.value; other.key = nullptr; other.value = nullptr; } mosquitto_auth_opt::mosquitto_auth_opt(const mosquitto_auth_opt &other) { if (other.key) this->key = strdup(other.key); if (other.value) this->value = strdup(other.value); } mosquitto_auth_opt::~mosquitto_auth_opt() { free(key); key = nullptr; free(value); value = nullptr; } mosquitto_auth_opt &mosquitto_auth_opt::operator=(const mosquitto_auth_opt &other) { free(key); key = nullptr; free(value); value = nullptr; if (other.key) this->key = strdup(other.key); if (other.value) this->value = strdup(other.value); return *this; } AuthOptCompatWrap::AuthOptCompatWrap(const std::unordered_map &authOpts) { for(auto &pair : authOpts) { mosquitto_auth_opt opt(pair.first, pair.second); optArray.push_back(std::move(opt)); } } ================================================ FILE: mosquittoauthoptcompatwrap.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef MOSQUITTOAUTHOPTCOMPATWRAP_H #define MOSQUITTOAUTHOPTCOMPATWRAP_H #include #include #include #include /** * @brief The mosquitto_auth_opt struct is a resource managed class of auth options, compatible with passing as arguments to Mosquitto * auth plugins. * * It's fully assignable and copyable. */ struct mosquitto_auth_opt { char *key = nullptr; char *value = nullptr; mosquitto_auth_opt(const std::string &key, const std::string &value); mosquitto_auth_opt(mosquitto_auth_opt &&other); mosquitto_auth_opt(const mosquitto_auth_opt &other); ~mosquitto_auth_opt(); mosquitto_auth_opt& operator=(const mosquitto_auth_opt &other); }; /** * @brief The AuthOptCompatWrap struct contains a vector of mosquitto auth options, with a head pointer and count which can be passed to * Mosquitto auth plugins. */ struct AuthOptCompatWrap { std::vector optArray; AuthOptCompatWrap(const std::unordered_map &authOpts); AuthOptCompatWrap(const AuthOptCompatWrap &other) = default; AuthOptCompatWrap(AuthOptCompatWrap &&other) = delete; AuthOptCompatWrap() = default; struct mosquitto_auth_opt *head() { return &optArray[0]; } int size() { return optArray.size(); } AuthOptCompatWrap &operator=(const AuthOptCompatWrap &other) = default; }; #endif // MOSQUITTOAUTHOPTCOMPATWRAP_H ================================================ FILE: mqtt5properties.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "mqtt5properties.h" #include #include #include #include #include "exceptions.h" Mqtt5PropertyBuilder::Mqtt5PropertyBuilder() { bytes.reserve(128); } size_t Mqtt5PropertyBuilder::getLength() { length = bytes.size(); return length.getLen() + bytes.size(); } const VariableByteInt &Mqtt5PropertyBuilder::getVarInt() { length = bytes.size(); return length; } const std::vector &Mqtt5PropertyBuilder::getBytes() const { return bytes; } void Mqtt5PropertyBuilder::writeServerKeepAlive(uint16_t val) { writeUint16(Mqtt5Properties::ServerKeepAlive, val); } void Mqtt5PropertyBuilder::writeSessionExpiry(uint32_t val) { writeUint32(Mqtt5Properties::SessionExpiryInterval, val); } void Mqtt5PropertyBuilder::writeReceiveMax(uint16_t val) { writeUint16(Mqtt5Properties::ReceiveMaximum, val); } void Mqtt5PropertyBuilder::writeRetainAvailable(uint8_t val) { writeUint8(Mqtt5Properties::RetainAvailable, val); } void Mqtt5PropertyBuilder::writeMaxPacketSize(uint32_t val) { writeUint32(Mqtt5Properties::MaximumPacketSize, val); } void Mqtt5PropertyBuilder::writeAssignedClientId(const std::string &clientid) { writeStr(Mqtt5Properties::AssignedClientIdentifier, clientid); } void Mqtt5PropertyBuilder::writeMaxTopicAliases(uint16_t val) { writeUint16(Mqtt5Properties::TopicAliasMaximum, val); } void Mqtt5PropertyBuilder::writeMaxQoS(uint8_t qos) { assert(qos < 2); writeUint8(Mqtt5Properties::MaximumQoS, qos); } void Mqtt5PropertyBuilder::writeWildcardSubscriptionAvailable(uint8_t val) { writeUint8(Mqtt5Properties::WildcardSubscriptionAvailable, val); } void Mqtt5PropertyBuilder::writeSubscriptionIdentifier(uint32_t val) { writeVariableByteInt(Mqtt5Properties::SubscriptionIdentifier, val); } void Mqtt5PropertyBuilder::writeSubscriptionIdentifiersAvailable(uint8_t val) { writeUint8(Mqtt5Properties::SubscriptionIdentifierAvailable, val); } void Mqtt5PropertyBuilder::writeSharedSubscriptionAvailable(uint8_t val) { writeUint8(Mqtt5Properties::SharedSubscriptionAvailable, val); } void Mqtt5PropertyBuilder::writeContentType(const std::string &format) { writeStr(Mqtt5Properties::ContentType, format); } void Mqtt5PropertyBuilder::writePayloadFormatIndicator(uint8_t val) { writeUint8(Mqtt5Properties::PayloadFormatIndicator, val); } void Mqtt5PropertyBuilder::writeMessageExpiryInterval(uint32_t val) { writeUint32(Mqtt5Properties::MessageExpiryInterval, val); } void Mqtt5PropertyBuilder::writeResponseTopic(const std::string &str) { writeStr(Mqtt5Properties::ResponseTopic, str); } void Mqtt5PropertyBuilder::writeUserProperties(const std::vector> &properties) { for (auto &p : properties) { writeUserProperty(p.first, p.second); } } void Mqtt5PropertyBuilder::writeUserProperty(const std::string &key, const std::string &value) { if (this->userPropertyCount++ > 50) throw ProtocolError("Trying to set more than 50 user properties. Likely a bad actor.", ReasonCodes::ImplementationSpecificError); write2Str(Mqtt5Properties::UserProperty, key, value); } void Mqtt5PropertyBuilder::writeCorrelationData(const std::string &correlationData) { writeStr(Mqtt5Properties::CorrelationData, correlationData); } void Mqtt5PropertyBuilder::writeTopicAlias(const uint16_t id) { writeUint16(Mqtt5Properties::TopicAlias, id); } void Mqtt5PropertyBuilder::writeAuthenticationMethod(const std::string &method) { writeStr(Mqtt5Properties::AuthenticationMethod, method); } void Mqtt5PropertyBuilder::writeAuthenticationData(const std::string &data) { writeStr(Mqtt5Properties::AuthenticationData, data); } void Mqtt5PropertyBuilder::writeWillDelay(uint32_t delay) { writeUint32(Mqtt5Properties::WillDelayInterval, delay); } void Mqtt5PropertyBuilder::writeUint32(Mqtt5Properties prop, const uint32_t x) { size_t pos = bytes.size(); const size_t newSize = pos + 5; bytes.resize(newSize); const uint8_t a = static_cast(x >> 24); const uint8_t b = static_cast(x >> 16); const uint8_t c = static_cast(x >> 8); const uint8_t d = static_cast(x); bytes[pos++] = static_cast(prop); bytes[pos++] = a; bytes[pos++] = b; bytes[pos++] = c; bytes[pos] = d; } void Mqtt5PropertyBuilder::writeUint16(Mqtt5Properties prop, const uint16_t x) { size_t pos = bytes.size(); const size_t newSize = pos + 3; bytes.resize(newSize); const uint8_t a = static_cast(x >> 8); const uint8_t b = static_cast(x); bytes[pos++] = static_cast(prop); bytes[pos++] = a; bytes[pos] = b; } void Mqtt5PropertyBuilder::writeUint8(Mqtt5Properties prop, const uint8_t x) { size_t pos = bytes.size(); const size_t newSize = pos + 2; bytes.resize(newSize); bytes[pos++] = static_cast(prop); bytes[pos] = x; } void Mqtt5PropertyBuilder::writeStr(Mqtt5Properties prop, const std::string &str) { if (str.length() > 65535) throw ProtocolError("String too long.", ReasonCodes::MalformedPacket); const uint16_t strlen = str.length(); size_t pos = bytes.size(); const size_t newSize = pos + strlen + 3; bytes.resize(newSize); const uint8_t a = static_cast(strlen >> 8); const uint8_t b = static_cast(strlen); bytes[pos++] = static_cast(prop); bytes[pos++] = a; bytes[pos++] = b; std::memcpy(&bytes[pos], str.c_str(), strlen); } void Mqtt5PropertyBuilder::write2Str(Mqtt5Properties prop, const std::string &one, const std::string &two) { size_t pos = bytes.size(); const size_t newSize = pos + one.length() + two.length() + 5; bytes.resize(newSize); bytes[pos++] = static_cast(prop); std::array strings; strings[0] = &one; strings[1] = &two; for (const std::string *str : strings) { if (str->length() > 0xFFFF) throw ProtocolError("String too long.", ReasonCodes::MalformedPacket); const uint16_t strlen = str->length(); const uint8_t a = static_cast(strlen >> 8); const uint8_t b = static_cast(strlen); bytes[pos++] = a; bytes[pos++] = b; std::memcpy(&bytes[pos], str->c_str(), strlen); pos += strlen; } } void Mqtt5PropertyBuilder::writeVariableByteInt(Mqtt5Properties prop, const uint32_t val) { const VariableByteInt x(val); size_t pos = bytes.size(); const size_t newSize = pos + x.getLen() + 1; bytes.resize(newSize); bytes[pos++] = static_cast(prop); std::memcpy(&bytes[pos], x.data(), x.getLen()); } ================================================ FILE: mqtt5properties.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef MQTT5PROPERTIES_H #define MQTT5PROPERTIES_H #include #include "types.h" #include "variablebyteint.h" class Mqtt5PropertyBuilder { std::vector bytes; VariableByteInt length; int userPropertyCount = 0; void writeUint32(Mqtt5Properties prop, const uint32_t x); void writeUint16(Mqtt5Properties prop, const uint16_t x); void writeUint8(Mqtt5Properties prop, const uint8_t x); void writeStr(Mqtt5Properties prop, const std::string &str); void write2Str(Mqtt5Properties prop, const std::string &one, const std::string &two); void writeVariableByteInt(Mqtt5Properties prop, const unsigned int val); public: Mqtt5PropertyBuilder(); size_t getLength(); const VariableByteInt &getVarInt(); const std::vector &getBytes() const; void writeServerKeepAlive(uint16_t val); void writeSessionExpiry(uint32_t val); void writeReceiveMax(uint16_t val); void writeRetainAvailable(uint8_t val); void writeMaxPacketSize(uint32_t val); void writeAssignedClientId(const std::string &clientid); void writeMaxTopicAliases(uint16_t val); void writeMaxQoS(uint8_t qos); void writeWildcardSubscriptionAvailable(uint8_t val); void writeSubscriptionIdentifier(uint32_t val); void writeSubscriptionIdentifiersAvailable(uint8_t val); void writeSharedSubscriptionAvailable(uint8_t val); void writeContentType(const std::string &format); void writePayloadFormatIndicator(uint8_t val); void writeMessageExpiryInterval(uint32_t val); void writeResponseTopic(const std::string &str); void writeUserProperties(const std::vector> &properties); void writeUserProperty(const std::string &key, const std::string &value); void writeCorrelationData(const std::string &correlationData); void writeTopicAlias(const uint16_t id); void writeAuthenticationMethod(const std::string &method); void writeAuthenticationData(const std::string &data); void writeWillDelay(uint32_t delay); }; #endif // MQTT5PROPERTIES_H ================================================ FILE: mqttpacket.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "mqttpacket.h" #include #include #include #include #include "globals.h" #include "threaddata.h" #include "threadglobals.h" #include "utils.h" #include "subscriptionstore.h" #include "acksender.h" // constructor for parsing incoming packets MqttPacket::MqttPacket(std::vector &&packet_bytes, size_t fixed_header_length, std::shared_ptr &sender) : bites(std::move(packet_bytes)), fixed_header_length(fixed_header_length) { if (bites.size() < MQTT_HEADER_LENGH) // All calling contexts prevent this, but just making sure. throw ProtocolError("Packet is smaller than minimum length.", ReasonCodes::MalformedPacket); if (bites.size() > sender->getMaxIncomingPacketSize()) throw ProtocolError("Incoming packet size exceeded.", ReasonCodes::PacketTooLarge); protocolVersion = sender->getProtocolVersion(); first_byte = bites[0]; unsigned char _packetType = (first_byte & 0xF0) >> 4; packetType = (PacketType)_packetType; pos += fixed_header_length; externallyReceived = true; } MqttPacket::MqttPacket(const ConnAck &connAck) : bites(connAck.getLengthWithoutFixedHeader()) { packetType = PacketType::CONNACK; first_byte = static_cast(packetType) << 4; writeByte(connAck.session_present & 0b00000001); // all connect-ack flags are 0, except session-present. [MQTT-3.2.2.1] writeByte(connAck.return_code); if (connAck.protocol_version >= ProtocolVersion::Mqtt5) { // TODO: don't include the reason string and user properties when it would increase the CONACK packet beyond the max packet size as determined by client. // We don't send those at all momentarily, so there is no logic to prevent it. writeProperties(connAck.propertyBuilder); } calculateRemainingLength(); } MqttPacket::MqttPacket(const SubAck &subAck) : bites(subAck.getLengthWithoutFixedHeader()) { packetType = PacketType::SUBACK; first_byte = static_cast(packetType) << 4; writeUint16(subAck.packet_id); if (subAck.protocol_version >= ProtocolVersion::Mqtt5) { // TODO: don't include the reason string and user properties when it would increase the SUBACK packet beyond the max packet size as determined by client. // We don't send those at all momentarily, so there is no logic to prevent it. writeProperties(subAck.propertyBuilder); } std::vector returnList; returnList.reserve(subAck.responses.size()); for (ReasonCodes code : subAck.responses) { returnList.push_back(static_cast(code)); } write(returnList); calculateRemainingLength(); } MqttPacket::MqttPacket(const UnsubAck &unsubAck) : bites(unsubAck.getLengthWithoutFixedHeader()) { packetType = PacketType::UNSUBACK; first_byte = static_cast(packetType) << 4; writeUint16(unsubAck.packet_id); if (unsubAck.protocol_version >= ProtocolVersion::Mqtt5) { writeProperties(unsubAck.propertyBuilder); for(const ReasonCodes &rc : unsubAck.reasonCodes) { writeByte(static_cast(rc)); } } calculateRemainingLength(); } MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish) : MqttPacket(protocolVersion, _publish, _publish.qos, _publish.topicAlias, _publish.skipTopic, _publish.subscriptionIdentifier, std::optional()) { } /** * @brief Construct a packet for a specific protocol version. * @param protocolVersion is required here, and not on the Publish object, because publishes don't have a protocol until they are for a specific client. * @param _publish * * Important to note here is that there are two concepts here: writing the byte array for sending to clients, and setting the data in publishData. The latter * will only have stuff important for internal logic. In other words, it won't contain the payload. * * The extra parameters are for overriding certain properties of the publish, because the receiving client wants it differently. Use the other overload * if you just want the publish object's data. */ MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint8_t _qos, const uint16_t _topic_alias, const bool _skip_topic, const uint32_t subscriptionIdentifier, const std::optional &topic_override) { this->protocolVersion = protocolVersion; this->publishData.client_id = _publish.client_id; this->publishData.username = _publish.username; this->publishData.skipTopic = _skip_topic; this->publishData.qos = _qos; this->publishData.retain = _publish.retain; this->publishData.topicAlias = _topic_alias; this->packetType = PacketType::PUBLISH; if (!this->publishData.skipTopic) this->publishData.topic = topic_override.value_or(_publish.topic); if (this->publishData.topic.length() > 0xFFFF) { throw ProtocolError("Topic path too long.", ReasonCodes::ProtocolError); } first_byte = static_cast(packetType) << 4; first_byte |= (this->publishData.qos << 1); first_byte |= (static_cast(publishData.retain) & 0b00000001); std::optional property_builder; if (protocolVersion >= ProtocolVersion::Mqtt5) { if (_publish.expireInfo) this->publishData.setExpireAfter(_publish.expireInfo->getCurrentTimeToExpire().count()); this->publishData.correlationData = _publish.correlationData; this->publishData.responseTopic = _publish.responseTopic; this->publishData.contentType = _publish.contentType; this->publishData.payloadUtf8 = _publish.payloadUtf8; this->publishData.userProperties = _publish.userProperties; this->publishData.subscriptionIdentifier = subscriptionIdentifier; property_builder = this->publishData.getPropertyBuilder(); } size_t len = 0; // Calculate length { len += 2; // topic string length field if (!this->publishData.skipTopic) len += this->publishData.topic.length(); len += _publish.payload.length(); if (this->publishData.qos) len += 2; if (protocolVersion >= ProtocolVersion::Mqtt5) len += property_builder ? property_builder->getLength() : 1; } bites.resize(len); writeString(publishData.topic); if (publishData.qos) { // Reserve the space for the packet id, which will be assigned later. packet_id_pos = pos; std::array zero{}; write(zero); } if (protocolVersion >= ProtocolVersion::Mqtt5) writeProperties(property_builder); payloadStart = pos; payloadLen = _publish.payload.length(); write(_publish.payload); calculateRemainingLength(); assert(pos == bites.size()); } MqttPacket::MqttPacket(const PubResponse &pubAck) : bites(pubAck.getLengthIncludingFixedHeader()) { this->protocolVersion = pubAck.protocol_version; fixed_header_length = 2; const uint8_t firstByteDefaultBits = pubAck.packet_type == PacketType::PUBREL ? 0b0010 : 0; this->first_byte = (static_cast(pubAck.packet_type) << 4) | firstByteDefaultBits; writeByte(first_byte); writeByte(pubAck.getRemainingLength()); this->packet_id_pos = this->pos; writeUint16(pubAck.packet_id); if (pubAck.needsReasonCode()) { // TODO: don't include the reason string and user properties when it would increase the PUBACK/PUBREL/PUBCOMP packet beyond the max packet size as determined by client. // We don't send those at all momentarily, so there is no logic to prevent it. writeByte(static_cast(pubAck.reason_code)); } } /** * @brief Constructor to create a disconnect packet. In normal server mode, only MQTT5 is supposed to do that (MQTT3 has no concept of server-initiated * disconnect packet). But, we also use it in the test client. * @param disconnect */ MqttPacket::MqttPacket(const Disconnect &disconnect) : bites(disconnect.getLengthWithoutFixedHeader()) { this->protocolVersion = disconnect.protocolVersion; packetType = PacketType::DISCONNECT; first_byte = static_cast(packetType) << 4; if (this->protocolVersion >= ProtocolVersion::Mqtt5) { writeByte(static_cast(disconnect.reasonCode)); writeProperties(disconnect.propertyBuilder); } calculateRemainingLength(); } MqttPacket::MqttPacket(const Auth &auth) : bites(auth.getLengthWithoutFixedHeader()), protocolVersion(ProtocolVersion::Mqtt5), packetType(PacketType::AUTH) { first_byte = static_cast(packetType) << 4; writeByte(static_cast(auth.reasonCode)); writeProperties(auth.propertyBuilder); calculateRemainingLength(); } MqttPacket::MqttPacket(const Connect &connect) : protocolVersion(connect.protocolVersion), packetType(PacketType::CONNECT) { first_byte = static_cast(packetType) << 4; const std::string_view magicString = connect.getMagicString(); std::optional properties; // If absent, the other side has to assume 0. if (connect.sessionExpiryInterval) non_optional(properties)->writeSessionExpiry(connect.sessionExpiryInterval); // We tell the other side they can send us topics with aliases, if set. if (connect.maxIncomingTopicAliasValue) non_optional(properties)->writeMaxTopicAliases(connect.maxIncomingTopicAliasValue); if (connect.authenticationMethod) non_optional(properties)->writeAuthenticationMethod(connect.authenticationMethod.value()); if (connect.authenticationData) non_optional(properties)->writeAuthenticationData(connect.authenticationData.value()); if (connect.fmq_client_group_id) non_optional(properties)->writeUserProperty(FMQ_CLIENT_GROUP_ID, connect.fmq_client_group_id.value()); std::optional will_properties; if (connect.will && this->protocolVersion >= ProtocolVersion::Mqtt5) will_properties = connect.will->getPropertyBuilder(); size_t len = 0; // Calculate length { len += connect.clientid.length() + 2; len += magicString.length(); len += 6; // header stuff, lengths, keep-alive if (this->protocolVersion >= ProtocolVersion::Mqtt5) len += properties ? properties->getLength() : 1; if (connect.will) { if (this->protocolVersion >= ProtocolVersion::Mqtt5) len += will_properties ? will_properties->getLength() : 1; len += connect.will->topic.length() + 2; len += connect.will->payload.length() + 2; } if (connect.username.has_value()) len += connect.username->size() + 2; if (connect.password.has_value()) len += connect.password->size() + 2; } bites.resize(len); writeString(magicString); uint8_t protocolVersionByte = static_cast(protocolVersion); if (connect.bridgeProtocolBit && protocolVersion <= ProtocolVersion::Mqtt311) // MQTT5 uses subscription options for it. protocolVersionByte |= 0x80; writeByte(protocolVersionByte); uint8_t flags = connect.clean_start << 1; flags |= static_cast(connect.username.has_value()) << 7; flags |= static_cast(connect.password.has_value()) << 6; if (connect.will) { flags |= 4; flags |= (connect.will->qos << 3); flags |= (connect.will->retain << 5); } writeByte(flags); writeUint16(connect.keepalive); if (connect.protocolVersion >= ProtocolVersion::Mqtt5) { writeProperties(properties); } writeString(connect.clientid); if (connect.will) { if (connect.protocolVersion >= ProtocolVersion::Mqtt5) { writeProperties(will_properties); } writeString(connect.will->topic); writeString(connect.will->payload); } if (connect.username.has_value()) writeString(connect.username.value()); if (connect.password.has_value()) writeString(connect.password.value()); calculateRemainingLength(); assert(pos == bites.size()); } MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const uint16_t packetId, const uint32_t subscriptionIdentifier, const std::vector &subscriptions) : packet_id(packetId), protocolVersion(protocolVersion), packetType(PacketType::SUBSCRIBE) { assert(!subscriptions.empty()); first_byte = static_cast(packetType) << 4; first_byte |= 2; // required reserved bit std::optional properties; if (protocolVersion >= ProtocolVersion::Mqtt5) { if (subscriptionIdentifier > 0) non_optional(properties)->writeSubscriptionIdentifier(subscriptionIdentifier); } size_t len = 0; for (const Subscribe &sub : subscriptions) { // Calculate length len += sub.topic.size() + 2; len += 1; // requested QoS } len += 2; // packet id if (protocolVersion >= ProtocolVersion::Mqtt5) len += properties ? properties->getLength() : 1; bites.resize(len); writeUint16(packetId); if (protocolVersion >= ProtocolVersion::Mqtt5) { writeProperties(properties); } for (const Subscribe &subscribe : subscriptions) { writeString(subscribe.topic); if (protocolVersion < ProtocolVersion::Mqtt5) { writeByte(subscribe.qos); } else { SubscriptionOptionsByte options(subscribe.qos, subscribe.noLocal, subscribe.retainAsPublished, subscribe.retainHandling); writeByte(options.b); } } calculateRemainingLength(); } MqttPacket::MqttPacket(const ProtocolVersion protocolVersion, const uint16_t packetId, const std::vector &unsubs) : packet_id(packetId), protocolVersion(protocolVersion), packetType(PacketType::UNSUBSCRIBE) { #ifndef TESTING throw NotImplementedException("Code is only for testing."); #endif first_byte = static_cast(packetType) << 4; first_byte |= 2; // required reserved bit std::optional properties; if (protocolVersion >= ProtocolVersion::Mqtt5) { // no properties we support (yet?) } size_t len = 0; for (const Unsubscribe &unsub : unsubs) { len += unsub.topic.size() + 2; } len += 2; // packet id if (protocolVersion >= ProtocolVersion::Mqtt5) { len += 1; } bites.resize(len); writeUint16(packetId); if (protocolVersion >= ProtocolVersion::Mqtt5) { writeProperties(properties); } for (const Unsubscribe &unsub : unsubs) { writeString(unsub.topic); } calculateRemainingLength(); } void MqttPacket::bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender) { if (!sender) return; if (!sender->getAuthenticated() && sender->getConnectionProtocol() < ConnectionProtocol::WebsocketMqtt && sender->getAcmeRedirectUrl()) { if (sender->tryAcmeRedirect()) return; else if (sender->getConnectionProtocol() == ConnectionProtocol::AcmeOnly) throw BadClientException("Non-ACME request on ACME-only listener"); } while (buf.usedBytes() >= MQTT_HEADER_LENGH) { // Determine the packet length by decoding the variable length int remaining_length_i = 1; // index of 'remaining length' field is one after start. uint fixed_header_length = 1; size_t multiplier = 1; size_t packet_length = 0; unsigned char encodedByte = 0; do { fixed_header_length++; if (fixed_header_length > 5) throw ProtocolError("Packet signifies more than 5 bytes in variable length header. Invalid.", ReasonCodes::MalformedPacket); // This happens when you only don't have all the bytes that specify the remaining length. if (fixed_header_length > buf.usedBytes()) return; encodedByte = buf.peakAhead(remaining_length_i++); packet_length += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128*128*128*128) throw ProtocolError("Malformed Remaining Length.", ReasonCodes::MalformedPacket); } while ((encodedByte & 128) != 0); packet_length += fixed_header_length; if (sender && !sender->getAuthenticated() && packet_length >= 1024*1024) { throw ProtocolError("An unauthenticated client sends a packet of 1 MB or bigger? Probably it's just random bytes.", ReasonCodes::ProtocolError); } const uint32_t size_limit = std::min(sender->getMaxIncomingPacketSize(), ABSOLUTE_MAX_PACKET_SIZE); if (packet_length > size_limit) { std::ostringstream oss; oss << "Packet size " << packet_length << " exceeds the server limit of " << size_limit << " bytes"; throw ProtocolError(oss.str(), ReasonCodes::PacketTooLarge); } if (packet_length <= buf.usedBytes()) { std::vector packet_bytes = buf.readToVector(packet_length); packetQueueIn.emplace_back(std::move(packet_bytes), fixed_header_length, sender); } else break; } } HandleResult MqttPacket::handle(std::shared_ptr &sender) { // For clients that send packets before they even receive a connack. if (protocolVersion == ProtocolVersion::None) protocolVersion = sender->getProtocolVersion(); // It may be a stale client. This is especially important for when a session is picked up by another client. The old client // may still have stale data in the buffer, causing action on the session otherwise. if (sender->getDisconnectStage() > DisconnectStage::NotInitiated) return HandleResult::Done; if (packetType == PacketType::Reserved) throw ProtocolError("Packet type 0 specified, which is reserved and invalid.", ReasonCodes::MalformedPacket); if (!sender->getAuthenticated()) { if (packetType == PacketType::AUTH && sender->getExtendedAuthenticationMethod().empty()) { throw ProtocolError("You can't initiate first (extended) authentication with an AUTH packet.", ReasonCodes::ProtocolError); } if (!(packetType == PacketType::CONNECT || packetType == PacketType::AUTH || packetType == PacketType::DISCONNECT || packetType == PacketType::CONNACK)) { if (sender->getAsyncAuthenticating()) return HandleResult::Defer; exceptionOnNonMqtt(this->bites); if (sender->preAuthPacketCounter++ > 200) throw ProtocolError("Too many pre-auth packets dropped", ReasonCodes::ProtocolError); logger->log(LOG_WARNING) << "Unapproved packet type (" << packetTypeToString(packetType) << ") from non-authenticated client " << sender->repr() << ". Dropping packet."; return HandleResult::Done; } } if (packetType == PacketType::PUBLISH) handlePublish(sender); else if (packetType == PacketType::PUBACK) handlePubAck(sender); else if (packetType == PacketType::PUBREC) handlePubRec(sender); else if (packetType == PacketType::PUBREL) handlePubRel(sender); else if (packetType == PacketType::PUBCOMP) handlePubComp(sender); else if (packetType == PacketType::PINGREQ) sender->writePingResp(); else if (packetType == PacketType::SUBSCRIBE) handleSubscribe(sender); else if (packetType == PacketType::UNSUBSCRIBE) handleUnsubscribe(sender); else if (packetType == PacketType::SUBACK) handleSubAck(sender); else if (packetType == PacketType::CONNECT) handleConnect(sender); else if (packetType == PacketType::DISCONNECT) handleDisconnect(sender); else if (packetType == PacketType::CONNACK) handleConnAck(sender); else if (packetType == PacketType::AUTH) handleExtendedAuth(sender); return HandleResult::Done; } ConnectData MqttPacket::parseConnectData(std::shared_ptr &sender) { if (this->packetType != PacketType::CONNECT) throw std::runtime_error("Packet must be connect packet."); setPosToDataStart(); ConnectData result; uint16_t variable_header_length = readTwoBytesToUInt16(); if (!(variable_header_length == 4 || variable_header_length == 6)) { throw ProtocolError("Invalid variable header length. Garbage?", ReasonCodes::MalformedPacket); } const Settings &settings = *ThreadGlobals::getSettings(); const std::string_view magic_marker = readBytes(variable_header_length); const uint8_t protocolVersionByte = readUint8(); result.protocol_level_byte = protocolVersionByte & 0x7F; result.bridge = protocolVersionByte & 0x80; // Unofficial, defacto, way of specifying that. MQTT5 uses subscription options for it. if (magic_marker == "MQTT") { if (result.protocol_level_byte == 0x04) protocolVersion = ProtocolVersion::Mqtt311; if (result.protocol_level_byte == 0x05) protocolVersion = ProtocolVersion::Mqtt5; } else if (magic_marker == "MQIsdp" && result.protocol_level_byte == 0x03) { protocolVersion = ProtocolVersion::Mqtt31; } else { throw ProtocolError("Packet contains invalid MQTT marker.", ReasonCodes::MalformedPacket); } // Even though we're still parsing, setting this helps the exception handler to make decisions. sender->setProtocolVersion(this->protocolVersion); char flagByte = readByte(); bool reserved = !!(flagByte & 0b00000001); if (reserved) throw ProtocolError("Protocol demands reserved flag in CONNECT is 0", ReasonCodes::MalformedPacket); bool user_name_flag = static_cast(flagByte & 0b10000000); result.password_flag = !!(flagByte & 0b01000000); result.will_retain = !!(flagByte & 0b00100000); result.will_qos = (flagByte & 0b00011000) >> 3; result.will_flag = !!(flagByte & 0b00000100); result.clean_start = !!(flagByte & 0b00000010); if (result.will_qos > 2) throw ProtocolError("Invalid QoS for will.", ReasonCodes::MalformedPacket); result.keep_alive = readTwoBytesToUInt16(); if (protocolVersion == ProtocolVersion::Mqtt5) { /* * MQTT5: "If the Session Expiry Interval is absent the value 0 is used. If it is set to 0, or is absent, * the Session ends when the Network Connection is closed." */ result.session_expire = 0; result.keep_alive = std::max(result.keep_alive, 5); const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::SessionExpiryInterval: if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.session_expire = std::min(readFourBytesToUint32(), settings.getExpireSessionAfterSeconds()); break; case Mqtt5Properties::ReceiveMaximum: if (pcounts[1]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.client_receive_max = std::min(readTwoBytesToUInt16(), result.client_receive_max); break; case Mqtt5Properties::MaximumPacketSize: if (pcounts[2]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.max_outgoing_packet_size = std::min(readFourBytesToUint32(), result.max_outgoing_packet_size); break; case Mqtt5Properties::TopicAliasMaximum: if (pcounts[3]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.max_outgoing_topic_aliases = std::min(readTwoBytesToUInt16(), settings.maxOutgoingTopicAliasValue); break; case Mqtt5Properties::RequestResponseInformation: { if (pcounts[4]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint8_t x = readUint8(); if (x > 1) throw ProtocolError(propertyToString(prop) + " must be 0 or 1", ReasonCodes::ProtocolError); result.request_response_information = static_cast(x); break; } case Mqtt5Properties::RequestProblemInformation: { if (pcounts[5]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint8_t x = readUint8(); if (x > 1) throw ProtocolError(propertyToString(prop) + " must be 0 or 1", ReasonCodes::ProtocolError); result.request_problem_information = static_cast(x); break; } case Mqtt5Properties::UserProperty: { // We (ab)use the publishData for the user properties, because it's there. std::string key = readBytesToString(settings.maxStringLength); std::string val = readBytesToString(settings.maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } case Mqtt5Properties::AuthenticationMethod: { if (pcounts[6]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.authenticationMethod = readBytesToString(settings.maxStringLength); break; } case Mqtt5Properties::AuthenticationData: { if (pcounts[7]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.authenticationData = readBytesToString(std::numeric_limits::max(), false); break; } default: throw ProtocolError("Invalid connect property.", ReasonCodes::ProtocolError); } } result.fmq_client_group_id = publishData.getFirstUserProperty(FMQ_CLIENT_GROUP_ID); if (result.fmq_client_group_id.has_value() && result.fmq_client_group_id.value().size() > 24) throw ProtocolError("FMQ Client group ID can't be longer than 24 chars.", ReasonCodes::ImplementationSpecificError); } if (result.authenticationMethod.empty() && !result.authenticationData.empty()) throw ProtocolError("Including authentication data when there is no authentication method is not allowed", ReasonCodes::ProtocolError); if (result.client_receive_max == 0 || result.max_outgoing_packet_size == 0) { throw ProtocolError("Receive max or max outgoing packet size can't be 0.", ReasonCodes::ProtocolError); } result.client_id = readBytesToString(settings.maxStringLength); if (result.will_flag) { result.willpublish.qos = result.will_qos; result.willpublish.retain = result.will_retain; if (result.will_retain) { if (settings.retainedMessagesMode == RetainedMessagesMode::DisconnectWithError) throw ProtocolError("Option 'retained_messages_mode' set to 'disconnect_with_error' and received a will with retain.", ReasonCodes::RetainNotSupported); else if (settings.retainedMessagesMode == RetainedMessagesMode::Downgrade) { result.willpublish.retain = false; result.will_retain = false; } else if (settings.retainedMessagesMode == RetainedMessagesMode::Drop) result.will_flag = false; // This will make us not pick up later, and we still parse the bytes from the packet. } result.willpublish.client_id = result.client_id; if (protocolVersion == ProtocolVersion::Mqtt5) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::WillDelayInterval: if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.willpublish.will_delay = readFourBytesToUint32(); break; case Mqtt5Properties::PayloadFormatIndicator: if (pcounts[1]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.willpublish.payloadUtf8 = static_cast(readByte()); break; case Mqtt5Properties::ContentType: { if (pcounts[2]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.willpublish.contentType = readBytesToString(settings.maxStringLength); break; } case Mqtt5Properties::ResponseTopic: { if (pcounts[3]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.willpublish.responseTopic = readBytesToString(settings.maxStringLength, true, true); if (result.willpublish.responseTopic->empty()) throw ProtocolError("Response topic in will cannot be empty", ReasonCodes::ProtocolError); break; } case Mqtt5Properties::MessageExpiryInterval: { if (pcounts[4]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint32_t expiresAfter = readFourBytesToUint32(); result.willpublish.setExpireAfter(expiresAfter); break; } case Mqtt5Properties::CorrelationData: { if (pcounts[5]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.willpublish.correlationData = readBytesToString(settings.maxStringLength, false); break; } case Mqtt5Properties::UserProperty: { std::string key = readBytesToString(settings.maxStringLength); std::string val = readBytesToString(settings.maxStringLength); result.willpublish.addUserProperty(std::move(key), std::move(val)); break; } default: throw ProtocolError("Invalid will property in connect.", ReasonCodes::ProtocolError); } } } result.willpublish.topic = readBytesToString(settings.maxStringLength, true, true); if (result.willpublish.topic.empty()) { logger->log(LOG_WARNING) << "Empty will topic is not allowed. Dropping will for client " << sender->repr() << "."; result.will_flag = false; } uint16_t will_payload_length = readTwoBytesToUInt16(); result.willpublish.payload = readBytes(will_payload_length); if (result.willpublish.payloadUtf8 && !isValidUtf8Generic(result.willpublish.payload)) { throw ProtocolError("Will payload announced as UTF8, but it's not valid.", ReasonCodes::PayloadFormatInvalid); } } else { if (result.will_retain) throw ProtocolError("Will retain bit can't be set without will.", ReasonCodes::ProtocolError); if (result.will_qos != 0) throw ProtocolError("Will QoS must be 0 when there is no will.", ReasonCodes::ProtocolError); } if (user_name_flag) { result.username = readBytesToString(settings.maxStringLength, true, false, ReasonCodes::BadUserNameOrPassword); if (result.username.value().empty()) { if (settings.zeroByteUsernameIsAnonymous) result.username.reset(); else throw ProtocolError("Attempting anonymous login with zero byte username. See config option 'zero_byte_username_is_anonymous'.", ReasonCodes::BadUserNameOrPassword); } } if (result.username) { result.willpublish.username = result.username.value(); if (!settings.allowUnsafeUsernameChars && containsDangerousCharacters(result.username.value())) throw ProtocolError(formatString("Username '%s' contains unsafe characters and 'allow_unsafe_username_chars' is false.", result.username.value().c_str()), ReasonCodes::BadUserNameOrPassword); } if (result.password_flag) { if (this->protocolVersion <= ProtocolVersion::Mqtt311 && !user_name_flag) { throw ProtocolError("MQTT 3.1.1: If the User Name Flag is set to 0, the Password Flag MUST be set to 0.", ReasonCodes::MalformedPacket); } result.password = readBytesToString(std::numeric_limits::max(), false); } return result; } ConnAckData MqttPacket::parseConnAckData() { if (this->packetType != PacketType::CONNACK) throw std::runtime_error("Packet must be connack packet."); const Settings &settings = *ThreadGlobals::getSettings(); setPosToDataStart(); ConnAckData result; const uint8_t flagByte = readByte(); result.sessionPresent = flagByte & 0x01; result.reasonCode = static_cast(readUint8()); if (protocolVersion == ProtocolVersion::Mqtt5) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::SessionExpiryInterval: if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.session_expire = std::min(readFourBytesToUint32(), result.session_expire); break; case Mqtt5Properties::ReceiveMaximum: { if (pcounts[1]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint16_t val{readTwoBytesToUInt16()}; if (val == 0) throw ProtocolError("In CONNACK: ReceiveMax can't be 0", ReasonCodes::ProtocolError); result.client_receive_max = std::min(val, result.client_receive_max); break; } case Mqtt5Properties::MaximumQoS: { if (pcounts[2]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint8_t val {readUint8()}; if (val > 2) throw ProtocolError("In CONNACK: QoS must be <= 2", ReasonCodes::ProtocolError); result.max_qos = std::min(val, result.max_qos); break; } case Mqtt5Properties::RetainAvailable: { if (pcounts[3]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint8_t val{readUint8()}; if (val > 1) throw ProtocolError("In CONNACK: RetainAvailable must be <= 1", ReasonCodes::ProtocolError); result.retained_available = static_cast(val); break; } case Mqtt5Properties::MaximumPacketSize: { if (pcounts[4]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint32_t val {readFourBytesToUint32()}; if (val == 0) throw ProtocolError("In CONNACK: MaximumPacketSize must be > 0", ReasonCodes::ProtocolError); result.max_outgoing_packet_size = std::min(val, result.max_outgoing_packet_size); break; } case Mqtt5Properties::AssignedClientIdentifier: if (pcounts[5]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.assigned_client_id = readBytesToString(settings.maxStringLength); break; case Mqtt5Properties::TopicAliasMaximum: if (pcounts[6]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.max_outgoing_topic_aliases = std::min(readTwoBytesToUInt16(), settings.maxOutgoingTopicAliasValue); break; case Mqtt5Properties::ReasonString: { if (pcounts[7]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const std::string reason = readBytesToString(settings.maxStringLength); logger->logf(LOG_NOTICE, "ConnAck reason string: %s", reason.c_str()); break; } case Mqtt5Properties::UserProperty: { std::string key = readBytesToString(settings.maxStringLength); std::string value = readBytesToString(settings.maxStringLength); break; } case Mqtt5Properties::WildcardSubscriptionAvailable: if (pcounts[8]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); readByte(); break; case Mqtt5Properties::SubscriptionIdentifierAvailable: if (pcounts[9]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); readByte(); break; case Mqtt5Properties::SharedSubscriptionAvailable: { if (pcounts[10]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint8_t val{readUint8()}; if (val > 1) throw ProtocolError("In CONNACK: SharedSubscriptionAvailable must be <= 1", ReasonCodes::ProtocolError); result.shared_subscriptions_available = static_cast(val); break; } case Mqtt5Properties::ServerKeepAlive: if (pcounts[11]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.keep_alive = readTwoBytesToUInt16(); break; case Mqtt5Properties::ResponseInformation: if (pcounts[12]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.response_information = readBytesToString(settings.maxStringLength); break; case Mqtt5Properties::ServerReference: if (pcounts[13]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.server_reference = readBytesToString(settings.maxStringLength); break; case Mqtt5Properties::AuthenticationMethod: if (pcounts[14]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.authMethod = readBytesToString(settings.maxStringLength); break; case Mqtt5Properties::AuthenticationData: if (pcounts[15]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.authData = readBytesToString(); break; default: throw ProtocolError("Invalid connack property.", ReasonCodes::ProtocolError); } } } return result; } void MqttPacket::handleConnect(std::shared_ptr &sender) { if (sender->hasConnectPacketSeen()) throw ProtocolError("Client already sent a CONNECT.", ReasonCodes::ProtocolError); std::shared_ptr subscriptionStore = globals->subscriptionStore; auto &threadData = ThreadGlobals::getThreadData(); Authentication &authentication = threadData->authentication; threadData->mqttConnectCounter.inc(); ConnectData connectData = parseConnectData(sender); sender->setHasConnectPacketSeen(); sender->setProtocolVersion(this->protocolVersion); sender->setClientType(connectData.bridge ? ClientType::Mqtt3DefactoBridge : ClientType::Normal); if (this->protocolVersion == ProtocolVersion::None) { logger->logf(LOG_ERR, "Rejecting because of invalid protocol version: %s", sender->repr().c_str()); // The specs are unclear when to use the version 3 codes or version 5 codes when you don't know which protocol version to speak. ProtocolVersion fuzzyProtocolVersion = connectData.protocol_level_byte < 0x05 ? ProtocolVersion::Mqtt31 : ProtocolVersion::Mqtt5; ConnAck connAck(fuzzyProtocolVersion, ReasonCodes::UnsupportedProtocolVersion); MqttPacket response(connAck); sender->setDisconnectStage(DisconnectStage::SendPendingAppData); sender->writeMqttPacket(response); sender->setDisconnectReason("Unsupported protocol version"); return; } const Settings &settings = *ThreadGlobals::getSettings(); bool validClientId = true; // Check for wildcard chars in case the client_id ever appears in topics. if (!settings.allowUnsafeClientidChars && containsDangerousCharacters(connectData.client_id)) { logger->logf(LOG_ERR, "ClientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", connectData.client_id.c_str()); validClientId = false; } else if (protocolVersion < ProtocolVersion::Mqtt5 && !connectData.clean_start && connectData.client_id.empty()) { logger->logf(LOG_ERR, "ClientID empty and clean start 0, which is incompatible below MQTTv5."); validClientId = false; } else if (protocolVersion < ProtocolVersion::Mqtt311 && connectData.client_id.empty()) { logger->logf(LOG_ERR, "Empty clientID. Connect with protocol 3.1.1 or higher to have one generated securely."); validClientId = false; } if (!validClientId) { ConnAck connAck(protocolVersion, ReasonCodes::ClientIdentifierNotValid); MqttPacket response(connAck); sender->setDisconnectReason("Invalid clientID"); sender->setDisconnectStage(DisconnectStage::SendPendingAppData); sender->writeMqttPacket(response); return; } bool clientIdGenerated = false; if (connectData.client_id.empty()) { connectData.client_id = getSecureRandomString(23); clientIdGenerated = true; } sender->setClientId(connectData.client_id); std::string username = connectData.username ? connectData.username.value() : ""; if (sender->getX509ClientVerification() > X509ClientVerification::None && sender->getHaProxyMode() >= HaProxyMode::HaProxyClientVerification) { throw std::runtime_error("Clients can't be verified by both haproxy and our own x509 check."); } if (sender->getX509ClientVerification() > X509ClientVerification::None) { std::optional certificateUsername = sender->getUsernameFromPeerCertificate(); if (!certificateUsername || certificateUsername.value().empty()) throw ProtocolError("Client certificate did not provider username", ReasonCodes::BadUserNameOrPassword); username = certificateUsername.value(); } if (sender->getHaProxyMode() == HaProxyMode::HaProxyClientVerification || sender->getHaProxyMode() == HaProxyMode::HaProxyClientVerficiationWithAuthn) { const std::optional &haproxy_username = sender->getHaProxySslCnName(); if (!haproxy_username || haproxy_username.value().empty()) throw BadClientException("Haproxy client configured to provide username as SSL CN, but none provided."); username = haproxy_username.value(); } sender->setClientProperties(protocolVersion, connectData.client_id, connectData.fmq_client_group_id, username, true, connectData.keep_alive, connectData.max_outgoing_packet_size, connectData.max_outgoing_topic_aliases); if (settings.willsEnabled && connectData.will_flag) { if (connectData.will_qos > sender->getMaxQos()) { if (sender->getProtocolVersion() >= ProtocolVersion::Mqtt5 || sender->getMqtt3QoSExceedAction() == Mqtt3QoSExceedAction::Disconnect) { throw ProtocolError("QoS exceeds configured maximum", ReasonCodes::QosNotSupported); } } else { sender->stageWill(std::move(connectData.willpublish)); } } // Stage connack, for immediate or delayed use when auth succeeds. { bool sessionPresent = false; std::shared_ptr existingSession; if (protocolVersion >= ProtocolVersion::Mqtt311 && !connectData.clean_start) { existingSession = subscriptionStore->lockSession(connectData.client_id); if (existingSession && !existingSession->getDestroyOnDisconnect()) sessionPresent = true; } std::unique_ptr connAck = std::make_unique(protocolVersion, ReasonCodes::Success, sessionPresent); if (protocolVersion >= ProtocolVersion::Mqtt5) { const uint8_t sender_max_qos = sender->getMaxQos(); connAck->propertyBuilder = std::make_shared(); connAck->propertyBuilder->writeSessionExpiry(connectData.session_expire); connAck->propertyBuilder->writeReceiveMax(settings.maxQosMsgPendingPerClient); if (sender_max_qos < 2) connAck->propertyBuilder->writeMaxQoS(sender_max_qos); connAck->propertyBuilder->writeRetainAvailable(settings.retainedMessagesMode <= RetainedMessagesMode::EnabledWithoutPersistence); connAck->propertyBuilder->writeMaxPacketSize(sender->getMaxIncomingPacketSize()); if (clientIdGenerated) connAck->propertyBuilder->writeAssignedClientId(connectData.client_id); connAck->propertyBuilder->writeMaxTopicAliases(sender->getMaxIncomingTopicAliasValue()); connAck->propertyBuilder->writeWildcardSubscriptionAvailable(1); connAck->propertyBuilder->writeSubscriptionIdentifiersAvailable(static_cast(settings.subscriptionIdentifierEnabled)); connAck->propertyBuilder->writeSharedSubscriptionAvailable(1); connAck->propertyBuilder->writeServerKeepAlive(connectData.keep_alive); if (!connectData.authenticationMethod.empty()) { connAck->propertyBuilder->writeAuthenticationMethod(connectData.authenticationMethod); } } sender->stageConnack(std::move(connAck)); } sender->setRegistrationData(connectData.clean_start, connectData.client_receive_max, connectData.session_expire); AuthResult authResult = AuthResult::login_denied; std::string authReturnData; bool allowAnonymous = settings.allowAnonymous; if (sender->getAllowAnonymousOverride() != AllowListenerAnonymous::None) { allowAnonymous = sender->getAllowAnonymousOverride() == AllowListenerAnonymous::Yes; } if (!connectData.username && connectData.authenticationMethod.empty() && allowAnonymous) { authResult = AuthResult::success; } else if (sender->getX509ClientVerification() == X509ClientVerification::X509IsEnough || sender->getHaProxyMode() == HaProxyMode::HaProxyClientVerification) { // The client will have been kicked out already if the certificate is not valid, so we can just approve it. authResult = AuthResult::success; } else if (connectData.authenticationMethod.empty()) { authResult = authentication.loginCheck(connectData.client_id, username, connectData.password, getUserProperties(), sender, allowAnonymous); } else { sender->setExtendedAuthenticationMethod(connectData.authenticationMethod); authResult = authentication.extendedAuth(connectData.client_id, ExtendedAuthStage::Auth, connectData.authenticationMethod, connectData.authenticationData, getUserProperties(), authReturnData, sender->getMutableUsername(), sender); } if (authResult != AuthResult::async) { threadData->continuationOfAuthentication(sender, authResult, connectData.authenticationMethod, authReturnData); } else { sender->setAsyncAuthenticating(); } } void MqttPacket::handleConnAck(std::shared_ptr &sender) { if (!sender->isOutgoingConnection()) return; if (sender->hasConnectPacketSeen()) throw ProtocolError("Client already sent a CONNACK.", ReasonCodes::ProtocolError); const Settings *settings = ThreadGlobals::getSettings(); const ConnAckData data = parseConnAckData(); if (data.reasonCode != ReasonCodes::Success) { throw ProtocolError("MQTT connect reject: " + reasonCodeToString(data.reasonCode), data.reasonCode); } if (!settings->allowUnsafeClientidChars && containsDangerousCharacters(data.assigned_client_id)) { const std::string error = formatString("Assigned clientID '%s' has + or # in the id and 'allow_unsafe_clientid_chars' is false.", data.assigned_client_id.c_str()); throw ProtocolError(error, ReasonCodes::ImplementationSpecificError); } sender->setAuthenticated(true); std::shared_ptr bridgeState = sender->getBridgeState(); // Should be impossible. if (!bridgeState) return; bridgeState->resetReconnectCounter(); logger->logf(LOG_NOTICE, "Bridge '%s' connection established. Subscribing to topics.", sender->repr().c_str()); std::shared_ptr store = globals->subscriptionStore; std::shared_ptr session = bridgeState->session->lock(); // Should be impossible. if (!session) return; if (protocolVersion >= ProtocolVersion::Mqtt311 && !data.sessionPresent) session->resetQoSData(); const uint16_t keepalive = data.keep_alive ? data.keep_alive : bridgeState->c.keepalive; const uint16_t effectiveMaxOutgoingTopicAliases = std::min(data.max_outgoing_topic_aliases, bridgeState->c.maxOutgoingTopicAliases); const bool realRetainedAvailable = data.retained_available && bridgeState->c.remoteRetainAvailable; sender->setClientProperties(true, keepalive, data.max_outgoing_packet_size, effectiveMaxOutgoingTopicAliases, realRetainedAvailable); session->setSessionProperties(data.client_receive_max, bridgeState->c.localSessionExpiryInterval, bridgeState->c.localCleanStart, bridgeState->c.protocolVersion); { std::vector subscriptions; // This resubscribes also when there is already a session with subscriptions remotely, but that is required when you change QoS levels, for instance. It // will not unsubscribe, so it will add to the existing subscriptions. // Note that this will also get you retained messages again. for(const BridgeTopicPath &sub : bridgeState->c.subscribes) { const uint8_t real_qos = std::min(data.max_qos, sub.qos); logger->log(LOG_DEBUG) << "Bridge '" << sender->repr() << "' subscribing remotely to '" << sub.topic << "', QoS=" << static_cast(real_qos) << "."; subscriptions.emplace_back(sub.topic, real_qos); /* * 'No local' is not allowed for shared subscriptions in MQTT5. However, when the other side is also FlashMQ, * that behavor can be achieved with the 'fmq_client_group_id' user property. */ if (!startsWith(sub.topic, "$share/")) subscriptions.back().noLocal = true; subscriptions.back().retainAsPublished = true; } if (!subscriptions.empty()) { MqttPacket subPacket(this->getProtocolVersion(), session->getNextPacketIdLocked(), 0, subscriptions); sender->writeMqttPacketAndBlameThisClient(subPacket); } } // It doesn't matter if we do this every time on connack; a client can only be subscribed once per pattern. // Note that this will also send all locally retained messages. I think that's the best approach: the state of retained // messages need to be synced; they may have been removed remotely while disconnected, for instance. for(const BridgeTopicPath &pub : bridgeState->c.publishes) { logger->log(LOG_DEBUG) << "Bridge '" << sender->repr() << "' subscribing locally to '" << pub.topic << "', QoS=" << static_cast(pub.qos) << "."; std::vector subtopics = splitTopic(pub.topic); std::string shareName; std::string _; parseSubscriptionShare(subtopics, shareName, _); const bool no_local = shareName.empty(); // See above about no-local. store->addSubscription(session, subtopics, pub.qos, no_local, true, shareName, 0); } ThreadGlobals::getThreadData()->publishBridgeState(bridgeState, true, {}); session->sendAllPendingQosData(); } AuthPacketData MqttPacket::parseAuthData() { if (this->packetType != PacketType::AUTH) throw std::runtime_error("Packet must be an AUTH packet."); if (first_byte & 0b1111) throw ProtocolError("AUTH packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); if (this->protocolVersion < ProtocolVersion::Mqtt5) throw ProtocolError("AUTH packet needs MQTT5 or higher", ReasonCodes::ProtocolError); const Settings *settings = ThreadGlobals::getSettings(); setPosToDataStart(); AuthPacketData result; result.reasonCode = static_cast(readUint8()); if (!atEnd()) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::AuthenticationMethod: if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.method = readBytesToString(settings->maxStringLength); break; case Mqtt5Properties::AuthenticationData: if (pcounts[1]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.data = readBytesToString(std::numeric_limits::max(), false); break; case Mqtt5Properties::ReasonString: if (pcounts[2]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); readBytesToString(settings->maxStringLength); break; case Mqtt5Properties::UserProperty: { // We (ab)use the publishData for the user properties, because it's there. std::string key = readBytesToString(settings->maxStringLength); std::string val = readBytesToString(settings->maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } default: throw ProtocolError("Invalid property in auth packet.", ReasonCodes::ProtocolError); } } } if (result.method.empty() && !result.data.empty()) throw ProtocolError("Including authentication data when there is no authentication method is not allowed", ReasonCodes::ProtocolError); return result; } void MqttPacket::handleExtendedAuth(std::shared_ptr &sender) { AuthPacketData data = parseAuthData(); if (data.method != sender->getExtendedAuthenticationMethod()) throw ProtocolError("Client continued with another authentication method that it started with.", ReasonCodes::ProtocolError); ExtendedAuthStage authStage = ExtendedAuthStage::None; switch(data.reasonCode) { case ReasonCodes::ContinueAuthentication: authStage = ExtendedAuthStage::Continue; break; case ReasonCodes::ReAuthenticate: authStage = ExtendedAuthStage::Reauth; break; default: throw ProtocolError(formatString("Invalid reason code '%d' in auth packet", static_cast(data.reasonCode)), ReasonCodes::MalformedPacket); } if (authStage == ExtendedAuthStage::Reauth && !sender->getAuthenticated()) { throw ProtocolError("Trying to reauth when client was not authenticated.", ReasonCodes::ProtocolError); } Authentication &authentication = ThreadGlobals::getThreadData()->authentication; std::string returnData; const AuthResult authResult = authentication.extendedAuth(sender->getClientId(), authStage, data.method, data.data, getUserProperties(), returnData, sender->getMutableUsername(), sender); if (authResult != AuthResult::async) { ThreadGlobals::getThreadData()->continuationOfAuthentication(sender, authResult, data.method, returnData); } else { sender->setAsyncAuthenticating(); } } DisconnectData MqttPacket::parseDisconnectData() { if (this->packetType != PacketType::DISCONNECT) throw std::runtime_error("Packet must be disconnect packet."); if (first_byte & 0b1111) throw ProtocolError("Disconnect packet first 4 bits should be 0.", ReasonCodes::MalformedPacket); const Settings *settings = ThreadGlobals::getSettings(); setPosToDataStart(); DisconnectData result; if (this->protocolVersion >= ProtocolVersion::Mqtt5) { if (!atEnd()) result.reasonCode = static_cast(readUint8()); if (!atEnd()) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::SessionExpiryInterval: { if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const Settings *settings = ThreadGlobals::getSettings(); const uint32_t session_expire = std::min(readFourBytesToUint32(), settings->getExpireSessionAfterSeconds()); result.session_expiry_interval = session_expire; result.session_expiry_interval_set = true; break; } case Mqtt5Properties::ReasonString: { if (pcounts[1]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.reasonString = readBytesToString(settings->maxStringLength); break; } case Mqtt5Properties::ServerReference: { if (pcounts[2]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); readBytesToString(settings->maxStringLength); break; } case Mqtt5Properties::UserProperty: { // We (ab)use the publishData for the user properties, because it's there. std::string key = readBytesToString(settings->maxStringLength); std::string val = readBytesToString(settings->maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } default: throw ProtocolError("Invalid property in disconnect.", ReasonCodes::ProtocolError); } } } } return result; } void MqttPacket::handleDisconnect(std::shared_ptr &sender) { if (!sender) return; DisconnectData data = parseDisconnectData(); std::string disconnectReason = "MQTT Disconnect received (reason '" + reasonCodeToString(data.reasonCode) + "')"; if (!data.reasonString.empty()) disconnectReason += data.reasonString; if (data.session_expiry_interval_set) { const std::shared_ptr session = sender->getSession(); if (session) session->setSessionExpiryInterval(data.session_expiry_interval); } logger->logf(LOG_NOTICE, "Client '%s' cleanly disconnecting", sender->repr().c_str()); sender->setDisconnectReason(disconnectReason); sender->setDisconnectStage(DisconnectStage::Now); if (data.reasonCode == ReasonCodes::Success) sender->clearWill(); ThreadGlobals::getThreadData()->removeClientQueued(sender); } void MqttPacket::handleSubscribe(std::shared_ptr &sender) { const char firstByteFirstNibble = (first_byte & 0x0F); if (firstByteFirstNibble != 2) throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.", ReasonCodes::MalformedPacket); const uint16_t packet_id = readTwoBytesToUInt16(); if (packet_id == 0) { throw ProtocolError("Packet ID 0 when subscribing is invalid.", ReasonCodes::MalformedPacket); // [MQTT-2.3.1-1] } const Settings *settings = ThreadGlobals::getSettings(); uint32_t subscription_identifier = 0; if (protocolVersion == ProtocolVersion::Mqtt5) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::SubscriptionIdentifier: { /* * This is per-spec, but when you change the setting in a running server, clients that will have already received * 'subscription identifiers enabled' in the CONNACK won't know that. On the other hand, by keep allowing * existing clients to use them, a sysop is out of control. */ if (!ThreadGlobals::getSettings()->subscriptionIdentifierEnabled) throw ProtocolError("Subscription identifiers are disabled.", ReasonCodes::ProtocolError); if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); subscription_identifier = decodeVariableByteIntAtPos(); if (subscription_identifier == 0) throw ProtocolError("Subscription identifier can't be 0", ReasonCodes::ProtocolError); break; } case Mqtt5Properties::UserProperty: { // We (ab)use the publishData for the user properties, because it's there. std::string key = readBytesToString(settings->maxStringLength); std::string val = readBytesToString(settings->maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } default: throw ProtocolError("Invalid subscribe property.", ReasonCodes::ProtocolError); } } } Authentication &authentication = ThreadGlobals::getThreadData()->authentication; std::forward_list deferredSubscribes; std::list subs_reponse_codes; while (remainingAfterPos() > 0) { std::string topic = readBytesToString(settings->maxStringLength, true); const uint8_t qos_byte = readUint8(); if (this->protocolVersion < ProtocolVersion::Mqtt5 && qos_byte > 2) { throw ProtocolError("QoS value in subscribe is higher than 2", ReasonCodes::MalformedPacket); } const SubscriptionOptionsByte options(qos_byte); uint8_t qos = options.getQos(); const bool mqtt3bridge = sender->getClientType() == ClientType::Mqtt3DefactoBridge; const bool noLocal = mqtt3bridge || options.getNoLocal(); const bool retainedAsPublished = mqtt3bridge || options.getRetainAsPublished(); const RetainHandling retainHandling = options.getRetainHandling(); if (logger->wouldLog(LOG_SUBSCRIBE)) { logger->log(LOG_SUBSCRIBE) << "Client '" << sender->repr() << "' subscribing to '" << topic << "' with QoS " << static_cast(qos) << ", no_local=" << noLocal << ", retain_as_published=" << retainedAsPublished << "."; } std::vector subtopics = splitTopic(topic); if (authentication.alterSubscribe(sender->getClientId(), topic, subtopics, qos, getUserProperties())) subtopics = splitTopic(topic); if (topic.empty()) throw ProtocolError("Subscribe topic is empty.", ReasonCodes::MalformedPacket); if (!isValidSubscribePath(topic)) throw ProtocolError(formatString("Invalid subscribe path: %s", topic.c_str()), ReasonCodes::MalformedPacket); if (qos > 2) throw ProtocolError("QoS is greater than 2, and/or reserved bytes in QoS field are not 0.", ReasonCodes::MalformedPacket); qos = std::min(qos, sender->getMaxQos()); std::string shareName; parseSubscriptionShare(subtopics, shareName, topic); if (!shareName.empty() && noLocal) { throw ProtocolError("It is a Protocol Error to set the No Local bit to 1 on a Shared Subscription", ReasonCodes::ProtocolError); } const Settings *settings = ThreadGlobals::getSettings(); AuthResult authResult = AuthResult::success; if (settings->minimumWildcardSubscriptionDepth > 0 && getFirstWildcardDepth(subtopics) < settings->minimumWildcardSubscriptionDepth) { std::string action_text = ""; if (settings->wildcardSubscriptionDenyMode == WildcardSubscriptionDenyMode::DenyRetainedOnly) { authResult = AuthResult::success_without_retained_delivery; action_text = "Denying retained messages"; } else { authResult = AuthResult::acl_denied; action_text = "Denying subscription"; } logger->log(LOG_WARNING) << "Wildcard subscription too broad. " << action_text << ". Topic: '" << topic << "'. Client: " << sender->repr(); } if (authResult == AuthResult::success || authResult == AuthResult::success_without_retained_delivery) { const AuthResult newAuthResult = authentication.aclCheck( sender->getClientId(), sender->getUsername(), topic, subtopics, shareName, std::string_view(), AclAccess::subscribe, qos, false, std::optional(), std::optional(), {}, {}, getUserProperties()); // We don't allow upgrading back to success. This gets too complicated between having no additional ACL, having an ACL file, and/or a plugin. if (newAuthResult != AuthResult::success) authResult = newAuthResult; } if (authResult == AuthResult::success || authResult == AuthResult::success_without_retained_delivery || authResult == AuthResult::success_but_drop) { const uint32_t subscr_id = settings->subscriptionIdentifierEnabled ? subscription_identifier : 0; deferredSubscribes.emplace_front(topic, subtopics, qos, noLocal, retainedAsPublished, shareName, authResult, subscr_id, retainHandling); subs_reponse_codes.push_back(static_cast(qos)); } else { if (logger->wouldLog(LOG_SUBSCRIBE)) logger->logf(LOG_SUBSCRIBE, "Client '%s' subscribe to '%s' denied or failed.", sender->repr().c_str(), topic.c_str()); // We can't not send an ack, because if there are multiple subscribes, you'd send fewer acks back, losing sync. ReasonCodes return_code = sender->getProtocolVersion() >= ProtocolVersion::Mqtt311 ? ReasonCodes::NotAuthorized : static_cast(qos); subs_reponse_codes.push_back(return_code); } } // MQTT-3.8.3-3 if (subs_reponse_codes.empty()) { throw ProtocolError("No topics specified to subscribe to.", ReasonCodes::MalformedPacket); } SubAck subAck(this->protocolVersion, packet_id, subs_reponse_codes); MqttPacket response(subAck); sender->writeMqttPacket(response); std::shared_ptr session = sender->getSession(); // Adding the subscription will also send publishes for retained messages, so that's why we're doing it at the end. for(const SubscriptionTuple &tup : deferredSubscribes) { if (tup.authResult == AuthResult::success_but_drop) continue; auto store = globals->subscriptionStore; const AddSubscriptionType add_type = store->addSubscription( session, tup.subtopics, tup.qos, tup.noLocal, tup.retainAsPublished, tup.shareName, tup.subscriptionIdentifier); if (tup.authResult == AuthResult::success && tup.shareName.empty()) { if ((tup.retainHandling == RetainHandling::SendRetainedMessagesAtSubscribe) || (tup.retainHandling == RetainHandling::SendRetainedMessagesAtNewSubscribeOnly && add_type == AddSubscriptionType::NewSubscription) ) { store->giveClientRetainedMessages(session, tup.subtopics, tup.qos, tup.subscriptionIdentifier); } } } } void MqttPacket::handleSubAck(std::shared_ptr &sender) { if (!sender->isOutgoingConnection()) return; const SubAckData data = parseSubAckData(); if (logger->wouldLog(LOG_SUBSCRIBE)) { logger->log(LOG_SUBSCRIBE) << "SUBACK received from '" << sender->getClientId() << "' with id " << data.packet_id << " having " << data.subAckCodes.size() << " ack(s) in it."; } std::shared_ptr bridgeState = sender->getBridgeState(); // Should be impossible. if (!bridgeState) return; std::shared_ptr session = bridgeState->session->lock(); // Should be impossible. if (!session) return; session->increaseFlowControlQuotaLocked(); } void MqttPacket::handleUnsubscribe(std::shared_ptr &sender) { const char firstByteFirstNibble = (first_byte & 0x0F); if (firstByteFirstNibble != 2) throw ProtocolError("First LSB of first byte is wrong value for subscribe packet.", ReasonCodes::MalformedPacket); const Settings *settings = ThreadGlobals::getSettings(); const uint16_t packet_id = readTwoBytesToUInt16(); if (packet_id == 0) { throw ProtocolError("Packet ID 0 when unsubscribing is invalid.", ReasonCodes::ProtocolError); // [MQTT-2.3.1-1] } if (protocolVersion == ProtocolVersion::Mqtt5) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::UserProperty: { // We (ab)use the publishData for the user properties, because it's there. std::string key = readBytesToString(settings->maxStringLength); std::string val = readBytesToString(settings->maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } default: throw ProtocolError("Invalid unsubscribe property.", ReasonCodes::ProtocolError); } } } int numberOfUnsubs = 0; std::shared_ptr session = sender->getSession(); while (remainingAfterPos() > 0) { numberOfUnsubs++; const std::string topic = readBytesToString(settings->maxStringLength); if (topic.empty()) throw ProtocolError("Unsubscribe topic is empty.", ReasonCodes::MalformedPacket); if (!isValidSubscribePath(topic)) throw ProtocolError("Unsubscribe topic is invalid: " + topic, ReasonCodes::MalformedPacket); std::vector subtopics = splitTopic(topic); std::string shareName; std::string topic_without_sharename = topic; parseSubscriptionShare(subtopics, shareName, topic_without_sharename); globals->subscriptionStore->removeSubscription(session, subtopics, shareName); const Authentication &auth = ThreadGlobals::getThreadData()->authentication; auth.onUnsubscribe(session, sender->getClientId(), sender->getUsername(), topic_without_sharename, subtopics, shareName, getUserProperties()); if (logger->wouldLog(LOG_UNSUBSCRIBE)) logger->logf(LOG_UNSUBSCRIBE, "Client '%s' unsubscribed from '%s'", sender->repr().c_str(), topic.c_str()); } // MQTT-3.10.3-2 if (numberOfUnsubs == 0) { throw ProtocolError("No topics specified to unsubscribe to.", ReasonCodes::MalformedPacket); } UnsubAck unsubAck(sender->getProtocolVersion(), packet_id, numberOfUnsubs); MqttPacket response(unsubAck); sender->writeMqttPacket(response); } void MqttPacket::parsePublishData(std::shared_ptr &sender) { assert(externallyReceived); setPosToDataStart(); publishData.retain = (first_byte & 0b00000001); const bool duplicate = !!(first_byte & 0b00001000); publishData.qos = (first_byte & 0b00000110) >> 1; if (publishData.qos > 2) throw ProtocolError("QoS 3 is a protocol violation.", ReasonCodes::MalformedPacket); if (publishData.qos == 0 && duplicate) throw ProtocolError("Duplicate flag is set for QoS 0 packet. This is illegal.", ReasonCodes::MalformedPacket); publishData.username = sender->getUsername(); publishData.client_id = sender->getClientId(); const Settings *settings = ThreadGlobals::getSettings(); publishData.topic = readBytesToString(settings->maxStringLength, true, true); if (publishData.qos) { packet_id_pos = pos; packet_id = readTwoBytesToUInt16(); if (packet_id == 0) { throw ProtocolError("Packet ID 0 when publishing is invalid.", ReasonCodes::MalformedPacket); // [MQTT-2.3.1-1] } } if (this->protocolVersion >= ProtocolVersion::Mqtt5 ) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::PayloadFormatIndicator: if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); publishData.payloadUtf8 = static_cast(readByte()); break; case Mqtt5Properties::MessageExpiryInterval: if (pcounts[1]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); publishData.setExpireAfter(readFourBytesToUint32()); break; case Mqtt5Properties::TopicAlias: { // For when we use packets has helpers without a senser (like loading packets from disk). // Logically, this should never trip because there can't be aliases in such packets, but including // a check to be sure. if (!sender) break; if (pcounts[2]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); const uint16_t alias_id = readTwoBytesToUInt16(); if (alias_id == 0) throw ProtocolError("Topic alias ID 0 is invalid.", ReasonCodes::TopicAliasInvalid); this->dontReuseBites = true; if (publishData.topic.empty()) { publishData.topic = sender->getTopicAlias(alias_id); } else { sender->setTopicAlias(alias_id, publishData.topic); } // Just making clear we don't want to store the alias in the publish object. It has lost its meaning from this point on. assert(this->publishData.topicAlias == 0); break; } case Mqtt5Properties::ResponseTopic: { if (pcounts[3]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); publishData.responseTopic = readBytesToString(settings->maxStringLength, true, true); if (publishData.responseTopic->empty()) throw ProtocolError("Response topic cannot be empty", ReasonCodes::ProtocolError); break; } case Mqtt5Properties::CorrelationData: { if (pcounts[4]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); publishData.correlationData = readBytesToString(settings->maxStringLength, false); break; } case Mqtt5Properties::UserProperty: { std::string key = readBytesToString(settings->maxStringLength); std::string val = readBytesToString(settings->maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } case Mqtt5Properties::SubscriptionIdentifier: { if (pcounts[5]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); dontReuseBites = true; #ifndef TESTING if (sender->getClientType() != ClientType::LocalBridge) throw ProtocolError("Subscription identifiers cannot be sent to servers.", ReasonCodes::ProtocolError); // We don't store it, because it should not propagate. decodeVariableByteIntAtPos(); #else publishData.subscriptionIdentifierTesting = decodeVariableByteIntAtPos(); #endif break; } case Mqtt5Properties::ContentType: { if (pcounts[6]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); publishData.contentType = readBytesToString(settings->maxStringLength, true, false); break; } default: throw ProtocolError("Invalid property in publish.", ReasonCodes::ProtocolError); } } } if (publishData.topic.empty()) throw ProtocolError("Empty publish topic", ReasonCodes::ProtocolError); payloadLen = remainingAfterPos(); payloadStart = pos; // Not using SIMD UTF8 checker because that requires making a copy, and requires being able to deal with large strings. if (publishData.payloadUtf8 && !isValidUtf8Generic(getPayloadView())) { throw ProtocolError("Payload announced as UTF8, but it's not valid.", ReasonCodes::PayloadFormatInvalid); } } void MqttPacket::handlePublish(std::shared_ptr &sender) { parsePublishData(sender); if (__builtin_expect(logger->wouldLog(LOG_PUBLISH), 0)) { const bool duplicate = !!(first_byte & 0b00001000); logger->log(LOG_PUBLISH) << "Publish received from '" << sender->repr() << "'. Size: " << bites.size() << ". Topic: '" << publishData.topic << "'. QoS=" << static_cast(publishData.qos) << ". Retain=" << publishData.retain << ". Dup=" << duplicate << ". Alias=" << publishData.topicAlias << "."; } ThreadGlobals::getThreadData()->receivedMessageCounter.inc(); /* * Topic prefixing. Currently, remote prefix is removed and local prefixes is added before any ACL checks are * done. That seems to make the most sense. */ { const std::optional &remote_prefix = sender->getRemotePrefix(); const std::optional &local_prefix = sender->getLocalPrefix(); bool topic_changed = false; // The size check is to prevent making "remote/haha/" into "" when the remote_prefix is "remote/haha/" if (remote_prefix && startsWith(publishData.topic, *remote_prefix) && publishData.topic.size() > remote_prefix->size()) { topic_changed = true; publishData.topic.erase(0, remote_prefix->length()); } if (local_prefix) { topic_changed = true; publishData.topic = *local_prefix + publishData.topic; } if (topic_changed) { dontReuseBites = true; publishData.resplitTopic(); if (publishData.topic.empty()) throw ProtocolError("Empty publish topic", ReasonCodes::ProtocolError); } } Authentication &authentication = ThreadGlobals::getThreadData()->authentication; const Settings *settings = ThreadGlobals::getSettings(); // Working with a local copy because the subscribing action will modify this->packet_id. See the PublishCopyFactory. const uint16_t _packet_id = this->packet_id; // Stage the ack, with the proper ID. AckSender ackSender(this->publishData.qos, this->packet_id, this->protocolVersion, sender); if (publishData.retain && settings->retainedMessagesMode == RetainedMessagesMode::DisconnectWithError) { throw ProtocolError("Retained messages not supported and 'retained_messages_mode' set to 'disconnect_with_error'.", ReasonCodes::RetainNotSupported); } if (publishData.qos > sender->getMaxQos()) { if (sender->getProtocolVersion() >= ProtocolVersion::Mqtt5 || sender->getMqtt3QoSExceedAction() == Mqtt3QoSExceedAction::Disconnect) { throw ProtocolError("QoS exceeds configured maximum", ReasonCodes::QosNotSupported); } else if (sender->getMqtt3QoSExceedAction() == Mqtt3QoSExceedAction::Drop) { ackSender.sendNow(); } } else if (publishData.qos == 2 && sender->getSession()->incomingQoS2MessageIdInTransit(_packet_id)) { ackSender.setAckCode(ReasonCodes::PacketIdentifierInUse); } else { // Doing this before the authentication on purpose, so when the publish is not allowed, the QoS control packets are allowed and can finish. if (publishData.qos == 2) sender->getSession()->addIncomingQoS2MessageId(_packet_id); const uint8_t qos_org = this->publishData.qos; const bool retain_org = this->publishData.retain; const bool altered = authentication.alterPublish( this->publishData.client_id, this->publishData.topic, this->publishData.getSubtopics(), getPayloadView(), this->publishData.qos, this->publishData.retain, this->publishData.correlationData, this->publishData.responseTopic, this->publishData.contentType, this->publishData.getUserProperties()); if (altered) this->publishData.resplitTopic(); if (retain_org != publishData.retain) setRetain(publishData.retain); // Don't look at 'retain', because a changed retain bit doesn't alter the byte layout of the original packet. if (altered || qos_org != this->publishData.qos) this->dontReuseBites = true; const AuthResult authResult = authentication.aclCheck(this->publishData, getPayloadView()); if (authResult == AuthResult::success || authResult == AuthResult::success_without_setting_retained) { if (publishData.retain) { if (authResult == AuthResult::success && settings->retainedMessagesMode <= RetainedMessagesMode::EnabledWithoutPersistence) { publishData.payload = getPayloadCopy(); globals->subscriptionStore->trySetRetainedMessages(publishData, publishData.getSubtopics()); } else if (settings->retainedMessagesMode == RetainedMessagesMode::Downgrade) { publishData.retain = false; bites[0] &= 0b11111110; first_byte = bites[0]; } } if (!publishData.retain || settings->retainedMessagesMode <= RetainedMessagesMode::Downgrade) { // Set dup flag to 0, because that must not be propagated [MQTT-3.3.1-3]. bites[0] &= 0b11110111; first_byte = bites[0]; PublishCopyFactory factory(this); ackSender.sendNow(); globals->subscriptionStore->queuePacketAtSubscribers(factory, sender->getClientId(), sender->getFmqClientGroupId()); } } else if (authResult == AuthResult::success_but_drop_publish) { ackSender.sendNow(); } else { ackSender.setAckCode(ReasonCodes::NotAuthorized); } } #ifndef NDEBUG // Protection against using the altered packet id (because we change the incoming byte array for each subscriber). this->packet_id = 0; this->publishData.qos = 0; if (publishData.qos > 0) this->setPacketId(0); #endif } void MqttPacket::parsePubAckData() { setPosToDataStart(); this->publishData.qos = 1; this->packet_id = readTwoBytesToUInt16(); // TODO: if we ever parse the reason code and use it to make decisions, check for validity. But, as of yet, // checking validity would just add overhead for no reason. if (this->packet_id == 0) throw ProtocolError("QoS packets must have packet ID > 0.", ReasonCodes::ProtocolError); } void MqttPacket::handlePubAck(std::shared_ptr &sender) { parsePubAckData(); sender->getSession()->clearQosMessage(packet_id, true); } PubRecData MqttPacket::parsePubRecData() { setPosToDataStart(); this->publishData.qos = 2; this->packet_id = readTwoBytesToUInt16(); if (this->packet_id == 0) throw ProtocolError("QoS packets must have packet ID > 0.", ReasonCodes::ProtocolError); PubRecData result; if (!atEnd()) { result.reasonCode = static_cast(readUint8()); if (result.reasonCode > ReasonCodes::Success) { switch(result.reasonCode) { case ReasonCodes::Success: case ReasonCodes::NoMatchingSubscribers: case ReasonCodes::UnspecifiedError: case ReasonCodes::ImplementationSpecificError: case ReasonCodes::NotAuthorized: case ReasonCodes::TopicNameInvalid: case ReasonCodes::PacketIdentifierInUse: case ReasonCodes::QuotaExceeded: case ReasonCodes::PayloadFormatInvalid: { break; } default: throw ProtocolError("Invalid reason code in PUBREC", ReasonCodes::ProtocolError); } } } return result; } /** * @brief MqttPacket::handlePubRec handles QoS 2 'publish received' packets. The publisher receives these. */ void MqttPacket::handlePubRec(std::shared_ptr &sender) { PubRecData data = parsePubRecData(); const bool publishTerminatesHere = data.reasonCode >= ReasonCodes::UnspecifiedError; const bool foundAndRemoved = sender->getSession()->clearQosMessage(packet_id, publishTerminatesHere); // "If it has sent a PUBREC with a Reason Code of 0x80 or greater, the receiver MUST treat any subsequent PUBLISH packet // that contains that Packet Identifier as being a new Application Message." if (!publishTerminatesHere) { sender->getSession()->addOutgoingQoS2MessageId(packet_id); // MQTT5: "[The sender] MUST send a PUBREL packet when it receives a PUBREC packet from the receiver with a Reason Code value less than 0x80" const ReasonCodes reason = foundAndRemoved ? ReasonCodes::Success : ReasonCodes::PacketIdentifierNotFound; PubResponse pubRel(this->protocolVersion, PacketType::PUBREL, reason, packet_id); MqttPacket response(pubRel); sender->writeMqttPacket(response); } } void MqttPacket::parsePubRelData() { // MQTT-3.6.1-1, but why do we care, and only care for certain control packets? if ((first_byte & 0b00001111) != 0b00000010) throw ProtocolError("PUBREL first byte LSB must be 0010.", ReasonCodes::MalformedPacket); setPosToDataStart(); this->publishData.qos = 2; this->packet_id = readTwoBytesToUInt16(); } /** * @brief MqttPacket::handlePubRel handles QoS 2 'publish release'. The publisher sends these. */ void MqttPacket::handlePubRel(std::shared_ptr &sender) { parsePubRelData(); if (this->packet_id == 0) throw ProtocolError("QoS packets must have packet ID > 0.", ReasonCodes::ProtocolError); const bool foundAndRemoved = sender->getSession()->removeIncomingQoS2MessageId(packet_id); const ReasonCodes reason = foundAndRemoved ? ReasonCodes::Success : ReasonCodes::PacketIdentifierNotFound; PubResponse pubcomp(this->protocolVersion, PacketType::PUBCOMP, reason, packet_id); MqttPacket response(pubcomp); sender->writeMqttPacket(response); } void MqttPacket::parsePubComp() { setPosToDataStart(); this->publishData.qos = 2; this->packet_id = readTwoBytesToUInt16(); if (this->packet_id == 0) throw ProtocolError("QoS packets must have packet ID > 0.", ReasonCodes::ProtocolError); } /** * @brief MqttPacket::handlePubComp handles QoS 2 'publish complete'. The publisher receives these. */ void MqttPacket::handlePubComp(std::shared_ptr &sender) { parsePubComp(); sender->getSession()->removeOutgoingQoS2MessageId(packet_id); } SubAckData MqttPacket::parseSubAckData() { if (this->packetType != PacketType::SUBACK) throw std::runtime_error("Packet must be suback packet."); setPosToDataStart(); const Settings *settings = ThreadGlobals::getSettings(); SubAckData result; result.packet_id = readTwoBytesToUInt16(); this->packet_id = result.packet_id; if (this->protocolVersion >= ProtocolVersion::Mqtt5 ) { const size_t proplen = decodeVariableByteIntAtPos(); const size_t prop_end_at = pos + proplen; std::array pcounts; pcounts.fill(0); while (withinBound(prop_end_at)) { const Mqtt5Properties prop = static_cast(readUint8()); switch (prop) { case Mqtt5Properties::ReasonString: if (pcounts[0]++ > 0) throw ProtocolError("Can't specify " + propertyToString(prop) + " more than once", ReasonCodes::ProtocolError); result.reasonString = readBytesToString(settings->maxStringLength); break; case Mqtt5Properties::UserProperty: { // We (ab)use the publishData for the user properties, because it's there. std::string key = readBytesToString(settings->maxStringLength); std::string val = readBytesToString(settings->maxStringLength); this->publishData.addUserProperty(std::move(key), std::move(val)); break; } default: throw ProtocolError("Invalid property in suback.", ReasonCodes::ProtocolError); } } } // payload starts here while (!atEnd()) { uint8_t code = readByte(); result.subAckCodes.push_back(code); } return result; } void MqttPacket::calculateRemainingLength() { assert(fixed_header_length == 0); // because you're not supposed to call this on packet that we already know the length of. this->remainingLength = bites.size(); } void MqttPacket::setPosToDataStart() { this->pos = this->fixed_header_length; } bool MqttPacket::atEnd() const { assert(pos <= bites.size()); return pos >= bites.size(); } bool MqttPacket::withinBound(const size_t limit) const { if (pos > limit) throw ProtocolError("Out of bounds", ReasonCodes::MalformedPacket); return pos < limit; } void MqttPacket::setPacketId(uint16_t packet_id) { assert(fixed_header_length == 0 || first_byte == bites[0]); assert(packet_id_pos > 0); assert(packetType == PacketType::PUBLISH); assert(publishData.qos > 0); this->packet_id = packet_id; pos = packet_id_pos; writeUint16(packet_id); } uint16_t MqttPacket::getPacketId() const { assert(publishData.qos > 0); return packet_id; } // If I read the specs correctly, the DUP flag is merely for show. It doesn't control anything? void MqttPacket::setDuplicate() { assert(packetType == PacketType::PUBLISH); assert(publishData.qos > 0); assert(fixed_header_length == 0 || first_byte == bites[0]); first_byte |= 0b00001000; if (fixed_header_length > 0) { pos = 0; writeByte(first_byte); } } /** * @brief MqttPacket::getPayloadCopy takes part of the vector of bytes and returns it as a string. * @return */ std::string MqttPacket::getPayloadCopy() const { std::string payload(getPayloadView()); return payload; } std::string_view MqttPacket::getPayloadView() const { assert(payloadStart > 0); assert(pos <= bites.size()); FMQ_ENSURE(payloadStart + payloadLen <= bites.size()); std::string_view payload(bites.data() + payloadStart, payloadLen); return payload; } uint8_t MqttPacket::getFixedHeaderLength() const { size_t result = this->fixed_header_length; if (result == 0) { result++; // first byte it always there. result += remainingLength.getLen(); } return result; } size_t MqttPacket::getSizeIncludingNonPresentHeader() const { size_t total = bites.size(); if (fixed_header_length == 0) { total += getFixedHeaderLength(); } return total; } void MqttPacket::setQos(const uint8_t new_qos) { // You can't change to a QoS level that would remove the packet identifier. assert((publishData.qos == 0 && new_qos == 0) || (publishData.qos > 0 && new_qos > 0)); assert(new_qos > 0 && packet_id_pos > 0); publishData.qos = new_qos; first_byte &= 0b11111001; first_byte |= (publishData.qos << 1); if (fixed_header_length > 0) { pos = 0; writeByte(first_byte); } } const std::string &MqttPacket::getTopic() const { return this->publishData.topic; } const std::vector &MqttPacket::getSubtopics() { return this->publishData.getSubtopics(); } bool MqttPacket::containsFixedHeader() const { return fixed_header_length > 0; } std::string_view MqttPacket::readBytes(size_t length) { if (pos + length > bites.size()) throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); std::string_view result(bites.data() + pos, length); pos += length; return result; } char MqttPacket::readByte() { if (pos + 1 > bites.size()) throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); char b = bites[pos++]; return b; } uint8_t MqttPacket::readUint8() { char r = readByte(); return static_cast(r); } void MqttPacket::writeByte(char b) { if (pos + 1 > bites.size()) throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket); bites[pos++] = b; } void MqttPacket::writeUint16(uint16_t x) { if (pos + 2 > bites.size()) throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket); const uint8_t a = static_cast(x >> 8); const uint8_t b = static_cast(x); bites[pos++] = a; bites[pos++] = b; } void MqttPacket::writeProperties(Mqtt5PropertyBuilder &properties) { writeVariableByteInt(properties.getVarInt()); const std::vector &b = properties.getBytes(); write(b); } void MqttPacket::writeVariableByteInt(const VariableByteInt &v) { write(v); } void MqttPacket::writeString(const std::string &s) { writeUint16(s.length()); write(s); } void MqttPacket::writeString(std::string_view s) { writeUint16(s.length()); write(s); } uint16_t MqttPacket::readTwoBytesToUInt16() { if (pos + 2 > bites.size()) throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); uint8_t a = bites[pos]; uint8_t b = bites[pos+1]; uint16_t i = a << 8 | b; pos += 2; return i; } uint32_t MqttPacket::readFourBytesToUint32() { if (pos + 4 > bites.size()) throw ProtocolError("Invalid packet: header specifies invalid length.", ReasonCodes::MalformedPacket); const uint8_t a = bites[pos++]; const uint8_t b = bites[pos++]; const uint8_t c = bites[pos++]; const uint8_t d = bites[pos++]; uint32_t i = (a << 24) | (b << 16) | (c << 8) | d; return i; } size_t MqttPacket::remainingAfterPos() { assert(pos <= bites.size()); return bites.size() - pos; } size_t MqttPacket::decodeVariableByteIntAtPos() { uint64_t multiplier = 1; size_t value = 0; uint8_t encodedByte = 0; do { if (pos >= bites.size()) throw ProtocolError("Variable byte int length goes out of packet. Corrupt.", ReasonCodes::MalformedPacket); encodedByte = bites[pos++]; value += (encodedByte & 127) * multiplier; multiplier *= 128; if (multiplier > 128*128*128*128) throw ProtocolError("Malformed Remaining Length.", ReasonCodes::MalformedPacket); } while ((encodedByte & 128) != 0); return value; } std::string MqttPacket::readBytesToString(const uint16_t maxLength, bool validateUtf8, bool alsoCheckInvalidPublishChars, const ReasonCodes reasonCode) { assert(maxLength > 0); const uint16_t len = readTwoBytesToUInt16(); if (len > maxLength) throw std::runtime_error("Client sent string longer than 'max_string_length'"); std::string result(readBytes(len)); if (validateUtf8) { if (!isValidUtf8(result, alsoCheckInvalidPublishChars)) { logger->logf(LOG_DEBUG, "Data of invalid UTF-8 string or publish topic: %s", result.c_str()); throw ProtocolError("Invalid UTF8 string detected, or invalid publish characters.", reasonCode); } } return result; } std::vector> *MqttPacket::getUserProperties() const { return this->publishData.getUserProperties(); } const std::optional &MqttPacket::getCorrelationData() const { return this->publishData.correlationData; } const std::optional &MqttPacket::getResponseTopic() const { return this->publishData.responseTopic; } const std::optional &MqttPacket::getContentType() const { return this->publishData.contentType; } const std::optional > MqttPacket::getExpiresAt() const { return publishData.expiresAt(); } bool MqttPacket::getRetain() const { return (first_byte & 0b00000001); } void MqttPacket::setRetain(bool val) { first_byte &= 0b11111110; first_byte |= static_cast(val); if (fixed_header_length > 0) { pos = 0; writeByte(first_byte); } } const Publish &MqttPacket::getPublishData() { if (payloadLen > 0 && publishData.payload.empty()) publishData.payload = getPayloadCopy(); return publishData; } bool MqttPacket::biteArrayCannotBeReused() const { assert(packetType == PacketType::PUBLISH); assert(this->externallyReceived); assert(this->publishData.topicAlias == 0); // The topic alias should not be stored in the incoming publish. return this->dontReuseBites; } void MqttPacket::readIntoBuf(CirBuf &buf) const { assert(packetType != PacketType::PUBLISH || (first_byte & 0b00000110) >> 1 == publishData.qos); assert(publishData.qos == 0 || packet_id > 0); if (!containsFixedHeader()) { buf.write(first_byte); remainingLength.readIntoBuf(buf); } else { assert(bites.data()[0] == first_byte); } buf.writerange(bites.begin(), bites.end()); } SubscriptionTuple::SubscriptionTuple(const std::string &topic, const std::vector &subtopics, uint8_t qos, bool noLocal, bool retainAsPublished, const std::string &shareName, const AuthResult authResult, const uint32_t subscriptionIdentifier, const RetainHandling retainHandling) : topic(topic), subtopics(subtopics), qos(qos), noLocal(noLocal), retainAsPublished(retainAsPublished), shareName(shareName), authResult(authResult), subscriptionIdentifier(subscriptionIdentifier), retainHandling(retainHandling) { } ================================================ FILE: mqttpacket.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef MQTTPACKET_H #define MQTTPACKET_H #include #include #include #include "forward_declarations.h" #include "types.h" #include "cirbuf.h" #include "logger.h" #include "variablebyteint.h" #include "mqtt5properties.h" #include "packetdatatypes.h" #include "exceptions.h" enum class HandleResult { Done, Defer }; /** * @brief The MqttPacket class represents incoming and outgoing packets. * * Be sure to understand the 'externallyReceived' member. See in-code documentation. */ class MqttPacket { #ifdef TESTING friend class MainTests; #endif std::vector bites; Publish publishData; size_t fixed_header_length = 0; // if 0, this packet does not contain the bytes of the fixed header. VariableByteInt remainingLength; char first_byte = 0; size_t pos = 0; size_t packet_id_pos = 0; uint16_t packet_id = 0; ProtocolVersion protocolVersion = ProtocolVersion::None; size_t payloadStart = 0; size_t payloadLen = 0; bool dontReuseBites = false; // It's important to understand that this class is used for incoming packets as well as new outgoing packets. When we create // new outgoing packets, we generally know exactly who it's for and the information is only stored in this->bites. So, the // publishData and fields like hasTopicAlias are invalid in those cases. bool externallyReceived = false; Logger *logger = Logger::getInstance(); std::string_view readBytes(size_t length); char readByte(); uint8_t readUint8(); void writeByte(char b); void writeUint16(uint16_t x); template void write(T &src) { if (pos + src.size() > bites.size()) throw ProtocolError("Exceeding packet size", ReasonCodes::MalformedPacket); std::copy(src.begin(), src.end(), bites.begin() + pos); pos += src.size(); } template void writeProperties(T properties) { if (!properties) { writeByte(0); return; } writeProperties(*properties); } void writeProperties(Mqtt5PropertyBuilder &properties); void writeVariableByteInt(const VariableByteInt &v); void writeString(const std::string &s); void writeString(std::string_view s); uint16_t readTwoBytesToUInt16(); uint32_t readFourBytesToUint32(); size_t remainingAfterPos(); size_t decodeVariableByteIntAtPos(); std::string readBytesToString( const uint16_t maxLength = std::numeric_limits::max(), bool validateUtf8 = true, bool alsoCheckInvalidPublishChars = false, const ReasonCodes reasonCode = ReasonCodes::MalformedPacket); void calculateRemainingLength(); void setPosToDataStart(); bool atEnd() const; bool withinBound(const size_t limit) const; #ifndef TESTING // In production, I want to be sure I don't accidentally copy packets, because it's slow. MqttPacket(const MqttPacket &other) = delete; #endif public: #ifdef TESTING // In testing I need to copy packets for administrative purposes. MqttPacket(const MqttPacket &other) = default; #endif PacketType packetType = PacketType::Reserved; MqttPacket(std::vector &&packet_bytes, size_t fixed_header_length, std::shared_ptr &sender); // Constructor for parsing incoming packets. MqttPacket(MqttPacket &&other) = default; // Constructor for outgoing packets. These may not allocate room for the fixed header, because we don't (always) know the length in advance. MqttPacket(const ConnAck &connAck); MqttPacket(const SubAck &subAck); MqttPacket(const UnsubAck &unsubAck); MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish); MqttPacket(const ProtocolVersion protocolVersion, const Publish &_publish, const uint8_t _qos, const uint16_t _topic_alias, const bool _skip_topic, const uint32_t subscriptionIdentifier, const std::optional &topic_override); MqttPacket(const PubResponse &pubAck); MqttPacket(const Disconnect &disconnect); MqttPacket(const Auth &auth); MqttPacket(const Connect &connect); MqttPacket(const ProtocolVersion protocolVersion, const uint16_t packetId, const uint32_t subscriptionIdentifier, const std::vector &subscriptions); MqttPacket(const ProtocolVersion protocolVersion, const uint16_t packetId, const std::vector &unsubs); static void bufferToMqttPackets(CirBuf &buf, std::vector &packetQueueIn, std::shared_ptr &sender); HandleResult handle(std::shared_ptr &sender); AuthPacketData parseAuthData(); ConnectData parseConnectData(std::shared_ptr &sender); ConnAckData parseConnAckData(); void handleConnect(std::shared_ptr &sender); void handleConnAck(std::shared_ptr &sender); void handleExtendedAuth(std::shared_ptr &sender); DisconnectData parseDisconnectData(); void handleDisconnect(std::shared_ptr &sender); void handleSubscribe(std::shared_ptr &sender); void handleSubAck(std::shared_ptr &sender); void handleUnsubscribe(std::shared_ptr &sender); void handlePing(std::shared_ptr &sender); void parsePublishData(std::shared_ptr &sender); void handlePublish(std::shared_ptr &sender); void parsePubAckData(); void handlePubAck(std::shared_ptr &sender); PubRecData parsePubRecData(); void handlePubRec(std::shared_ptr &sender); void parsePubRelData(); void handlePubRel(std::shared_ptr &sender); void parsePubComp(); void handlePubComp(std::shared_ptr &sender); SubAckData parseSubAckData(); uint8_t getFixedHeaderLength() const; size_t getSizeIncludingNonPresentHeader() const; uint8_t getQos() const { return publishData.qos; } void setQos(const uint8_t new_qos); ProtocolVersion getProtocolVersion() const { return protocolVersion;} const std::string &getTopic() const; const std::vector &getSubtopics(); bool containsFixedHeader() const; void setPacketId(uint16_t packet_id); uint16_t getPacketId() const; void setDuplicate(); void readIntoBuf(CirBuf &buf) const; std::string getPayloadCopy() const; std::string_view getPayloadView() const; bool getRetain() const; void setRetain(bool val); const Publish &getPublishData(); bool biteArrayCannotBeReused() const; std::vector> *getUserProperties() const; const std::optional &getCorrelationData() const; const std::optional &getResponseTopic() const; const std::optional &getContentType() const; const std::optional> getExpiresAt() const; }; struct SubscriptionTuple { const std::string topic; const std::vector subtopics; const uint8_t qos; const bool noLocal; const bool retainAsPublished; const std::string shareName; const AuthResult authResult; const uint32_t subscriptionIdentifier = 0; const RetainHandling retainHandling = RetainHandling::SendRetainedMessagesAtSubscribe; SubscriptionTuple(const std::string &topic, const std::vector &subtopics, uint8_t qos, bool noLocal, bool retainAsPublished, const std::string &shareName, const AuthResult authResult, const uint32_t subscriptionIdentifier, const RetainHandling retainHandling); }; #endif // MQTTPACKET_H ================================================ FILE: mutexowned.h ================================================ #ifndef MUTEXOWNED_H #define MUTEXOWNED_H #include class __attribute__((visibility("default"))) MutexOwnedObjectNull : public std::exception { public: virtual const char* what() const noexcept override { return "MutexOwnedObjectNull"; } }; template class MutexLocked { std::unique_lock l; T *d = nullptr; public: MutexLocked() = default; MutexLocked(T &other, std::mutex &m) : l(m), d(&other) { } MutexLocked(T &other, std::mutex &m, std::try_to_lock_t try_to_lock) : l(m, try_to_lock), d(&other) { } MutexLocked(T &other, std::mutex &m, std::defer_lock_t defer_lock) : l(m, defer_lock), d(&other) { } MutexLocked(const MutexLocked &other) = delete; MutexLocked &operator=(const MutexLocked &other) = delete; MutexLocked &operator=(MutexLocked &&other) noexcept { l = std::move(other.l); d = other.d; other.d = nullptr; return *this; } MutexLocked(MutexLocked &&other) noexcept : l(std::move(other.l)), d(other.d) { other.d = nullptr; } ~MutexLocked() { d = nullptr; } T &operator*() { if (!d) throw MutexOwnedObjectNull(); return *d; } T *operator->() { if (!d) throw MutexOwnedObjectNull(); return d; } void reset() { d = nullptr; if (l.owns_lock()) l.unlock(); } const std::unique_lock &get_lock() { return l; } }; template class MutexOwned { std::mutex m; T d; public: template MutexOwned(Args... args) : d(args...) { } ~MutexOwned() { m.lock(); } MutexLocked lock() { MutexLocked r(d, m); return r; } MutexLocked lock(std::try_to_lock_t try_to_lock) { MutexLocked r(d, m, try_to_lock); return r; } MutexLocked lock(std::defer_lock_t defer_lock) { MutexLocked r(d, m, defer_lock); return r; } }; #endif // MUTEXOWNED_H ================================================ FILE: network.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "network.h" #include "utils.h" #include #include Network::Network(const std::string &network) { if (strContains(network, ".")) { struct sockaddr_in tmpaddr{}; tmpaddr.sin_family = AF_INET; int maskbits = inet_net_pton(AF_INET, network.c_str(), &tmpaddr.sin_addr, sizeof(struct in_addr)); if (maskbits < 0) throw std::runtime_error(formatString("Network '%s' is not a valid network notation.", network.c_str())); std::memcpy(this->data.data(), &tmpaddr, sizeof(tmpaddr)); const uint64_t bits {0xFFFFFFFFu}; uint32_t _netmask {static_cast( (bits << (32 - maskbits)) & 0xFFFFFFFFu )}; this->in_mask = htonl(_netmask); this->family = AF_INET; } else if (strContains(network, ":")) { // Why does inet_net_pton not support AF_INET6...? struct sockaddr_in6 tmpaddr{}; tmpaddr.sin6_family = AF_INET6; std::vector parts = splitToVector(network, '/', 2, false); std::string &addrPart = parts[0]; int maskbits = 128; if (parts.size() == 2) { const std::string &maskstring = parts[1]; const bool invalid_chars = std::any_of(maskstring.begin(), maskstring.end(), [](const char &c) { return c < '0' || c > '9'; }); if (invalid_chars || maskstring.length() > 3) throw std::runtime_error(formatString("Mask '%s' is not valid", maskstring.c_str())); maskbits = std::stoi(maskstring); } if (inet_pton(AF_INET6, addrPart.c_str(), &tmpaddr.sin6_addr) != 1) { throw std::runtime_error(formatString("Network '%s' is not a valid network notation.", network.c_str())); } std::memcpy(this->data.data(), &tmpaddr, sizeof(tmpaddr)); if (maskbits > 128 || maskbits < 0) { throw std::runtime_error(formatString("Network '%s' is not a valid network notation.", network.c_str())); } int m = maskbits; size_t i = 0; const uint64_t x{0xFFFFFFFF00000000}; while (m > 0) { int shift_remainder = std::min(m, 32); const uint32_t b {static_cast( (x >> shift_remainder) & 0xFFFFFFFFu )}; in6_mask.at(i++) = htonl(b); m -= 32; } for (size_t i = 0; i < 4; i++) { network_addr_relevant_bits.__in6_u.__u6_addr32[i] = tmpaddr.sin6_addr.__in6_u.__u6_addr32[i] & in6_mask.at(i); } this->family = AF_INET6; } else { throw std::runtime_error(formatString("Network '%s' is not a valid network notation.", network.c_str())); } } bool Network::match(const sockaddr *addr) const { const sa_family_t fam_arg = getFamilyFromSockAddr(addr); if (this->family != fam_arg) return false; if (this->family == AF_INET) { struct sockaddr_in tmp_this; struct sockaddr_in tmp_arg; std::memcpy(&tmp_this, this->data.data(), sizeof(tmp_this)); std::memcpy(&tmp_arg, addr, sizeof(tmp_arg)); return (tmp_this.sin_addr.s_addr & this->in_mask) == (tmp_arg.sin_addr.s_addr & this->in_mask); } else if (this->family == AF_INET6) { struct sockaddr_in6 tmp_arg; std::memcpy(&tmp_arg, addr, sizeof(tmp_arg)); struct in6_addr arg_addr_relevant_bits; for (size_t i = 0; i < 4; i++) { arg_addr_relevant_bits.__in6_u.__u6_addr32[i] = tmp_arg.sin6_addr.__in6_u.__u6_addr32[i] & in6_mask.at(i); } uint8_t matches[4]; for (int i = 0; i < 4; i++) { matches[i] = arg_addr_relevant_bits.__in6_u.__u6_addr32[i] == network_addr_relevant_bits.__in6_u.__u6_addr32[i]; } return (matches[0] & matches[1] & matches[2] & matches[3]); } return false; } bool Network::match(const sockaddr_in *addr) const { const struct sockaddr *_addr = reinterpret_cast(addr); return match(_addr); } bool Network::match(const sockaddr_in6 *addr) const { const struct sockaddr *_addr = reinterpret_cast(addr); return match(_addr); } ================================================ FILE: network.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef NETWORK_H #define NETWORK_H #include #include #include class Network { std::array data{}; sa_family_t family = AF_UNSPEC; uint32_t in_mask = 0; std::array in6_mask{}; struct in6_addr network_addr_relevant_bits{}; public: Network(const std::string &network); bool match(const struct sockaddr *addr) const ; bool match(const struct sockaddr_in *addr) const ; bool match(const struct sockaddr_in6 *addr) const; }; #endif // NETWORK_H ================================================ FILE: nocopy.cpp ================================================ #include "nocopy.h" ================================================ FILE: nocopy.h ================================================ #ifndef NOCOPY_H #define NOCOPY_H #include template class NoCopy { std::optional data; public: NoCopy() = default; NoCopy(const NoCopy &other) { (void) other; } NoCopy(NoCopy &&other) = delete; NoCopy& operator=(const NoCopy &other) { (void)other; return *this; } NoCopy& operator=(const T &other) { data = other; return *this; } operator bool() const { return data.operator bool(); } const T& value() const { return data.value(); } }; #endif // NOCOPY_H ================================================ FILE: oneinstancelock.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "oneinstancelock.h" #include #include "utils.h" OneInstanceLock::OneInstanceLock() { std::string dir("/tmp"); char *d = getenv("HOME"); if (d != NULL && d[0] == '/') { dir = std::string(d); } lockFilePath = dir + "/.FlashMQ.lock"; } OneInstanceLock::~OneInstanceLock() { unlock(); } void OneInstanceLock::lock() { fd = open(lockFilePath.c_str(), O_RDWR | O_CREAT, 0600); if (fd < 0) throw std::runtime_error(formatString("Can't create '%s': %s", lockFilePath.c_str(), strerror(errno))); struct flock fl; fl.l_start = 0; fl.l_len = 0; fl.l_type = F_WRLCK; fl.l_whence = SEEK_SET; if (fcntl(fd, F_SETLK, &fl) < 0) { throw std::runtime_error("Can't acquire lock: another instance is already running?"); } } void OneInstanceLock::unlock() { if (fd > 0) { close(fd); fd = 0; if (!lockFilePath.empty()) { if (unlink(lockFilePath.c_str()) < 0) { logger->log(LOG_ERR) << "Can't delete '" << lockFilePath << "': " << strerror(errno); } lockFilePath.clear(); } } } ================================================ FILE: oneinstancelock.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef ONEINSTANCELOCK_H #define ONEINSTANCELOCK_H #include #include #include #include #include "logger.h" class OneInstanceLock { int fd = -1; std::string lockFilePath; Logger *logger = Logger::getInstance(); public: OneInstanceLock(); ~OneInstanceLock(); void lock(); void unlock(); }; #endif // ONEINSTANCELOCK_H ================================================ FILE: packetdatatypes.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "packetdatatypes.h" #include "threadglobals.h" #include "settings.h" ConnectData::ConnectData() { const Settings *settings = ThreadGlobals::getSettings(); client_receive_max = settings->maxQosMsgPendingPerClient; session_expire = settings->getExpireSessionAfterSeconds(); max_outgoing_packet_size = settings->maxPacketSize; } ConnAckData::ConnAckData() { const Settings *settings = ThreadGlobals::getSettings(); client_receive_max = settings->maxQosMsgPendingPerClient; session_expire = settings->getExpireSessionAfterSeconds(); max_outgoing_packet_size = settings->maxPacketSize; } ================================================ FILE: packetdatatypes.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef PACKETDATATYPES_H #define PACKETDATATYPES_H #include #include "mqtt5properties.h" struct ConnectData { uint8_t protocol_level_byte = 0; bool bridge = false; // Flags bool password_flag = false; bool will_retain = false; uint8_t will_qos = false; bool will_flag = false; bool clean_start = false; uint16_t keep_alive = 0; // Content from properties uint16_t client_receive_max; uint32_t session_expire; uint32_t max_outgoing_packet_size; uint16_t max_outgoing_topic_aliases = 0; // Default MUST BE 0, meaning server won't initiate aliases; bool request_response_information = false; bool request_problem_information = false; std::string authenticationMethod; std::string authenticationData; std::optional fmq_client_group_id; // Content from Payload std::string client_id; WillPublish willpublish; std::optional username; std::string password; Mqtt5PropertyBuilder builder; ConnectData(); }; struct ConnAckData { // Flags bool sessionPresent = false; uint32_t session_expire = 0; uint16_t client_receive_max; uint8_t max_qos = 2; uint32_t max_outgoing_packet_size; uint16_t max_outgoing_topic_aliases = 0; // Default MUST BE 0, meaning we won't initiate aliases unless the other side says we can. std::string assigned_client_id; uint16_t keep_alive = 0; std::string response_information; std::string server_reference; bool shared_subscriptions_available = true; bool retained_available = true; ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result; std::string authMethod; std::string authData; ConnAckData(); }; struct AuthPacketData { std::string method; std::string data; ReasonCodes reasonCode = ReasonCodes::ImplementationSpecificError; // default something that is never a parse result; }; struct DisconnectData { ReasonCodes reasonCode = ReasonCodes::Success; std::string reasonString; bool session_expiry_interval_set = false; uint32_t session_expiry_interval = 0; }; struct SubAckData { uint16_t packet_id; std::string reasonString; std::vector subAckCodes; }; struct PubRecData { ReasonCodes reasonCode = ReasonCodes::Success; // Default when not specified, or MQTT3; }; #endif // PACKETDATATYPES_H ================================================ FILE: persistencefile.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "persistencefile.h" #include #include #include #include #include #include #include #include #include #include #include "utils.h" #include "logger.h" PersistenceFile::PersistenceFile(const std::string &filePath) : digestContext(EVP_MD_CTX_new(), EVP_MD_CTX_free), buf(1024*1024) { if (!filePath.empty() && filePath[filePath.size() - 1] == '/') throw std::runtime_error("Target file can't contain trailing slash."); this->filePath = filePath; this->filePathTemp = formatString("%s.newfile.%s", filePath.c_str(), getSecureRandomString(8).c_str()); this->filePathCorrupt = formatString("%s.corrupt.%s", filePath.c_str(), getSecureRandomString(8).c_str()); std::vector d1(filePath.length() + 1, 0); std::copy(filePath.begin(), filePath.end(), d1.begin()); this->dirPath = std::string(dirname(d1.data())); } PersistenceFile::~PersistenceFile() { try { closeFile(); } catch(std::exception &ex) { Logger::getInstance()->logf(LOG_WARNING, ex.what()); } if (f != nullptr) { fclose(f); // fclose was already attempted and error handled if it was possible. In case an early fault happend, we need to still make sure. f = nullptr; } } /** * @brief RetainedMessagesDB::hashFile hashes the data after the headers and writes the hash in the header. Uses SHA512. * */ void PersistenceFile::writeCheck(const void *ptr, size_t size, size_t n, FILE *s) { if (fwrite(ptr, size, n, s) != n) { throw std::runtime_error(formatString("Error writing: %s", strerror(errno))); } } ssize_t PersistenceFile::readCheck(void *ptr, size_t size, size_t n, FILE *stream) { size_t nread = fread(ptr, size, n, stream); if (nread != n) { if (feof(f)) return -1; throw std::runtime_error(formatString("Error reading: %s", strerror(errno))); } return nread; } void PersistenceFile::hashFile() { logger->logf(LOG_DEBUG, "Calculating and saving hash of '%s'.", filePath.c_str()); fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); unsigned int output_len = 0; unsigned char md_value[EVP_MAX_MD_SIZE]; std::memset(md_value, 0, EVP_MAX_MD_SIZE); EVP_MD_CTX_reset(digestContext.get()); EVP_DigestInit_ex(digestContext.get(), sha512, NULL); while (!feof(f)) { size_t n = fread(buf.data(), 1, buf.size(), f); EVP_DigestUpdate(digestContext.get(), buf.data(), n); } EVP_DigestFinal_ex(digestContext.get(), md_value, &output_len); if (output_len != HASH_SIZE) throw std::runtime_error("Impossible: calculated hash size wrong length"); fseek(f, MAGIC_STRING_LENGH, SEEK_SET); writeCheck(md_value, output_len, 1, f); } void PersistenceFile::verifyHash() { fseek(f, 0, SEEK_END); const size_t size = ftell(f); if (size < TOTAL_HEADER_SIZE) throw std::runtime_error(formatString("File '%s' is too small for it even to contain a header.", filePath.c_str())); unsigned char md_from_disk[HASH_SIZE]; std::memset(md_from_disk, 0, HASH_SIZE); fseek(f, MAGIC_STRING_LENGH, SEEK_SET); readCheck(md_from_disk, 1, HASH_SIZE, f); unsigned int output_len = 0; unsigned char md_value[EVP_MAX_MD_SIZE]; std::memset(md_value, 0, EVP_MAX_MD_SIZE); EVP_MD_CTX_reset(digestContext.get()); EVP_DigestInit_ex(digestContext.get(), sha512, NULL); while (!feof(f)) { size_t n = fread(buf.data(), 1, buf.size(), f); EVP_DigestUpdate(digestContext.get(), buf.data(), n); } EVP_DigestFinal_ex(digestContext.get(), md_value, &output_len); if (output_len != HASH_SIZE) throw std::runtime_error("Impossible: calculated hash size wrong length"); if (std::memcmp(md_from_disk, md_value, output_len) != 0) { fclose(f); f = nullptr; if (rename(filePath.c_str(), filePathCorrupt.c_str()) == 0) { throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Moved aside to '%s'.", filePath.c_str(), filePathCorrupt.c_str())); } else { throw std::runtime_error(formatString("File '%s' is corrupt: hash mismatch. Tried to move aside, but that failed: '%s'.", filePath.c_str(), strerror(errno))); } } logger->logf(LOG_DEBUG, "Hash of '%s' correct", filePath.c_str()); } /** * @brief PersistenceFile::makeSureBufSize grows the buffer if n is bigger. * @param n in bytes. * * Remember that when you're dealing with fields that are sized in MQTT by 16 bit ints, like topic paths, the buffer will always be big enough, because it's 1 MB. */ void PersistenceFile::makeSureBufSize(size_t n) { if (n > buf.size()) buf.resize(n); } void PersistenceFile::writeInt64(const int64_t val) { unsigned char buf[8]; // Write big-endian int shift = 56; int i = 0; while (shift >= 0) { unsigned char wantedByte = val >> shift; buf[i++] = wantedByte; shift -= 8; } writeCheck(buf, 1, 8, f); } void PersistenceFile::writeUint32(const uint32_t val) { unsigned char buf[4]; // Write big-endian int shift = 24; int i = 0; while (shift >= 0) { unsigned char wantedByte = val >> shift; buf[i++] = wantedByte; shift -= 8; } writeCheck(buf, 1, 4, f); } void PersistenceFile::writeUint16(const uint16_t val) { unsigned char buf[2]; // Write big-endian int shift = 8; int i = 0; while (shift >= 0) { unsigned char wantedByte = val >> shift; buf[i++] = wantedByte; shift -= 8; } writeCheck(buf, 1, 2, f); } void PersistenceFile::writeUint8(const uint8_t val) { writeCheck(&val, 1, 1, f); } void PersistenceFile::writeString(const std::string &s) { writeUint32(s.size()); writeCheck(s.c_str(), 1, s.size(), f); } void PersistenceFile::writeOptionalString(const std::optional &s) { if (!s) { writeUint8(0); return; } writeUint8(1); writeString(s.value()); } int64_t PersistenceFile::readInt64(bool &eofFound) { if (readCheck(buf.data(), 1, 8, f) < 0) eofFound = true; unsigned char *buf_ = buf.data(); const uint64_t val1 = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); const uint64_t val2 = ((buf_[4]) << 24) | ((buf_[5]) << 16) | ((buf_[6]) << 8) | (buf_[7]); const int64_t val = (val1 << 32) | val2; return val; } uint32_t PersistenceFile::readUint32(bool &eofFound) { if (readCheck(buf.data(), 1, 4, f) < 0) eofFound = true; uint32_t val; unsigned char *buf_ = buf.data(); val = ((buf_[0]) << 24) | ((buf_[1]) << 16) | ((buf_[2]) << 8) | (buf_[3]); return val; } uint16_t PersistenceFile::readUint16(bool &eofFound) { if (readCheck(buf.data(), 1, 2, f) < 0) eofFound = true; uint16_t val; unsigned char *buf_ = buf.data(); val = ((buf_[0]) << 8) | (buf_[1]); return val; } uint8_t PersistenceFile::readUint8(bool &eofFound) { uint8_t val; if (readCheck(&val, 1, 1, f) < 0) eofFound = true; return val; } std::string PersistenceFile::readString(bool &eofFound) { const uint32_t size = readUint32(eofFound); if (size > 0xFFFF) throw std::runtime_error("In MQTT world, strings are never longer than 65535 bytes."); makeSureBufSize(size); readCheck(buf.data(), 1, size, f); std::string result = make_string(buf, 0, size); return result; } std::optional PersistenceFile::readOptionalString(bool &eofFound) { uint8_t x = readUint8(eofFound); if (!x) return {}; std::optional result = readString(eofFound); return result; } /** * @brief RetainedMessagesDB::openWrite doesn't explicitely name a file version (v1, etc), because we always write the current definition. */ void PersistenceFile::openWrite(const std::string &versionString) { if (versionString.size() - 1 >= MAGIC_STRING_LENGH) throw std::runtime_error("Version string length must be shorter than " + std::to_string(MAGIC_STRING_LENGH)); if (openMode != FileMode::unknown) throw std::runtime_error("File is already open."); int fd = creat(filePathTemp.c_str(), S_IRUSR | S_IWUSR); if (fd < 0) throw std::runtime_error("Creating " + filePathTemp + " failed"); close(fd); f = fopen(filePathTemp.c_str(), "w+b"); if (f == nullptr) { throw std::runtime_error(formatString("Can't open '%s': %s", filePathTemp.c_str(), strerror(errno))); } openMode = FileMode::write; writeCheck(buf.data(), 1, MAGIC_STRING_LENGH, f); rewind(f); writeCheck(versionString.c_str(), 1, versionString.length(), f); fseek(f, MAGIC_STRING_LENGH, SEEK_SET); writeCheck(buf.data(), 1, HASH_SIZE, f); } void PersistenceFile::openRead(const std::string &expected_version_string) { if (openMode != FileMode::unknown) throw std::runtime_error("File is already open."); f = fopen(filePath.c_str(), "rb"); if (f == nullptr) throw PersistenceFileCantBeOpened(formatString("Can't open '%s': %s.", filePath.c_str(), strerror(errno)).c_str()); openMode = FileMode::read; verifyHash(); rewind(f); readCheck(buf.data(), 1, MAGIC_STRING_LENGH, f); const auto magic_string_0_pos = std::find(buf.begin(), buf.end(), 0); if (magic_string_0_pos == buf.end()) throw std::runtime_error("Error reading version string."); detectedVersionString = std::string(buf.begin(), magic_string_0_pos); // In case people want to downgrade, the old file is still there. if (detectedVersionString != expected_version_string) { const std::string copy_file_path = filePath + "." + detectedVersionString; try { const uint64_t file_size = getFileSize(filePath); const uint64_t free_space = getFreeSpace(filePath); if (free_space > file_size * 3) { logger->log(LOG_NOTICE) << "File version change detected. Copying '" << filePath << "' to '" << copy_file_path << "' to support downgrading FlashMQ."; int fd = creat(copy_file_path.c_str(), S_IRUSR | S_IWUSR); if (fd < 0) throw std::runtime_error("Creating " + copy_file_path + " failed"); close(fd); std::ifstream src(filePath, std::ios::binary); std::ofstream dst(copy_file_path, std::ios::binary); src.exceptions(std::ifstream::failbit | std::ifstream::badbit); dst.exceptions(std::ifstream::failbit | std::ifstream::badbit); dst << src.rdbuf(); } else { logger->log(LOG_NOTICE) << "File version change detected, but not copying '" << filePath << "' to '" << copy_file_path << "' because disk space is running low."; } } catch (std::exception &ex) { logger->log(LOG_ERROR) << "Backing up to '" << copy_file_path << "' failed."; } } fseek(f, TOTAL_HEADER_SIZE, SEEK_SET); } void PersistenceFile::dontSaveTmpFile() { this->discard = true; } void PersistenceFile::closeFile() { if (!f) return; if (openMode == FileMode::write) { if (!discard) hashFile(); if (fflush(f) != 0) { std::string msg(strerror(errno)); throw std::runtime_error(formatString("Flush of '%s' failed: %s.", this->filePathTemp.c_str(), msg.c_str())); } fsync(f->_fileno); } if (f != nullptr) { FILE *f2 = f; f = nullptr; if (fclose(f2) < 0) { std::string msg(strerror(errno)); throw std::runtime_error(formatString("Close of '%s' failed: %s.", this->filePathTemp.c_str(), msg.c_str())); } } if (openMode == FileMode::write && !filePathTemp.empty() && ! filePath.empty()) { if (discard) { unlink(filePathTemp.c_str()); } else { if (rename(filePathTemp.c_str(), filePath.c_str()) < 0) throw std::runtime_error(formatString("Saving '%s' failed: rename of temp file to target failed with: %s", filePath.c_str(), strerror(errno))); int dir_fd = open(this->dirPath.c_str(), O_RDONLY); if (dir_fd > 0) { fsync(dir_fd); close(dir_fd); } } } } const std::string &PersistenceFile::getFilePath() const { return this->filePath; } ================================================ FILE: persistencefile.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef PERSISTENCEFILE_H #define PERSISTENCEFILE_H #include #include #include #include #include #include #include #include #include #include "logger.h" #define MAGIC_STRING_LENGH 32 #define HASH_SIZE 64 #define TOTAL_HEADER_SIZE (MAGIC_STRING_LENGH + HASH_SIZE) /** * @brief The PersistenceFileCantBeOpened class should be thrown when a non-fatal file-not-found error happens. */ class PersistenceFileCantBeOpened : public std::runtime_error { public: PersistenceFileCantBeOpened(const std::string &msg) : std::runtime_error(msg) {} }; class PersistenceFile { std::string filePath; std::string filePathTemp; std::string filePathCorrupt; std::string dirPath; bool discard = false; std::unique_ptr digestContext; const EVP_MD *sha512 = EVP_sha512(); void hashFile(); void verifyHash(); protected: enum class FileMode { unknown, read, write }; FILE *f = nullptr; std::vector buf; FileMode openMode = FileMode::unknown; std::string detectedVersionString; Logger *logger = Logger::getInstance(); void makeSureBufSize(size_t n); void writeCheck(const void *__restrict __ptr, size_t __size, size_t __n, FILE *__restrict __s); ssize_t readCheck(void *__restrict ptr, size_t size, size_t n, FILE *__restrict stream); void writeInt64(const int64_t val); void writeUint32(const uint32_t val); void writeUint16(const uint16_t val); void writeUint8(const uint8_t val); void writeString(const std::string &s); void writeOptionalString(const std::optional &s); int64_t readInt64(bool &eofFound); uint32_t readUint32(bool &eofFound); uint16_t readUint16(bool &eofFound); uint8_t readUint8(bool &eofFound); std::string readString(bool &eofFound); std::optional readOptionalString(bool &eofFound); public: PersistenceFile(const std::string &filePath); virtual ~PersistenceFile(); void openWrite(const std::string &versionString); void openRead(const std::string &expected_version_string); void dontSaveTmpFile(); void closeFile(); const std::string &getFilePath() const; }; #endif // PERSISTENCEFILE_H ================================================ FILE: persistencefunctions.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2025 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "persistencefunctions.h" #include "logger.h" #include "globals.h" #include "globber.h" /** * @brief saveState saves sessions and such to files. It's run in the main thread, but also dedicated threads. For that, * reason, it's a static method to reduce the risk of accidental use of data without locks. * @param settings A local settings, copied from a std::bind copy when running in a thread, because of thread safety. * @param bridgeInfos is a list of objects already prepared from the original bridge configs, to avoid concurrent access. */ void saveState(const Settings &settings, const std::list &bridgeInfos, bool in_background) { Logger *logger = Logger::getInstance(); try { if (settings.storageDir.empty()) return; if (settings.persistenceDataToSave.hasNone()) return; std::shared_ptr subscriptionStore = globals->subscriptionStore; if (settings.persistenceDataToSave.hasFlagSet(PersistenceDataToSave::RetainedMessages) && settings.retainedMessagesMode == RetainedMessagesMode::Enabled) { const std::string retainedDBPath = settings.getRetainedMessagesDBFile(); subscriptionStore->saveRetainedMessages(retainedDBPath, in_background); } if (settings.persistenceDataToSave.hasFlagSet(PersistenceDataToSave::SessionsAndSubscriptions)) { const std::string sessionsDBPath = settings.getSessionsDBFile(); subscriptionStore->saveSessionsAndSubscriptions(sessionsDBPath); } if (settings.persistenceDataToSave.hasFlagSet(PersistenceDataToSave::BridgeInfo)) { saveBridgeInfo(settings.getBridgeNamesDBFile(), bridgeInfos); } logger->logf(LOG_NOTICE, "Saving states done"); } catch(std::exception &ex) { logger->logf(LOG_ERR, "Error saving state: %s", ex.what()); } } void saveBridgeInfo(const std::string &filePath, const std::list &bridgeInfos) { Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Saving bridge info in '%s'", filePath.c_str()); BridgeInfoDb bridgeInfoDb(filePath); bridgeInfoDb.openWrite(); bridgeInfoDb.saveInfo(bridgeInfos); } std::list loadBridgeInfo(Settings &settings) { Logger *logger = Logger::getInstance(); std::list bridges = settings.stealBridges(); if (settings.storageDir.empty()) return bridges; const std::string filePath = settings.getBridgeNamesDBFile(); try { logger->logf(LOG_NOTICE, "Loading '%s'", filePath.c_str()); BridgeInfoDb dbfile(filePath); dbfile.openRead(); std::list bridgeInfos = dbfile.readInfo(); for(const BridgeInfoForSerializing &info : bridgeInfos) { for(BridgeConfig &bridgeConfig : bridges) { if (!bridgeConfig.useSavedClientId) continue; if (bridgeConfig.clientidPrefix == info.prefix) { logger->log(LOG_INFO) << "Assigning stored bridge clientid '" << info.clientId << "' to bridge '" << info.prefix << "'."; bridgeConfig.setClientId(info.prefix, info.clientId); break; } } } } catch (PersistenceFileCantBeOpened &ex) { logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); } return bridges; } void correctBackupDbPermissions(const std::string &dir) { Globber glob; const std::string backups_glob = dir + "/*.db.*"; auto matches = glob.getGlob(backups_glob); for (const auto &s : matches) { chmod(s.c_str(), S_IRUSR | S_IWUSR); } } ================================================ FILE: persistencefunctions.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2025 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef PERSISTENCEFUNCTIONS_H #define PERSISTENCEFUNCTIONS_H #include "bridgeinfodb.h" void saveState(const Settings &settings, const std::list &bridgeInfos, bool in_background); void saveBridgeInfo(const std::string &filePath, const std::list &bridgeInfos); std::list loadBridgeInfo(Settings &settings); void correctBackupDbPermissions(const std::string &dir); #endif // PERSISTENCEFUNCTIONS_H ================================================ FILE: plugin.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "plugin.h" #include #include #include #include #include #include #include #include #include "exceptions.h" #include "unscopedlock.h" #include "utils.h" #include "client.h" #include "threadglobals.h" #include "threaddata.h" std::mutex Authentication::initMutex; std::mutex Authentication::deinitMutex; std::mutex Authentication::authChecksMutex; MosquittoPasswordFileEntry::MosquittoPasswordFileEntry(PasswordHashType type, const std::vector &&salt, const std::vector &&cryptedPassword, int iterations) : type(type), salt(salt), cryptedPassword(cryptedPassword), iterations(iterations) { } Authentication::Authentication(Settings &settings) : settings(settings), mosquittoPasswordFile(settings.mosquittoPasswordFile), mosquittoAclFile(settings.mosquittoAclFile), mosquittoDigestContext(EVP_MD_CTX_new(), EVP_MD_CTX_free) { logger = Logger::getInstance(); if(!sha512) { throw std::runtime_error("Failed to initialize SHA512 for decoding auth entry"); } EVP_DigestInit_ex(mosquittoDigestContext.get(), sha512, NULL); } void *Authentication::loadSymbol(void *handle, const char *symbol, bool exceptionOnError) const { void *r = dlsym(handle, symbol); if (r == NULL && exceptionOnError) { std::string errmsg(dlerror()); throw FatalError(errmsg); } return r; } void Authentication::loadPlugin(const PluginLoader &l) { if (!l.loaded()) return; initialized = false; pluginFamily = l.getPluginFamily(); flashmqPluginVersionNumber = l.getFlashMQPluginVersion(); if (pluginFamily == PluginFamily::MosquittoV2) { init_v2 = (F_plugin_init_v2)l.loadSymbol( "mosquitto_auth_plugin_init"); cleanup_v2 = (F_plugin_cleanup_v2)l.loadSymbol( "mosquitto_auth_plugin_cleanup"); security_init_v2 = (F_plugin_security_init_v2)l.loadSymbol("mosquitto_auth_security_init"); security_cleanup_v2 = (F_plugin_security_cleanup_v2)l.loadSymbol("mosquitto_auth_security_cleanup"); acl_check_v2 = (F_plugin_acl_check_v2)l.loadSymbol("mosquitto_auth_acl_check"); unpwd_check_v2 = (F_plugin_unpwd_check_v2)l.loadSymbol("mosquitto_auth_unpwd_check"); psk_key_get_v2 = (F_plugin_psk_key_get_v2)l.loadSymbol("mosquitto_auth_psk_key_get"); } else if (pluginFamily == PluginFamily::FlashMQ) { flashmq_plugin_allocate_thread_memory_v1 = (F_flashmq_plugin_allocate_thread_memory_v1)l.loadSymbol("flashmq_plugin_allocate_thread_memory"); flashmq_plugin_deallocate_thread_memory_v1 = (F_flashmq_plugin_deallocate_thread_memory_v1)l.loadSymbol("flashmq_plugin_deallocate_thread_memory"); flashmq_plugin_init_v1 = (F_flashmq_plugin_init_v1)l.loadSymbol("flashmq_plugin_init"); flashmq_plugin_deinit_v1 = (F_flashmq_plugin_deinit_v1)l.loadSymbol("flashmq_plugin_deinit"); flashmq_plugin_login_check_v1 = (F_flashmq_plugin_login_check_v1)l.loadSymbol("flashmq_plugin_login_check"); flashmq_plugin_periodic_event_v1 = (F_flashmq_plugin_periodic_event_v1)l.loadSymbol("flashmq_plugin_periodic_event", false); flashmq_plugin_extended_auth_v1 = (F_flashmq_plugin_extended_auth_v1)l.loadSymbol("flashmq_plugin_extended_auth", false); flashmq_plugin_alter_subscription_v1 = (F_flashmq_plugin_alter_subscription_v1)l.loadSymbol("flashmq_plugin_alter_subscription", false); flashmq_plugin_client_disconnected_v1 = (F_flashmq_plugin_client_disconnected_v1)l.loadSymbol("flashmq_plugin_client_disconnected", false); flashmq_plugin_poll_event_received_v1 = (F_flashmq_plugin_poll_event_received_v1)l.loadSymbol("flashmq_plugin_poll_event_received", false); if (flashmqPluginVersionNumber == 1) { flashmq_plugin_acl_check_v1 = (F_flashmq_plugin_acl_check_v1)l.loadSymbol("flashmq_plugin_acl_check"); flashmq_plugin_alter_publish_v1 = (F_flashmq_plugin_alter_publish_v1)l.loadSymbol("flashmq_plugin_alter_publish", false); } else if (flashmqPluginVersionNumber == 2) { flashmq_plugin_acl_check_v2 = (F_flashmq_plugin_acl_check_v2)l.loadSymbol("flashmq_plugin_acl_check"); flashmq_plugin_alter_publish_v2 = (F_flashmq_plugin_alter_publish_v2)l.loadSymbol("flashmq_plugin_alter_publish", false); } else if (flashmqPluginVersionNumber == 3) { flashmq_plugin_acl_check_v3 = (F_flashmq_plugin_acl_check_v3)l.loadSymbol("flashmq_plugin_acl_check"); flashmq_plugin_alter_publish_v3 = (F_flashmq_plugin_alter_publish_v3)l.loadSymbol("flashmq_plugin_alter_publish", false); } else if (flashmqPluginVersionNumber == 4) { flashmq_plugin_acl_check_v4 = (F_flashmq_plugin_acl_check_v4)l.loadSymbol("flashmq_plugin_acl_check"); flashmq_plugin_alter_publish_v3 = (F_flashmq_plugin_alter_publish_v3)l.loadSymbol("flashmq_plugin_alter_publish", false); flashmq_plugin_on_unsubscribe_v4 = (F_flashmq_plugin_on_unsubscribe_v4)l.loadSymbol("flashmq_plugin_on_unsubscribe", false); } else if (flashmqPluginVersionNumber == 5) { flashmq_plugin_acl_check_v5 = (F_flashmq_plugin_acl_check_v5)l.loadSymbol("flashmq_plugin_acl_check"); flashmq_plugin_alter_publish_v5 = (F_flashmq_plugin_alter_publish_v5)l.loadSymbol("flashmq_plugin_alter_publish", false); flashmq_plugin_on_unsubscribe_v4 = (F_flashmq_plugin_on_unsubscribe_v4)l.loadSymbol("flashmq_plugin_on_unsubscribe", false); } else { throw FatalError("Unreachable error reached in detecting plugin version."); } } else { throw FatalError("Unreachable error reached?"); } initialized = true; } /** * @brief plugin::init is like Mosquitto's init(), and is to allow the plugin to init memory. Plugins should not load * their authentication data here. That's what securityInit() is for. */ void Authentication::init() { if (pluginFamily == PluginFamily::None) return; UnscopedLock lock(initMutex); if (settings.pluginSerializeInit) lock.lock(); if (quitting) return; if (pluginFamily == PluginFamily::MosquittoV2) { AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = init_v2(&pluginData, authOpts.head(), authOpts.size()); if (result != 0) throw FatalError("Error initialising auth plugin."); } else if (pluginFamily == PluginFamily::FlashMQ) { std::unordered_map &authOpts = settings.getFlashmqpluginOpts(); flashmq_plugin_allocate_thread_memory_v1(&pluginData, authOpts); } } void Authentication::cleanup() { if (pluginFamily == PluginFamily::None) return; logger->logf(LOG_NOTICE, "Cleaning up authentication."); securityCleanup(false); UnscopedLock lock(deinitMutex); if (settings.pluginSerializeInit) lock.lock(); if (pluginFamily == PluginFamily::MosquittoV2) { AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = cleanup_v2(pluginData, authOpts.head(), authOpts.size()); if (result != 0) logger->logf(LOG_ERR, "Error cleaning up auth plugin"); // Not doing exception, because we're shutting down anyway. } else if (pluginFamily == PluginFamily::FlashMQ) { try { std::unordered_map &authOpts = settings.getFlashmqpluginOpts(); flashmq_plugin_deallocate_thread_memory_v1(pluginData, authOpts); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error cleaning up auth plugin: '%s'", ex.what()); // Not doing exception, because we're shutting down anyway. } } } /** * @brief plugin::securityInit initializes the security data, like loading users, ACL tables, etc. * @param reloading */ void Authentication::securityInit(bool reloading) { if (pluginFamily == PluginFamily::None) return; UnscopedLock lock(initMutex); if (settings.pluginSerializeInit) lock.lock(); if (quitting) return; if (pluginFamily == PluginFamily::MosquittoV2) { AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_init_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) { throw pluginException("Plugin function mosquitto_auth_security_init returned an error. If it didn't log anything, we don't know what it was."); } } else if (pluginFamily == PluginFamily::FlashMQ) { // The exception handling is higher up in the call stack, because it needs to be different on first start vs reload. std::unordered_map &authOpts = settings.getFlashmqpluginOpts(); flashmq_plugin_init_v1(pluginData, authOpts, reloading); } initialized = true; periodicEvent(); } void Authentication::securityCleanup(bool reloading) { if (pluginFamily == PluginFamily::None) return; initialized = false; UnscopedLock lock(deinitMutex); if (settings.pluginSerializeInit) lock.lock(); if (pluginFamily == PluginFamily::MosquittoV2) { AuthOptCompatWrap &authOpts = settings.getAuthOptsCompat(); int result = security_cleanup_v2(pluginData, authOpts.head(), authOpts.size(), reloading); if (result != 0) { throw pluginException("Plugin function mosquitto_auth_security_cleanup returned an error. If it didn't log anything, we don't know what it was."); } } else if (pluginFamily == PluginFamily::FlashMQ) { // The exception handling is higher up in the call stack, because it needs to be different on first start vs reload. std::unordered_map &authOpts = settings.getFlashmqpluginOpts(); flashmq_plugin_deinit_v1(pluginData, authOpts, reloading); } } void Authentication::onUnsubscribe( const std::shared_ptr &session, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, const std::vector > *userProperties) const { assert(subtopics.size() > 0); if (pluginFamily == PluginFamily::None) { return; } if (!initialized) { logger->logf(LOG_ERR, "Plugin clientDisconnected called, but initialization failed or not performed."); return; } if (pluginFamily == PluginFamily::FlashMQ) { try { if (flashmq_plugin_on_unsubscribe_v4) flashmq_plugin_on_unsubscribe_v4(pluginData, session, clientid, username, topic, subtopics, shareName, userProperties); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Exception in 'flashmq_plugin_on_unsubscribe': '%s'.", ex.what()); } } } /** * @brief Authentication::aclCheck performs a write ACL check on the incoming publish. * @param publishData * @return * * Internal publishes write (publish) access is always allowed (it makes little sense that a plugin would have to explicitly allow * those), but they are passed through the plugin, so you can inspect them. The read access can still be rejected by a plugin. */ AuthResult Authentication::aclCheck(Publish &publishData, std::string_view payload, AclAccess access) { AuthResult result = aclCheck( publishData.client_id, publishData.username, publishData.topic, publishData.getSubtopics(), "", payload, access, publishData.qos, publishData.retain, publishData.correlationData, publishData.responseTopic, publishData.contentType, publishData.expiresAt(), publishData.getUserProperties()); // Anonymous publishes come from FlashMQ internally, like SYS topics. We need to allow them. if (access == AclAccess::write && publishData.client_id.empty()) result = AuthResult::success; return result; } AuthResult Authentication::aclCheck( const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &sharename, std::string_view payload, AclAccess access, uint8_t qos, bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties) { assert(subtopics.size() > 0); #ifdef TESTING // I could technically test with empty payload, but so far I don't, and this is a good way for now to check I don't miss the payload // because of the payload copy prevention optimization. assert(retain || access == AclAccess::subscribe || !payload.empty()); #endif auto &threadData = ThreadGlobals::getThreadData(); switch (access) { case AclAccess::read: threadData->aclReadChecks.inc(1); break; case AclAccess::write: threadData->aclWriteChecks.inc(1); break; case AclAccess::subscribe: threadData->aclSubscribeChecks.inc(1); break; case AclAccess::register_will: threadData->aclRegisterWillChecks.inc(1); break; default: break; } AuthResult firstResult = aclCheckFromMosquittoAclFile(clientid, username, subtopics, access); if (firstResult != AuthResult::success) return firstResult; if (pluginFamily == PluginFamily::None) return firstResult; if (!initialized) { logger->logf(LOG_ERR, "ACL check by plugin wanted, but initialization failed. Can't perform check."); return AuthResult::error; } UnscopedLock lock(authChecksMutex); if (settings.pluginSerializeAuthChecks) lock.lock(); if (__builtin_expect(pluginFamily == PluginFamily::FlashMQ, 1)) { // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher // gets disconnected. try { if (flashmqPluginVersionNumber == 5) { return flashmq_plugin_acl_check_v5( pluginData, access, clientid, username, topic, subtopics, sharename, payload, qos, retain, correlationData, responseTopic, contentType, expiresAt, userProperties); } else if (flashmqPluginVersionNumber == 4) { return flashmq_plugin_acl_check_v4( pluginData, access, clientid, username, topic, subtopics, sharename, payload, qos, retain, correlationData, responseTopic, userProperties); } else if (flashmqPluginVersionNumber == 3) { return flashmq_plugin_acl_check_v3( pluginData, access, clientid, username, topic, subtopics, payload, qos, retain, correlationData, responseTopic, userProperties); } else if (flashmqPluginVersionNumber == 2) return flashmq_plugin_acl_check_v2(pluginData, access, clientid, username, topic, subtopics, payload, qos, retain, userProperties); else return flashmq_plugin_acl_check_v1(pluginData, access, clientid, username, topic, subtopics, qos, retain, userProperties); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error doing ACL check in plugin: '%s'", ex.what()); logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need."); } } else if (pluginFamily == PluginFamily::MosquittoV2) { // We have to do this, because Mosquitto plugin v2 has no notion of checking subscribes. if (access == AclAccess::subscribe) return AuthResult::success; // We have to do this, because Mosquitto plugins has no notion of will registration ACL. if (access == AclAccess::register_will) return AuthResult::success; int result = acl_check_v2(pluginData, clientid.c_str(), username.c_str(), topic.c_str(), static_cast(access)); AuthResult result_ = static_cast(result); if (result_ == AuthResult::error) { logger->logf(LOG_ERR, "ACL check by plugin returned error for topic '%s'. If it didn't log anything, we don't know what it was.", topic.c_str()); } return result_; } return AuthResult::error; } AuthResult Authentication::loginCheck(const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client, const bool allowAnonymous) { /* * This first construct is designed so that even when you allow anonymous, user verification still works, with the password file or * plugin. If login is denied based on a password file, the attempt is still given to the plugin. However, when the auth succeeds * based on the password file, it's not given to the plugin anymore. * * Also, when you allow anonymous, loging in with a non-existing user will work. But, when the user does exist, it must match. */ AuthResult firstResult = allowAnonymous ? AuthResult::success : AuthResult::login_denied; if (!this->mosquittoPasswordFile.empty()) { const std::optional r = loginCheckFromMosquittoPasswordFile(username, password); if (r) firstResult = r.value(); if (firstResult == AuthResult::success) return firstResult; } if (pluginFamily == PluginFamily::None) return firstResult; if (!initialized) { logger->logf(LOG_ERR, "Username+password check with plugin wanted, but initialization failed. Can't perform check."); return AuthResult::error; } UnscopedLock lock(authChecksMutex); if (settings.pluginSerializeAuthChecks) lock.lock(); if (pluginFamily == PluginFamily::MosquittoV2) { int result = unpwd_check_v2(pluginData, username.c_str(), password.c_str()); AuthResult r = static_cast(result); if (r == AuthResult::error) { logger->logf(LOG_ERR, "Username+password check by plugin returned error for user '%s'. If it didn't log anything, we don't know what it was.", username.c_str()); } return r; } else if (pluginFamily == PluginFamily::FlashMQ) { // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher // gets disconnected. try { return flashmq_plugin_login_check_v1(pluginData, clientid, username, password, userProperties, client); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error doing login check in plugin: '%s'", ex.what()); logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need."); } } return AuthResult::error; } AuthResult Authentication::extendedAuth(const std::string &clientid, ExtendedAuthStage stage, const std::string &authMethod, const std::string &authData, const std::vector> *userProperties, std::string &returnData, std::string &username, const std::weak_ptr &client) { if (pluginFamily == PluginFamily::None) return AuthResult::auth_method_not_supported; if (!initialized) { logger->logf(LOG_ERR, "Extended auth check with plugin wanted, but initialization failed. Can't perform check."); return AuthResult::error; } UnscopedLock lock(authChecksMutex); if (settings.pluginSerializeAuthChecks) lock.lock(); if (pluginFamily == PluginFamily::FlashMQ) { if (!flashmq_plugin_extended_auth_v1) return AuthResult::auth_method_not_supported; // I'm using this try/catch because propagating the exception higher up conflicts with who gets the blame, and then the publisher // gets disconnected. try { return flashmq_plugin_extended_auth_v1(pluginData, clientid, stage, authMethod, authData, userProperties, returnData, username, client); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error doing login check in plugin: '%s'", ex.what()); logger->logf(LOG_WARNING, "Throwing exceptions from auth plugin login/ACL checks is slow. There's no need."); } } else if (pluginFamily == PluginFamily::MosquittoV2) { throw ProtocolError("Mosquitto v2 plugin doesn't support extended auth.", ReasonCodes::BadAuthenticationMethod); } return AuthResult::error; } bool Authentication::alterSubscribe(const std::string &clientid, std::string &topic, const std::vector &subtopics, uint8_t &qos, const std::vector> *userProperties) { if (pluginFamily == PluginFamily::None) { return false; } if (!initialized) { logger->logf(LOG_ERR, "Plugin alterSubscribe called, but initialization failed or not performed."); return false; } if (pluginFamily == PluginFamily::FlashMQ && flashmq_plugin_alter_subscription_v1) { try { return flashmq_plugin_alter_subscription_v1(pluginData, clientid, topic, subtopics, qos, userProperties); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Exception in 'flashmq_plugin_alter_subscription': '%s'. You now have undefined behavior.", ex.what()); } } return false; } bool Authentication::alterPublish( const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::optional &correlationData, std::optional &responseTopic, std::optional &contentType, std::vector> *userProperties) { #ifdef TESTING // I could technically test with empty payload, but so far I don't, and this is a good way for now to check I don't miss the payload // because of the payload copy prevention optimization. assert(retain || !payload.empty()); #endif if (pluginFamily == PluginFamily::None) { return false; } if (!initialized) { logger->logf(LOG_ERR, "Plugin alterPublish called, but initialization failed or not performed."); return false; } if (pluginFamily == PluginFamily::FlashMQ) { try { if (flashmqPluginVersionNumber == 5) { if (flashmq_plugin_alter_publish_v5) return flashmq_plugin_alter_publish_v5( pluginData, clientid, topic, subtopics, payload, qos, retain, correlationData, responseTopic, contentType, userProperties); } else if (flashmqPluginVersionNumber == 3 || flashmqPluginVersionNumber == 4) { if (flashmq_plugin_alter_publish_v3) return flashmq_plugin_alter_publish_v3(pluginData, clientid, topic, subtopics, payload, qos, retain, correlationData, responseTopic, userProperties); } else if (flashmqPluginVersionNumber == 2) { if (flashmq_plugin_alter_publish_v2) return flashmq_plugin_alter_publish_v2(pluginData, clientid, topic, subtopics, payload, qos, retain, userProperties); } else if (flashmqPluginVersionNumber == 1) { if (flashmq_plugin_alter_publish_v1) return flashmq_plugin_alter_publish_v1(pluginData, clientid, topic, subtopics, qos, retain, userProperties); } } catch (std::exception &ex) { logger->logf(LOG_ERR, "Exception in 'flashmq_plugin_alter_publish': '%s'. You now have undefined behavior.", ex.what()); } } return false; } void Authentication::clientDisconnected(const std::string &clientid) { if (pluginFamily == PluginFamily::None) { return; } if (!initialized) { logger->logf(LOG_ERR, "Plugin clientDisconnected called, but initialization failed or not performed."); return; } if (pluginFamily == PluginFamily::FlashMQ && flashmq_plugin_client_disconnected_v1) { try { flashmq_plugin_client_disconnected_v1(pluginData, clientid); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Exception in 'flashmq_plugin_client_disconnected': '%s'.", ex.what()); } } } void Authentication::fdReady(int fd, int events, const std::weak_ptr &p) { if (pluginFamily == PluginFamily::None) { return; } if (!initialized) { logger->logf(LOG_ERR, "Plugin fdReady called, but initialization failed or not performed."); return; } if (pluginFamily == PluginFamily::FlashMQ && flashmq_plugin_poll_event_received_v1) { try { flashmq_plugin_poll_event_received_v1(pluginData, fd, events, p); } catch (std::exception &ex) { logger->logf(LOG_ERR, "In 'flashmq_plugin_poll_event_v1(): '", ex.what()); } } } void Authentication::setQuitting() { this->quitting = true; } /** * @brief Authentication::loadMosquittoPasswordFile is called once on startup, and on a frequent interval, and reloads the file if changed. */ void Authentication::loadMosquittoPasswordFile() { if (this->mosquittoPasswordFile.empty()) return; if (access(this->mosquittoPasswordFile.c_str(), R_OK) != 0) { logger->logf(LOG_ERR, "Passwd file '%s' is not there or not readable.", this->mosquittoPasswordFile.c_str()); return; } struct stat statbuf {}; check(stat(mosquittoPasswordFile.c_str(), &statbuf)); struct timespec ctime = statbuf.st_ctim; if (ctime.tv_sec == this->mosquittoPasswordFileLastLoad.tv_sec) return; logger->logf(LOG_NOTICE, "Change detected in '%s'. Reloading.", this->mosquittoPasswordFile.c_str()); try { std::ifstream infile(this->mosquittoPasswordFile, std::ios::in); std::unique_ptr> passwordEntries_tmp = std::make_unique>(); for(std::string line; getline(infile, line ); ) { trim(line); if (line.empty()) continue; try { std::vector fields = splitToVector(line, ':'); if (fields.size() != 2) throw std::runtime_error(formatString("Passwd file line '%s' contains more than one ':'", line.c_str())); const std::string &username = fields[0]; for (const std::string &field : fields) { if (field.size() == 0) { throw std::runtime_error(formatString("An empty field was found in '%'", line.c_str())); } } std::vector fields2 = splitToVector(fields[1], '$', 4, false); int iterations = -1; int saltField = -1; int hashField = -1; PasswordHashType type = PasswordHashType::SHA512; if (fields2[0] == "6") { if (fields2.size() != 3) throw std::runtime_error(formatString("Invalid line format in '%s'. Expected three fields separated by '$'", line.c_str())); type = PasswordHashType::SHA512; saltField = 1; hashField = 2; } else if (fields2[0] == "7") { if (fields2.size() != 4) throw std::runtime_error(formatString("Invalid line format in '%s'. Expected four fields separated by '$'", line.c_str())); type = PasswordHashType::SHA512_pbkdf2; iterations = std::stoi(fields2[1]); saltField = 2; hashField = 3; } else { throw std::runtime_error("Password fields must start with $6$ or $7$"); } std::vector salt = base64Decode(fields2[saltField]); std::vector cryptedPassword = base64Decode(fields2[hashField]); passwordEntries_tmp->emplace(username, MosquittoPasswordFileEntry(type, std::move(salt), std::move(cryptedPassword), iterations)); } catch (std::exception &ex) { std::string lineCut = formatString("%s...", line.substr(0, 20).c_str()); logger->logf(LOG_ERR, "Dropping invalid username/password line: '%s'. Error: %s", lineCut.c_str(), ex.what()); } } this->mosquittoPasswordEntries = std::move(passwordEntries_tmp); this->mosquittoPasswordFileLastLoad = ctime; } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error loading Mosquitto password file: '%s'. Authentication won't work.", ex.what()); } } void Authentication::loadMosquittoAclFile() { if (this->mosquittoAclFile.empty()) return; if (access(this->mosquittoAclFile.c_str(), R_OK) != 0) { logger->logf(LOG_ERR, "ACL file '%s' is not there or not readable.", this->mosquittoAclFile.c_str()); return; } struct stat statbuf {}; check(stat(mosquittoAclFile.c_str(), &statbuf)); struct timespec ctime = statbuf.st_ctim; if (ctime.tv_sec == this->mosquittoAclFileLastChange.tv_sec) return; logger->logf(LOG_NOTICE, "Change detected in '%s'. Reloading.", this->mosquittoAclFile.c_str()); AclTree newTree; // Not doing by-line error handling, because ingoring one invalid line can completely change the user's intent. try { std::string currentUser; std::ifstream infile(this->mosquittoAclFile, std::ios::in); for(std::string line; getline(infile, line ); ) { trim(line); if (line.empty() || startsWith(line, "#")) continue; const std::vector fields = splitToVector(line, ' ', 3, false); if (fields.size() < 2) throw ConfigFileException(formatString("Line does not have enough fields: %s", line.c_str())); const std::string &firstWord = str_tolower(fields[0]); if (firstWord == "topic" || firstWord == "pattern") { AclGrant g = AclGrant::ReadWrite; std::string topic; if (fields.size() == 3) { topic = fields[2]; g = stringToAclGrant(fields[1]); } else if (fields.size() == 2) { topic = fields[1]; } else throw ConfigFileException(formatString("Invalid markup of 'topic' line: %s", line.c_str())); if (!isValidSubscribePath(topic)) throw ConfigFileException(formatString("Topic '%s' is not a valid ACL topic", topic.c_str())); AclTopicType type = firstWord == "pattern" ? AclTopicType::Patterns : AclTopicType::Strings; newTree.addTopic(topic, g, type, currentUser); } else if (firstWord == "user") { currentUser = fields[1]; } else { throw ConfigFileException(formatString("Invalid keyword '%s' in '%s'", firstWord.c_str(), line.c_str())); } } aclTree = std::move(newTree); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error loading Mosquitto ACL file: '%s'. Authorization won't work.", ex.what()); } mosquittoAclFileLastChange = ctime; } AuthResult Authentication::aclCheckFromMosquittoAclFile(const std::string &clientid, const std::string &username, const std::vector &subtopics, AclAccess access) { assert(access != AclAccess::none); if (this->mosquittoAclFile.empty()) return AuthResult::success; // We have to do this because the Mosquitto ACL file has no notion of checking subscribes. if (access == AclAccess::subscribe) return AuthResult::success; AclGrant ag = access == AclAccess::write ? AclGrant::Write : AclGrant::Read; AuthResult result = aclTree.findPermission(subtopics, ag, username, clientid); return result; } std::optional Authentication::loginCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password) { if (!this->mosquittoPasswordEntries) return AuthResult::login_denied; std::optional result; auto it = mosquittoPasswordEntries->find(username); if (it != mosquittoPasswordEntries->end()) { result = AuthResult::login_denied; const MosquittoPasswordFileEntry &entry = it->second; if (entry.type == PasswordHashType::SHA512) { std::array md_value; unsigned int output_len = 0; EVP_MD_CTX_reset(mosquittoDigestContext.get()); EVP_DigestInit_ex(mosquittoDigestContext.get(), sha512, NULL); EVP_DigestUpdate(mosquittoDigestContext.get(), password.c_str(), password.length()); EVP_DigestUpdate(mosquittoDigestContext.get(), entry.salt.data(), entry.salt.size()); EVP_DigestFinal_ex(mosquittoDigestContext.get(), md_value.data(), &output_len); std::vector hashedSalted = make_vector(md_value, 0, output_len); if (hashedSalted == entry.cryptedPassword) result = AuthResult::success; } else if (entry.type == PasswordHashType::SHA512_pbkdf2) { const auto len = EVP_MD_size(this->sha512); std::array md_value; PKCS5_PBKDF2_HMAC(password.c_str(), password.size(), entry.salt.data(), entry.salt.size(), entry.iterations, sha512, md_value.size(), md_value.data()); std::vector derivedKey = make_vector(md_value, 0, len); if (derivedKey == entry.cryptedPassword) result = AuthResult::success; } } return result; } void Authentication::periodicEvent() { if (pluginFamily == PluginFamily::None) return; if (!initialized) { logger->logf(LOG_ERR, "Auth plugin period event called, but initialization failed or not performed."); return; } if (pluginFamily == PluginFamily::FlashMQ && flashmq_plugin_periodic_event_v1) { try { flashmq_plugin_periodic_event_v1(pluginData); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Exception in 'flashmq_plugin_periodic_event': '%s'.", ex.what()); } } } std::string AuthResultToString(AuthResult r) { if (r == AuthResult::success) return "success"; if (r == AuthResult::acl_denied) return "ACL denied"; if (r == AuthResult::login_denied) return "login Denied"; if (r == AuthResult::error) return "error in check"; return ""; } ================================================ FILE: plugin.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef PLUGIN_H #define PLUGIN_H #include #include #include #include #include "logger.h" #include "acltree.h" #include "flashmq_plugin.h" #include "pluginloader.h" #include "settings.h" #include "types.h" enum class PasswordHashType { SHA512, SHA512_pbkdf2 }; /** * @brief The MosquittoPasswordFileEntry struct stores the decoded base64 password salt and hash. * * The Mosquitto encrypted format looks like that of crypt(2), but it's not. These are example entries: * * one:$6$emTXKCHfxMnZLDWg$gDcJRPojvOX8l7W/DRhSPoxV3CgPfECJVGRzw2Sqjdc2KIQ/CVLS1mNEuZUsp/vLdj7RCuqXCkgG43+XIc8WBA== * two:$7$101$twKcRmS7qxdZtFZiU+yLZHAIRNsm8deqMG9nN44pagg8t5wkUxtyWiNgbUF38cHzmgDja...VPMaNLw== * * $ is the seperator. '6' or '7' is the algorithm. */ struct MosquittoPasswordFileEntry { PasswordHashType type; std::vector salt; std::vector cryptedPassword; int iterations = 0; MosquittoPasswordFileEntry(PasswordHashType type, const std::vector &&salt, const std::vector &&cryptedPassword, int iterations); // The plan was that objects of this type wouldn't be copied, but I can't get emplacing to work without it...? //MosquittoPasswordFileEntry(const MosquittoPasswordFileEntry &other) = delete; }; // Mosquitto functions typedef int (*F_plugin_init_v2)(void **, struct mosquitto_auth_opt *, int); typedef int (*F_plugin_cleanup_v2)(void *, struct mosquitto_auth_opt *, int); typedef int (*F_plugin_security_init_v2)(void *, struct mosquitto_auth_opt *, int, bool); typedef int (*F_plugin_security_cleanup_v2)(void *, struct mosquitto_auth_opt *, int, bool); typedef int (*F_plugin_acl_check_v2)(void *, const char *, const char *, const char *, int); typedef int (*F_plugin_unpwd_check_v2)(void *, const char *, const char *); typedef int (*F_plugin_psk_key_get_v2)(void *, const char *, const char *, char *, int); typedef void(*F_flashmq_plugin_allocate_thread_memory_v1)(void **thread_data, std::unordered_map &auth_opts); typedef void(*F_flashmq_plugin_deallocate_thread_memory_v1)(void *thread_data, std::unordered_map &auth_opts); typedef void(*F_flashmq_plugin_init_v1)(void *thread_data, std::unordered_map &auth_opts, bool reloading); typedef void(*F_flashmq_plugin_deinit_v1)(void *thread_data, std::unordered_map &auth_opts, bool reloading); typedef AuthResult(*F_flashmq_plugin_acl_check_v1)(void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const uint8_t qos, const bool retain, const std::vector> *userProperties); typedef AuthResult(*F_flashmq_plugin_login_check_v1)(void *thread_data, const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client); typedef void (*F_flashmq_plugin_periodic_event_v1)(void *thread_data); typedef AuthResult(*F_flashmq_plugin_extended_auth_v1)(void *thread_data, const std::string &clientid, ExtendedAuthStage stage, const std::string &authMethod, const std::string &authData, const std::vector> *userProperties, std::string &returnData, std::string &username, const std::weak_ptr &client); typedef bool (*F_flashmq_plugin_alter_subscription_v1)(void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, uint8_t &qos, const std::vector> *userProperties); typedef bool (*F_flashmq_plugin_alter_publish_v1)(void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, uint8_t &qos, bool &retain, std::vector> *userProperties); typedef void (*F_flashmq_plugin_client_disconnected_v1)(void *thread_data, const std::string &clientid); typedef void (*F_flashmq_plugin_poll_event_received_v1)(void *thread_data, int fd, int events, const std::weak_ptr &p); typedef AuthResult(*F_flashmq_plugin_acl_check_v2)(void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, std::string_view payload, const uint8_t qos, const bool retain, const std::vector> *userProperties); typedef bool (*F_flashmq_plugin_alter_publish_v2)(void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::vector> *userProperties); typedef AuthResult(*F_flashmq_plugin_acl_check_v3)( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::vector> *userProperties); typedef bool (*F_flashmq_plugin_alter_publish_v3)( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, const std::optional &correlationData, const std::optional &responseTopic, std::vector> *userProperties); typedef AuthResult(*F_flashmq_plugin_acl_check_v4)( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &sharename, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::vector> *userProperties); typedef void(*F_flashmq_plugin_on_unsubscribe_v4)( void *thread_data, const std::weak_ptr &session, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, const std::vector> *userProperties); typedef bool (*F_flashmq_plugin_alter_publish_v5)( void *thread_data, const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::optional &correlationData, std::optional &responseTopic, std::optional &contentType, std::vector> *userProperties); typedef AuthResult(*F_flashmq_plugin_acl_check_v5)( void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &sharename, std::string_view payload, const uint8_t qos, const bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties); std::string AuthResultToString(AuthResult r); /** * @brief The Authentication class handles our integrated authentication, but also the FlashMQ and Mosquitto plugin interfaces. * * It's a bit of a legacy that both plugin handling and auth are in a class called 'Authentication', but oh well... */ class Authentication { // Mosquitto functions F_plugin_init_v2 init_v2 = nullptr; F_plugin_cleanup_v2 cleanup_v2 = nullptr; F_plugin_security_init_v2 security_init_v2 = nullptr; F_plugin_security_cleanup_v2 security_cleanup_v2 = nullptr; F_plugin_acl_check_v2 acl_check_v2 = nullptr; F_plugin_unpwd_check_v2 unpwd_check_v2 = nullptr; F_plugin_psk_key_get_v2 psk_key_get_v2 = nullptr; F_flashmq_plugin_allocate_thread_memory_v1 flashmq_plugin_allocate_thread_memory_v1 = nullptr; F_flashmq_plugin_deallocate_thread_memory_v1 flashmq_plugin_deallocate_thread_memory_v1 = nullptr; F_flashmq_plugin_init_v1 flashmq_plugin_init_v1 = nullptr; F_flashmq_plugin_deinit_v1 flashmq_plugin_deinit_v1 = nullptr; F_flashmq_plugin_acl_check_v1 flashmq_plugin_acl_check_v1 = nullptr; F_flashmq_plugin_login_check_v1 flashmq_plugin_login_check_v1 = nullptr; F_flashmq_plugin_periodic_event_v1 flashmq_plugin_periodic_event_v1 = nullptr; F_flashmq_plugin_extended_auth_v1 flashmq_plugin_extended_auth_v1 = nullptr; F_flashmq_plugin_alter_subscription_v1 flashmq_plugin_alter_subscription_v1 = nullptr; F_flashmq_plugin_alter_publish_v1 flashmq_plugin_alter_publish_v1 = nullptr; F_flashmq_plugin_client_disconnected_v1 flashmq_plugin_client_disconnected_v1 = nullptr; F_flashmq_plugin_poll_event_received_v1 flashmq_plugin_poll_event_received_v1 = nullptr; F_flashmq_plugin_acl_check_v2 flashmq_plugin_acl_check_v2 = nullptr; F_flashmq_plugin_alter_publish_v2 flashmq_plugin_alter_publish_v2 = nullptr; F_flashmq_plugin_acl_check_v3 flashmq_plugin_acl_check_v3 = nullptr; F_flashmq_plugin_alter_publish_v3 flashmq_plugin_alter_publish_v3 = nullptr; F_flashmq_plugin_acl_check_v4 flashmq_plugin_acl_check_v4 = nullptr; F_flashmq_plugin_on_unsubscribe_v4 flashmq_plugin_on_unsubscribe_v4 = nullptr; F_flashmq_plugin_acl_check_v5 flashmq_plugin_acl_check_v5 = nullptr; F_flashmq_plugin_alter_publish_v5 flashmq_plugin_alter_publish_v5 = nullptr; static std::mutex initMutex; static std::mutex deinitMutex; static std::mutex authChecksMutex; PluginFamily pluginFamily = PluginFamily::None; int flashmqPluginVersionNumber = 0; Settings &settings; // A ref because I want it to always be the same as the thread's settings void *pluginData = nullptr; Logger *logger = nullptr; bool initialized = false; bool quitting = false; /** * @brief mosquittoPasswordFile is a once set value based on config. It's not reloaded on reload signal currently, because it * forces some decisions when you change files or remove the config option. For instance, do you remove all accounts loaded * from the previous one? Perhaps I'm overthinking it. * * Its content is, however, reloaded every two seconds. */ const std::string mosquittoPasswordFile; const std::string mosquittoAclFile; struct timespec mosquittoPasswordFileLastLoad {}; struct timespec mosquittoAclFileLastChange {}; std::unique_ptr> mosquittoPasswordEntries; std::unique_ptr mosquittoDigestContext; const EVP_MD *sha512 = EVP_sha512(); AclTree aclTree; void *loadSymbol(void *handle, const char *symbol, bool exceptionOnError = true) const; public: Authentication(Settings &settings); Authentication(const Authentication &other) = delete; Authentication(Authentication &&other) = delete; void loadPlugin(const PluginLoader &l); void init(); void cleanup(); void securityInit(bool reloading); void securityCleanup(bool reloading); void onUnsubscribe( const std::shared_ptr &session, const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &shareName, const std::vector> *userProperties) const; AuthResult aclCheck(Publish &publishData, std::string_view payload, AclAccess access = AclAccess::write); AuthResult aclCheck( const std::string &clientid, const std::string &username, const std::string &topic, const std::vector &subtopics, const std::string &sharename, std::string_view payload, AclAccess access, uint8_t qos, bool retain, const std::optional &correlationData, const std::optional &responseTopic, const std::optional &contentType, const std::optional> expiresAt, const std::vector> *userProperties); AuthResult loginCheck(const std::string &clientid, const std::string &username, const std::string &password, const std::vector> *userProperties, const std::weak_ptr &client, const bool allowAnonymous); AuthResult extendedAuth(const std::string &clientid, ExtendedAuthStage stage, const std::string &authMethod, const std::string &authData, const std::vector> *userProperties, std::string &returnData, std::string &username, const std::weak_ptr &client); bool alterSubscribe(const std::string &clientid, std::string &topic, const std::vector &subtopics, uint8_t &qos, const std::vector> *userProperties); bool alterPublish(const std::string &clientid, std::string &topic, const std::vector &subtopics, std::string_view payload, uint8_t &qos, bool &retain, std::optional &correlationData, std::optional &responseTopic, std::optional &contentType, std::vector> *userProperties); void clientDisconnected(const std::string &clientid); void fdReady(int fd, int events, const std::weak_ptr &p); void setQuitting(); void loadMosquittoPasswordFile(); void loadMosquittoAclFile(); AuthResult aclCheckFromMosquittoAclFile(const std::string &clientid, const std::string &username, const std::vector &subtopics, AclAccess access); std::optional loginCheckFromMosquittoPasswordFile(const std::string &username, const std::string &password); void periodicEvent(); }; #endif // PLUGIN_H ================================================ FILE: pluginloader.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "pluginloader.h" #include "logger.h" #include "exceptions.h" PluginLoader::PluginLoader() { } bool PluginLoader::loaded() const { return handle != nullptr; } void PluginLoader::loadPlugin(const std::string &pathToSoFile) { if (pathToSoFile.empty()) return; Logger *logger = Logger::getInstance(); logger->logf(LOG_NOTICE, "Loading auth plugin %s", pathToSoFile.c_str()); pluginFamily = PluginFamily::Determining; if (access(pathToSoFile.c_str(), R_OK) != 0) { std::ostringstream oss; oss << "Error loading auth plugin: The file " << pathToSoFile << " is not there or not readable"; throw FatalError(oss.str()); } handle = dlopen(pathToSoFile.c_str(), RTLD_LAZY|RTLD_GLOBAL); if (handle == NULL) { const char *err = dlerror(); std::string errmsg = err ? err : ""; throw FatalError(errmsg); } if ((version = (F_plugin_version)loadSymbol("flashmq_plugin_version", false)) != nullptr) { pluginFamily = PluginFamily::FlashMQ; flashmqPluginVersionNumber = version(); if (flashmqPluginVersionNumber < 1 || flashmqPluginVersionNumber > 5) { throw FatalError("This FlashMQ version only support plugin version 1, 2, 3, 4 or 5."); } } else if ((version = (F_plugin_version)loadSymbol("mosquitto_auth_plugin_version", false)) != nullptr) { if (version() != 2) { throw FatalError("Only Mosquitto plugin version 2 is supported at this time."); } pluginFamily = PluginFamily::MosquittoV2; } else { throw FatalError("This does not seem to be a FlashMQ native plugin or Mosquitto plugin version 2."); } if (dlclose(handle) != 0) { std::string errmsg(dlerror()); throw FatalError(errmsg); } version = nullptr; handle = dlopen(pathToSoFile.c_str(), RTLD_NOW|RTLD_GLOBAL); if (handle == NULL) { std::string errmsg(dlerror()); throw FatalError(errmsg); } main_init_v1 = (F_flashmq_plugin_main_init_v1)loadSymbol("flashmq_plugin_main_init", false); main_deinit_v1 = (F_flashmq_plugin_main_deinit_v1)loadSymbol("flashmq_plugin_main_deinit", false); } void *PluginLoader::loadSymbol(const char *symbol, bool exceptionOnError) const { void *r = dlsym(handle, symbol); if (r == NULL && exceptionOnError) { std::string errmsg(dlerror()); throw FatalError(errmsg); } return r; } PluginFamily PluginLoader::getPluginFamily() const { return this->pluginFamily; } int PluginLoader::getFlashMQPluginVersion() const { return this->flashmqPluginVersionNumber; } void PluginLoader::mainInit(std::unordered_map &plugin_opts) { if (!main_init_v1) return; std::optional error; try { main_init_v1(plugin_opts); } catch(std::exception &ex) { error = ex.what(); Logger *logger = Logger::getInstance(); logger->log(LOG_ERR) << "Exception in flashmq_plugin_main_init(): " << ex.what(); } if (error) throw std::runtime_error(error.value()); } void PluginLoader::mainDeinit(std::unordered_map &plugin_opts) { if (!main_deinit_v1) return; try { main_deinit_v1(plugin_opts); } catch(std::exception &ex) { Logger *logger = Logger::getInstance(); logger->log(LOG_ERR) << "Exception in flashmq_plugin_main_deinit(): " << ex.what(); } } ================================================ FILE: pluginloader.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef PLUGINLOADER_H #define PLUGINLOADER_H #include #include #include #include enum class PluginFamily { None, Determining, FlashMQ, MosquittoV2, }; typedef int (*F_plugin_version)(void); typedef void(*F_flashmq_plugin_main_init_v1)(std::unordered_map &auth_opts); typedef void(*F_flashmq_plugin_main_deinit_v1)(std::unordered_map &auth_opts); class PluginLoader { PluginFamily pluginFamily = PluginFamily::None; int flashmqPluginVersionNumber = 0; void* handle = nullptr; F_plugin_version version = nullptr; F_flashmq_plugin_main_init_v1 main_init_v1 = nullptr; F_flashmq_plugin_main_deinit_v1 main_deinit_v1 = nullptr; public: PluginLoader(); PluginLoader(const PluginLoader &other) = delete; bool loaded() const; void loadPlugin(const std::string &pathToSoFile); void *loadSymbol(const char *symbol, bool exceptionOnError = true) const; PluginFamily getPluginFamily() const; int getFlashMQPluginVersion() const; void mainInit(std::unordered_map &plugin_opts); void mainDeinit(std::unordered_map &plugin_opts); }; #endif // PLUGINLOADER_H ================================================ FILE: publishcopyfactory.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include #include "publishcopyfactory.h" #include "mqttpacket.h" PublishCopyFactory::PublishCopyFactory(MqttPacket *packet) : packet(packet), orgQos(packet->getQos()), orgRetain(packet->getRetain()) { } PublishCopyFactory::PublishCopyFactory(Publish *publish) : publish(publish), orgQos(publish->qos), orgRetain(publish->retain) { } MqttPacket *PublishCopyFactory::getOptimumPacket( const uint8_t max_qos, const ProtocolVersion protocolVersion, uint16_t topic_alias, bool skip_topic, uint32_t subscriptionIdentifier, const std::optional &topic_override) { const uint8_t actualQos = getEffectiveQos(max_qos); if (packet) { // The incoming topic alias is not relevant after initial conversion and it should not propagate. assert(packet->getPublishData().topicAlias == 0); // The incoming packet should not have a subscription identifier stored in it. assert(packet->getPublishData().subscriptionIdentifier == 0); // When the packet contains an data specific to the receiver, we don't cache it. if (topic_override || (protocolVersion >= ProtocolVersion::Mqtt5 && topic_alias > 0) || subscriptionIdentifier > 0) { this->oneShotPacket.emplace(protocolVersion, packet->getPublishData(), actualQos, topic_alias, skip_topic, subscriptionIdentifier, topic_override); return &*this->oneShotPacket; } const int layout_key_target = getPublishLayoutCompareKey(protocolVersion, actualQos); if (!packet->biteArrayCannotBeReused() && getPublishLayoutCompareKey(packet->getProtocolVersion(), orgQos) == layout_key_target) { return packet; } // Note that this cache also possibly contains the expiration interval, but because we're only hitting this block for on-line // publishers, the interval has not decreased and is fine. std::optional &cachedPack = constructedPacketCache[layout_key_target]; if (!cachedPack) { // Don't include arguments that are not part of the cache key. cachedPack.emplace(protocolVersion, packet->getPublishData(), actualQos, 0, false, 0, std::optional()); } return &*cachedPack; } // Getting an instance of a Publish object happens at least on retained messages, will messages and SYS topics. It's low traffic, anyway. assert(publish); // The incoming topic alias is not relevant after initial conversion and it should not propagate. assert(publish->topicAlias == 0); this->oneShotPacket.emplace(protocolVersion, *publish, actualQos, topic_alias, skip_topic, subscriptionIdentifier, topic_override); return &*this->oneShotPacket; } uint8_t PublishCopyFactory::getEffectiveQos(uint8_t max_qos) const { const uint8_t effectiveQos = std::min(orgQos, max_qos); return effectiveQos; } bool PublishCopyFactory::getEffectiveRetain(bool retainAsPublished) const { return orgRetain && retainAsPublished; } const std::string &PublishCopyFactory::getTopic() const { if (packet) return packet->getTopic(); assert(publish); return publish->topic; } const std::vector &PublishCopyFactory::getSubtopics() { if (packet) { return packet->getSubtopics(); } else if (publish) { return publish->getSubtopics(); } throw std::runtime_error("Bug in &PublishCopyFactory::getSubtopics()"); } std::string_view PublishCopyFactory::getPayload() const { if (packet) return packet->getPayloadView(); assert(publish); return publish->payload; } bool PublishCopyFactory::getRetain() const { // Keeping this here as reminder that it should not be implemented. assert(false); return false; } /** * @brief PublishCopyFactory::getNewPublish gets a new publish object from an existing packet or publish. * @param new_max_qos * @return * * It being a public function, the idea is that it's only needed for creating publish objects for storing QoS messages for off-line * clients. For on-line clients, you're always making a packet (with getOptimumPacket()). */ Publish PublishCopyFactory::getNewPublish(uint8_t new_max_qos, bool retainAsPublished, uint32_t subscriptionIdentifier) const { // (At time of writing) we only need to construct new publishes for QoS (because we're storing QoS publishes for offline clients). If // you're doing it elsewhere, it's a bug. assert(orgQos > 0); assert(new_max_qos > 0); const uint8_t actualQos = getEffectiveQos(new_max_qos); if (packet) { assert(packet->getQos() > 0); Publish p(packet->getPublishData()); p.qos = actualQos; p.retain = getEffectiveRetain(retainAsPublished); p.subscriptionIdentifier = subscriptionIdentifier; return p; } assert(publish->qos > 0); // Same check as above, but then for Publish objects. Publish p(*publish); p.qos = actualQos; p.retain = getEffectiveRetain(retainAsPublished); return p; } const std::vector > *PublishCopyFactory::getUserProperties() const { if (packet) { return packet->getUserProperties(); } assert(publish); return publish->getUserProperties(); } const std::optional &PublishCopyFactory::getCorrelationData() const { if (packet) return packet->getCorrelationData(); assert(publish); return publish->correlationData; } const std::optional &PublishCopyFactory::getResponseTopic() const { if (packet) return packet->getResponseTopic(); assert(publish); return publish->responseTopic; } const std::optional &PublishCopyFactory::getContentType() const { if (packet) return packet->getContentType(); assert(publish); return publish->contentType; } const std::optional > PublishCopyFactory::getExpiresAt() const { if (packet) return packet->getExpiresAt(); assert(publish); return publish->expiresAt(); } int PublishCopyFactory::getPublishLayoutCompareKey(ProtocolVersion pv, uint8_t qos) { int key = 0; switch (pv) { case ProtocolVersion::None: key = 0; break; case ProtocolVersion::Mqtt31: case ProtocolVersion::Mqtt311: key = 1; break; case ProtocolVersion::Mqtt5: key = 2; break; default: key = 3; } key = (key * 10) + static_cast(qos); return key; } ================================================ FILE: publishcopyfactory.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef PUBLISHCOPYFACTORY_H #define PUBLISHCOPYFACTORY_H #include #include #include #include "mqttpacket.h" #include "forward_declarations.h" #include "types.h" /** * @brief The PublishCopyFactory class is for managing copies of an incoming publish, including sometimes not making copies at all. * * The idea is that certain incoming packets can just be written to the receiving client as-is, without constructing a new one. We do have to change the bytes * where the QoS is stored, so we keep track of the original QoS. * * Ownership info: object of this type are never copied or transferred, so the internal pointers come from (near) the same scope these objects * are created from. */ class PublishCopyFactory { MqttPacket *packet = nullptr; Publish *publish = nullptr; std::optional oneShotPacket; const uint8_t orgQos; const bool orgRetain = false; std::unordered_map> constructedPacketCache; public: PublishCopyFactory(MqttPacket *packet); PublishCopyFactory(Publish *publish); PublishCopyFactory(const PublishCopyFactory &other) = delete; PublishCopyFactory(PublishCopyFactory &&other) = delete; MqttPacket *getOptimumPacket( const uint8_t max_qos, const ProtocolVersion protocolVersion, uint16_t topic_alias, bool skip_topic, uint32_t subscriptionIdentifier, const std::optional &topic_override); uint8_t getEffectiveQos(uint8_t max_qos) const; bool getEffectiveRetain(bool retainAsPublished) const; const std::string &getTopic() const; const std::vector &getSubtopics(); std::string_view getPayload() const; bool getRetain() const; Publish getNewPublish(uint8_t new_max_qos, bool retainAsPublished, uint32_t subscriptionIdentifier) const; const std::vector> *getUserProperties() const; const std::optional &getCorrelationData() const; const std::optional &getResponseTopic() const; const std::optional &getContentType() const; const std::optional> getExpiresAt() const; static int getPublishLayoutCompareKey(ProtocolVersion pv, uint8_t qos); }; #endif // PUBLISHCOPYFACTORY_H ================================================ FILE: qospacketqueue.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "qospacketqueue.h" #include #include "mqttpacket.h" QueuedPublish::QueuedPublish(Publish &&publish, uint16_t packet_id, const std::optional &topic_override) : publish(std::move(publish)), packet_id(packet_id), topic_override(topic_override) { } uint16_t QueuedPublish::getPacketId() const { return this->packet_id; } Publish &QueuedPublish::getPublish() { return publish; } const std::optional &QueuedPublish::getTopicOverride() const { return topic_override; } size_t QueuedPublish::getApproximateMemoryFootprint() const { // TODO: hmm, this is possibly very inaccurate with MQTT5 packets. return publish.topic.length() + publish.payload.length(); } void QoSPublishQueue::addToExpirationQueue(std::shared_ptr &qp) { Publish &pub = qp->getPublish(); if (!pub.expireInfo) return; this->queueExpirations[pub.expireInfo->expiresAt()] = qp->getPacketId(); } bool QoSPublishQueue::erase(const uint16_t packet_id) { bool result = false; auto pos = this->queue.find(packet_id); if (pos != this->queue.end()) { std::shared_ptr &qp = pos->second; const size_t mem = qp->getApproximateMemoryFootprint(); qosQueueBytes -= mem; assert(qosQueueBytes >= 0); if (qosQueueBytes < 0) // Should not happen, but correcting a hypothetical bug is fine for this purpose. qosQueueBytes = 0; result = true; this->eraseFromMapAndRelinkList(pos); } return result; } /** * @brief QoSPublishQueue::eraseFromMapAndRelinkList Removes a QueuedPublish from the unordered_map and relink previous andnext together. * @param pos * * This is an internal helper, and therefore doesn't do anything with qosQueueBytes. */ void QoSPublishQueue::eraseFromMapAndRelinkList(std::unordered_map>::iterator pos) { std::shared_ptr &qp = pos->second; auto _prev = qp->prev.lock(); if (_prev) _prev->next = qp->next; if (this->head == qp) this->head = qp->prev.lock(); auto _next = qp->next.lock(); if (_next) _next->prev = qp->prev; if (this->tail == qp) this->tail = qp->next.lock(); this->queue.erase(pos); } const std::shared_ptr &QoSPublishQueue::getTail() const { return tail; } std::shared_ptr QoSPublishQueue::popNext() { std::shared_ptr result = this->tail; if (result) erase(result->getPacketId()); return result; } size_t QoSPublishQueue::size() const { return queue.size(); } size_t QoSPublishQueue::getByteSize() const { return qosQueueBytes; } void QoSPublishQueue::addToHeadOfLinkedList(std::shared_ptr &qp) { qp->prev = this->head; if (this->head) this->head->next = qp; this->head = qp; if (!this->tail) this->tail = qp; } /** * @brief QoSPublishQueue::queuePublish * * Note that it may seem a bit weird to queue messages with retain flags, because retained messages can only happen on * subscribe, which an offline client can't do. However, MQTT5 introduces 'retained as published', so it becomes valid. Bridge * mode uses this as well. */ void QoSPublishQueue::queuePublish( PublishCopyFactory ©Factory, uint16_t id, uint8_t new_max_qos, bool retainAsPublished, const uint32_t subscriptionIdentifier, const std::optional &topic_override) { assert(new_max_qos > 0); assert(id > 0); Publish pub = copyFactory.getNewPublish(new_max_qos, retainAsPublished, subscriptionIdentifier); std::shared_ptr qp = std::make_shared(std::move(pub), id, topic_override); addToHeadOfLinkedList(qp); qosQueueBytes += qp->getApproximateMemoryFootprint(); addToExpirationQueue(qp); queue[id] = std::move(qp); } /** * @brief QoSPublishQueue::queuePublish moves the publish into the queue. * @param pub * @param id * @param topic_override could/should theoretically also have been an rvalue ref, but that required maintaining * a various constructors of QueuedPublish with ref and rref arguments, which didn't seem worth it. So far, this * function is only used for loading from disk, so not the hot path. */ void QoSPublishQueue::queuePublish(Publish &&pub, uint16_t id, const std::optional &topic_override) { assert(id > 0); std::shared_ptr qp = std::make_shared(std::move(pub), id, topic_override); addToHeadOfLinkedList(qp); qosQueueBytes += qp->getApproximateMemoryFootprint(); addToExpirationQueue(qp); queue[id] = std::move(qp); } int QoSPublishQueue::clearExpiredMessages() { if (this->queueExpirations.empty()) return 0; if (this->queueExpirations.begin()->first > std::chrono::steady_clock::now()) return 0; int removed = 0; auto it = queueExpirations.begin(); auto end = queueExpirations.end(); while (it != end) { auto cur_it = it; it++; const std::chrono::time_point &when = cur_it->first; if (when > std::chrono::steady_clock::now()) { break; } auto qpos = this->queue.find(cur_it->second); if (qpos != this->queue.end()) { std::shared_ptr &p = qpos->second; if (p->getPublish().hasExpired()) { this->eraseFromMapAndRelinkList(qpos); this->queueExpirations.erase(cur_it); removed++; } } } return removed; } ================================================ FILE: qospacketqueue.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef QOSPACKETQUEUE_H #define QOSPACKETQUEUE_H #include #include #include "types.h" #include "publishcopyfactory.h" /** * @brief The QueuedPublish class wraps the publish with a packet id. * * We don't want to store the packet id in the Publish object, because the packet id is determined/tracked per client/session. */ class QueuedPublish { Publish publish; uint16_t packet_id = 0; // We store this separately because because we need to retain the original publish path for ACL checking upon resending. std::optional topic_override; public: QueuedPublish(Publish &&publish, uint16_t packet_id, const std::optional &topic_override); QueuedPublish(const QueuedPublish &other) = delete; QueuedPublish(QueuedPublish &&other) = delete; QueuedPublish &operator=(const QueuedPublish&) = delete; QueuedPublish &operator=(QueuedPublish&&) = delete; std::weak_ptr prev; std::weak_ptr next; size_t getApproximateMemoryFootprint() const; uint16_t getPacketId() const; Publish &getPublish(); const std::optional &getTopicOverride() const; }; class QoSPublishQueue { std::shared_ptr head; std::shared_ptr tail; std::unordered_map> queue; std::map, uint16_t> queueExpirations; ssize_t qosQueueBytes = 0; void addToExpirationQueue(std::shared_ptr &qp); void eraseFromMapAndRelinkList(std::unordered_map>::iterator pos); void addToHeadOfLinkedList(std::shared_ptr &qp); public: QoSPublishQueue() = default; // We make this uncopyable because of the linked list QueuedPublish objects, making a deep-copy difficult. QoSPublishQueue(const QoSPublishQueue &other) = delete; QoSPublishQueue &operator=(const QoSPublishQueue &other) = delete; QoSPublishQueue(QoSPublishQueue &&other) = default; QoSPublishQueue &operator=(QoSPublishQueue &&other) = default; bool erase(const uint16_t packet_id); size_t size() const; size_t getByteSize() const; void queuePublish( PublishCopyFactory ©Factory, uint16_t id, uint8_t new_max_qos, bool retainAsPublished, const uint32_t subscriptionIdentifier, const std::optional &topic_override); void queuePublish(Publish &&pub, uint16_t id, const std::optional &topic_override); int clearExpiredMessages(); const std::shared_ptr &getTail() const; std::shared_ptr popNext(); }; #endif // QOSPACKETQUEUE_H ================================================ FILE: queuedtasks.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "queuedtasks.h" #include "logger.h" bool QueuedTask::operator<(const QueuedTask &rhs) const { return this->when < rhs.when; } QueuedTasks::QueuedTasks() { } uint32_t QueuedTasks::addTask(std::function f, uint32_t delayInMs, bool repeat) { std::chrono::time_point when = std::chrono::steady_clock::now() + std::chrono::milliseconds(delayInMs); while(++nextId == 0 || tasks.find(nextId) != tasks.end()) { } const uint32_t id = nextId; std::shared_ptr> &inserted = tasks[id]; inserted = std::make_shared>(std::move(f)); QueuedTask t; t.id = id; t.f = inserted; t.when = when; t.interval = std::chrono::milliseconds(delayInMs); t.repeat = repeat; queuedTasks.insert(t); return id; } void QueuedTasks::eraseTask(uint32_t id) { tasks.erase(id); } uint32_t QueuedTasks::getTimeTillNext() const { if (__builtin_expect(queuedTasks.empty(), 1)) return std::numeric_limits::max(); std::chrono::time_point next = queuedTasks.begin()->when; std::chrono::milliseconds x = std::chrono::duration_cast(next - std::chrono::steady_clock::now()); std::chrono::milliseconds y = std::max(std::chrono::milliseconds(0), x); return y.count(); } void QueuedTasks::performAll() { const auto now = std::chrono::steady_clock::now(); std::vector>> functions; for (auto pos = queuedTasks.begin(); pos != queuedTasks.end(); ) { if (pos->when > now || functions.size() > 0xFFFF) { break; } const auto cur = pos++; const auto tpos = tasks.find(cur->id); if (tpos != tasks.end() && cur->f.lock() == tpos->second) { functions.push_back(tpos->second); if (cur->repeat) { QueuedTask requeue = *cur; requeue.when = std::chrono::steady_clock::now() + requeue.interval; queuedTasks.insert(requeue); } else { tasks.erase(tpos); } } queuedTasks.erase(cur); } for(const std::shared_ptr> &f : functions) { try { if (!f || !*f) continue; f->operator()(); } catch (std::exception &ex) { Logger *logger = Logger::getInstance(); logger->logf(LOG_ERR, "Error in delayed task: %s", ex.what()); } } } void QueuedTasks::clear() { queuedTasks.clear(); tasks.clear(); } ================================================ FILE: queuedtasks.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef QUEUEDTASKS_H #define QUEUEDTASKS_H #include #include #include #include #include struct QueuedTask { std::chrono::time_point when; std::chrono::milliseconds interval; uint32_t id = 0; bool repeat = false; std::weak_ptr> f; bool operator<(const QueuedTask &rhs) const; }; /** * @brief Contains delayed tasks to perform. * * At this point, it's not for cross-thread use, so not protected with mutexes etc. */ class QueuedTasks { uint32_t nextId = 1; std::multiset queuedTasks; std::unordered_map>> tasks; public: QueuedTasks(); uint32_t addTask(std::function f, uint32_t delayInMs, bool repeat=false); void eraseTask(uint32_t id); uint32_t getTimeTillNext() const; void performAll(); void clear(); size_t getTaskCount() const { return tasks.size(); } }; #endif // QUEUEDTASKS_H ================================================ FILE: release.sh ================================================ #!/bin/bash SEMVER_REGEXP="^([0-9]+)\.([0-9]+)\.([0-9]+)$" SCRIPT_NAME=$(basename "$0") SCRIPT_DIR="$(dirname "$(realpath "$(readlink -f "$0")")")" CMAKEFILE="$SCRIPT_DIR/CMakeLists.txt" usage() { echo -e "\e[1m$SCRIPT_NAME\e[22m – Tag a new release of \e[1;38;2;0;55;112;48;2;255;255;255m\e[39;49mFlashMQ\e[0m. Usage: $SCRIPT_NAME $SCRIPT_NAME -h\e[2m|\e[22m--help Options & arguments: A semantic version number, like 1.0.12, conforming to \e[4mhttps://semver.org/spec/v2.0.0.html\e[24m. -h\e[2m|\e[22m--help Display this help. " } usage_error() { echo -e "\e[31m$1\e[0m" >&2 echo >&2 usage >&2 exit 2 } while [[ -n "$1" ]]; do case "$1" in -h|--help) usage exit 0 ;; -*) usage_error "Unknown option: \e[1m$1\e[22m" ;; *) semver="$1" shift if [[ -n "$1" ]]; then usage_error "Extraneous arguments after (\e[1m$semver\e[22m)." fi ;; esac done if [[ -z "$semver" ]]; then usage_error "Missing \e[1m\e[22m argument." fi if [[ ! "$semver" =~ $SEMVER_REGEXP ]]; then usage_error "\e[1m$semver\e[22m is not a recognizable \e[1msemver\e[22m; \$SEMVER_REGEXP = \e[1m$SEMVER_REGEXP\e[22m" fi cmd() { msg="$1"; shift shift echo -e "$msg: \e[1m$(printf '%q ' "$@")\e[22m" >&2 "$@" local retval=$? if [[ "$retval" -eq 0 ]]; then echo -e "\e[32;1m✔ done\e[39;22m" >&2 else echo -e "\e[31;1m❌failed with exit code \e[4m$retval\e[24m\e[39;22m" >&2 fi echo >&2 return "$retval" } if output=$(git status --porcelain) && [[ -n "$output" ]]; then 1>&2 echo "Git not clean" exit 3 fi git_tag="v$semver" git_msg="Version $semver" if git rev-parse "$git_tag" &>/dev/null; then 1>&2 echo "Git tag seems already in use" exit 3 fi if ! cmd "Setting project version" -- sed -i -E '/^project\(/s/^(project\([^ ]+ VERSION )[^ ]+( .*)$/\1'"$semver"'\2/' "$CMAKEFILE"; then exit 3 fi if ! cmd "Staging changes to \e[1m$CMAKEFILE\e[22m" -- git add "$CMAKEFILE"; then exit 3 fi if ! cmd "Committing release $semver" -- git commit -m "$git_msg"; then exit 3 fi cmd "Tagging release $semver" -- git tag -a -m "$git_msg" "$git_tag" # vim: set expandtab tabstop=4 shiftwidth=4: ================================================ FILE: retainedmessage.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "retainedmessage.h" #include "threadglobals.h" #include "settings.h" RetainedMessage::RetainedMessage(const Publish &publish) : publish(publish) { setRetainData(); } bool RetainedMessage::operator==(const RetainedMessage &rhs) const { return this->publish.topic == rhs.publish.topic; } void RetainedMessage::setRetainData() { this->publish.retain = true; const Settings *settings = ThreadGlobals::getSettings(); this->publish.setExpireAfterToCeiling(settings->expireRetainedMessagesAfterSeconds); } RetainedMessage &RetainedMessage::operator=(const Publish &pub) { this->publish = pub; setRetainData(); return *this; } bool RetainedMessage::empty() const { return publish.payload.empty(); } uint32_t RetainedMessage::getSize() const { return publish.topic.length() + publish.payload.length() + 1; } /** * @brief RetainedMessage::hasExpired is more dynamic than a publish's expire info, because the settings may have changed in the mean time. * @return */ bool RetainedMessage::hasExpired() const { if (this->publish.hasExpired()) return true; const Settings *settings = ThreadGlobals::getSettings(); std::chrono::milliseconds expireAge(settings->expireRetainedMessagesAfterSeconds); if (this->publish.getAge() > expireAge) return true; return false; } ================================================ FILE: retainedmessage.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef RETAINEDMESSAGE_H #define RETAINEDMESSAGE_H #include #include "types.h" struct RetainedMessage { Publish publish; RetainedMessage(const Publish &publish); RetainedMessage(const RetainedMessage&) = default; RetainedMessage &operator=(const RetainedMessage&) = delete; RetainedMessage &operator=(RetainedMessage&&) = delete; bool operator==(const RetainedMessage &rhs) const; RetainedMessage &operator=(const Publish &pub); void setRetainData(); bool empty() const; uint32_t getSize() const; bool hasExpired() const; }; namespace std { template <> struct hash { std::size_t operator()(const RetainedMessage& k) const { using std::size_t; using std::hash; using std::string; return hash()(k.publish.topic); } }; } #endif // RETAINEDMESSAGE_H ================================================ FILE: retainedmessagesdb.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include #include #include #include #include #include #include #include "retainedmessagesdb.h" #include "utils.h" #include "logger.h" #include "mqttpacket.h" #include "threadglobals.h" #include "client.h" #include "logger.h" RetainedMessagesDB::RetainedMessagesDB(const std::string &filePath) : PersistenceFile(filePath) { } RetainedMessagesDB::~RetainedMessagesDB() { closeFile(); } void RetainedMessagesDB::openWrite() { PersistenceFile::openWrite(MAGIC_STRING_V4); this->written_count = 0; const int64_t now_epoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); logger->log(LOG_DEBUG) << "Saving current time stamp " << now_epoch << " in retained messages DB."; writeInt64(now_epoch); length_pos = ftell(f); writeUint32(0); char reserved[RESERVED_SPACE_RETAINED_DB_V2]; std::memset(reserved, 0, RESERVED_SPACE_RETAINED_DB_V2); writeCheck(reserved, 1, RESERVED_SPACE_RETAINED_DB_V2, f); } void RetainedMessagesDB::openRead() { const std::string current_magic_string(MAGIC_STRING_V4); PersistenceFile::openRead(current_magic_string); if (detectedVersionString == MAGIC_STRING_V1) readVersion = ReadVersion::v1; else if (detectedVersionString == MAGIC_STRING_V2) readVersion = ReadVersion::v2; else if (detectedVersionString == MAGIC_STRING_V3) readVersion = ReadVersion::v3; else if (detectedVersionString == current_magic_string) readVersion = ReadVersion::v4; else throw std::runtime_error("Unknown file version."); bool eofFound = false; if (readVersion >= ReadVersion::v4) { const int64_t fileSavedAt = readInt64(eofFound); if (eofFound) throw std::runtime_error("Error reading retained messages file age: eof reached."); const int64_t now_epoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); persistence_state_age = fileSavedAt > now_epoch ? 0 : now_epoch - fileSavedAt; } to_read_count = readUint32(eofFound); if (eofFound) throw std::runtime_error("Error reading retained messages file message count: eof reached."); fseek(f, RESERVED_SPACE_RETAINED_DB_V2, SEEK_CUR); } void RetainedMessagesDB::closeFile() { if (!f) return; if (openMode == FileMode::write && length_pos > 0 && written_count > 0) { fseek(f, length_pos, SEEK_SET); writeUint32(written_count); } PersistenceFile::closeFile(); } void RetainedMessagesDB::dontSaveTmpFile() { PersistenceFile::dontSaveTmpFile(); } /** * @brief RetainedMessagesDB::saveData doesn't explicitely name a file version (v1, etc), because we always write the current definition. * @param messages */ void RetainedMessagesDB::saveData(const std::vector &messages) { if (!f) return; CirBuf cirbuf(1024); for (const RetainedMessage &rm : messages) { if (logger->wouldLog(LOG_DEBUG)) { logger->log(LOG_DEBUG) << "Saving retained message for topic '" << rm.publish.topic << "' QoS " << static_cast(rm.publish.qos) << ", age " << rm.publish.getAge().count() << " seconds."; } this->written_count++; Publish pcopy(rm.publish); MqttPacket pack(ProtocolVersion::Mqtt5, pcopy); // Dummy, to please the parser on reading. if (pcopy.qos > 0) pack.setPacketId(666); const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); const uint32_t pubAge = pcopy.expireInfo ? ageFromTimePoint(pcopy.expireInfo.value().createdAt) : 0; cirbuf.reset(); cirbuf.ensureFreeSpace(packSize + 32); pack.readIntoBuf(cirbuf); writeUint16(pack.getFixedHeaderLength()); writeUint32(pubAge); writeUint32(packSize); writeString(pcopy.client_id); writeString(pcopy.username); writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f); } fflush(f); } std::list RetainedMessagesDB::readData(size_t max) { std::list defaultResult; if (!f) return defaultResult; if (readVersion == ReadVersion::v1) logger->logf(LOG_WARNING, "File '%s' is version 1, an internal development version that was never finalized. Not reading.", getFilePath().c_str()); if (readVersion == ReadVersion::v2) logger->logf(LOG_WARNING, "File '%s' is version 2, an internal development version that was never finalized. Not reading.", getFilePath().c_str()); if (readVersion == ReadVersion::v3 || readVersion == ReadVersion::v4) return readDataV3V4(max); return defaultResult; } std::list RetainedMessagesDB::readDataV3V4(size_t max) { std::list messages; CirBuf cirbuf(1024); const Settings &settings = *ThreadGlobals::getSettings(); std::shared_ptr dummyThreadData; std::shared_ptr dummyClient(std::make_shared(ClientType::Normal, -1, dummyThreadData, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingretained", {}, "nobody", true, 60); bool eofFound = false; const uint32_t numberOfMessages = std::min(to_read_count, max); for(uint32_t i = 0; i < numberOfMessages; i++) { assert(to_read_count > 0); to_read_count--; const uint16_t fixed_header_length = readUint16(eofFound); uint32_t originalPubAge = 0; if (readVersion >= ReadVersion::v4) { originalPubAge = readUint32(eofFound); } const uint32_t newPubAge = persistence_state_age + originalPubAge; const uint32_t packlen = readUint32(eofFound); const std::string client_id = readString(eofFound); const std::string username = readString(eofFound); if (eofFound) throw std::runtime_error("Error reading retained messages: unexpected end of file"); cirbuf.reset(); cirbuf.ensureFreeSpace(packlen + 32); readCheck(cirbuf.headPtr(), 1, packlen, f); cirbuf.advanceHead(packlen); MqttPacket pack(cirbuf.readToVector(packlen), fixed_header_length, dummyClient); pack.parsePublishData(dummyClient); Publish pub(pack.getPublishData()); pub.client_id = client_id; pub.username = username; if (pub.expireInfo) pub.expireInfo.value().createdAt = timepointFromAge(newPubAge); RetainedMessage msg(pub); if (logger->wouldLog(LOG_DEBUG)) { logger->log(LOG_DEBUG) << "Loading retained message for topic '" << msg.publish.topic << "' QoS " << static_cast(msg.publish.qos) << ", age " << msg.publish.getAge().count() << " seconds."; } messages.push_back(std::move(msg)); } return messages; } ================================================ FILE: retainedmessagesdb.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef RETAINEDMESSAGESDB_H #define RETAINEDMESSAGESDB_H #include "persistencefile.h" #include "retainedmessage.h" #define MAGIC_STRING_V1 "FlashMQRetainedDBv1" #define MAGIC_STRING_V2 "FlashMQRetainedDBv2" #define MAGIC_STRING_V3 "FlashMQRetainedDBv3" #define MAGIC_STRING_V4 "FlashMQRetainedDBv4" #define RESERVED_SPACE_RETAINED_DB_V2 64 /** * @brief The RetainedMessagesDB class saves and loads the retained messages. * * The DB looks like, from the top: * * MAGIC_STRING_LENGH bytes file header * HASH_SIZE SHA512 * [MESSAGES] * * Each message has a row header, which is 8 bytes. See writeRowHeader(). * */ class RetainedMessagesDB : private PersistenceFile { enum class ReadVersion { unknown, v1, v2, v3, v4 }; struct RowHeader { uint32_t topicLen = 0; uint32_t payloadLen = 0; }; ReadVersion readVersion = ReadVersion::unknown; std::list readDataV3V4(size_t max); uint32_t written_count = 0; long length_pos = 0; uint32_t to_read_count = 0; int64_t persistence_state_age = 0; public: RetainedMessagesDB(const std::string &filePath); virtual ~RetainedMessagesDB(); void openWrite(); void openRead(); void closeFile(); void dontSaveTmpFile(); void saveData(const std::vector &messages); std::list readData(size_t max=std::numeric_limits::max()); }; #endif // RETAINEDMESSAGESDB_H ================================================ FILE: rwlockguard.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "rwlockguard.h" #include "utils.h" #include #include RWLockGuard::RWLockGuard(pthread_rwlock_t *rwlock) : rwlock(rwlock) { } RWLockGuard::~RWLockGuard() { unlock(); } bool RWLockGuard::trywrlock() { const int rc = pthread_rwlock_trywrlock(rwlock); if (rc == 0) return true; if (rc == EINVAL) throw std::runtime_error("Lock not initialized."); rwlock = nullptr; return false; } /** * @brief RWLockGuard::wrlock locks for writing i.e. exclusive lock. * * Contrary to rdlock, we don't accept deadlock errors, because you may be trying to upgrade a read lock to a write lock, which will * cause actualy dead locks. */ void RWLockGuard::wrlock() { const int rc = pthread_rwlock_wrlock(rwlock); if (rc != 0) throw std::runtime_error("wrlock failed."); } bool RWLockGuard::tryrdlock() { const int rc = pthread_rwlock_tryrdlock(rwlock); if (rc == 0) return true; if (rc == EINVAL) throw std::runtime_error("Lock not initialized."); rwlock = nullptr; return false; } /** * @brief RWLockGuard::tryfirstrdlock is different than pthread_rwlock_timedrdlock. We try to avoid locking for periods. * @param limit * @param sleep_time */ void RWLockGuard::tryfirstrdlock(std::chrono::time_point limit, std::chrono::microseconds sleep_time) { while (std::chrono::steady_clock::now() < limit) { const int rc = pthread_rwlock_tryrdlock(rwlock); if (rc == 0) return; if (rc == EINVAL || rc == EDEADLK) { rwlock = nullptr; throw std::runtime_error("Lock not initialized or you already have a write lock."); } std::this_thread::sleep_for(sleep_time); } rdlock(); } /** * @brief RWLockGuard::rdlock locks for reading, and considers it OK of the current thread already owns the lock for writing. * * The pthread_rwlock_rdlock man page says: "EDEADLK: The current thread already owns the read-write lock for writing." I hope * that is literally the case. */ void RWLockGuard::rdlock() { int rc = pthread_rwlock_rdlock(rwlock); if (rc == EDEADLK) { rwlock = NULL; return; } if (rc != 0) throw std::runtime_error(strerror(rc)); } void RWLockGuard::unlock() { if (rwlock != NULL) { pthread_rwlock_unlock(rwlock); rwlock = NULL; } } ================================================ FILE: rwlockguard.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef RWLOCKGUARD_H #define RWLOCKGUARD_H #include #include class RWLockGuard { pthread_rwlock_t *rwlock = NULL; public: RWLockGuard(pthread_rwlock_t *rwlock); ~RWLockGuard(); bool trywrlock(); void wrlock(); bool tryrdlock(); void tryfirstrdlock(std::chrono::time_point limit, std::chrono::microseconds sleep_time); void rdlock(); void unlock(); }; #endif // RWLOCKGUARD_H ================================================ FILE: scopedsocket.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "scopedsocket.h" #include #include #include "utils.h" #include "logger.h" ScopedSocket::ScopedSocket(int socket, const std::string &unixSocketPath, const std::shared_ptr &listener) : socket(socket), unixSocketPath(unixSocketPath), listener(listener) { if (this->socket < 0) throw std::runtime_error("Cannot create scoped socket"); } ScopedSocket::ScopedSocket(ScopedSocket &&other) { assert(this != &other); *this = std::move(other); } ScopedSocket::~ScopedSocket() { if (socket >= 0) close(socket); socket = -1; listener.reset(); unlink_if_sock(unixSocketPath); } int ScopedSocket::get() const { return socket; } ScopedSocket &ScopedSocket::operator=(ScopedSocket &&other) { assert(this != &other); if (this->socket >= 0) { close(this->socket); this->socket = -1; } this->listener = std::move(other.listener); this->socket = other.socket; other.socket = -1; this->listening = other.listening; other.listening = false; this->unixSocketPath = std::move(other.unixSocketPath); other.unixSocketPath.clear(); this->listenMessage = std::move(other.listenMessage); other.listenMessage.clear(); return *this; } std::shared_ptr ScopedSocket::getListener() const { return listener.lock(); } void ScopedSocket::setListenMessage(const std::string &s) { this->listenMessage = s; } void ScopedSocket::doListen(int epoll_fd) { if (listening) return; if (!this->listenMessage.empty()) Logger::getInstance()->log(LOG_NOTICE) << this->listenMessage; check(listen(socket, 32768)); struct epoll_event ev {}; ev.data.fd = socket; ev.events = EPOLLIN; check(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, socket, &ev)); listening = true; } ================================================ FILE: scopedsocket.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SCOPEDSOCKET_H #define SCOPEDSOCKET_H #include #include #include "listener.h" /** * @brief The ScopedSocket struct allows for a bit of RAII and move semantics on a socket fd. */ class ScopedSocket { int socket = -1; std::string unixSocketPath; std::weak_ptr listener; bool listening = false; std::string listenMessage; public: ScopedSocket() = default; ScopedSocket(int socket, const std::string &unixSocketPath, const std::shared_ptr &listener); ScopedSocket(const ScopedSocket &other) = delete; ScopedSocket(ScopedSocket &&other); ~ScopedSocket(); int get() const; ScopedSocket &operator=(ScopedSocket &&other); std::shared_ptr getListener() const; void setListenMessage(const std::string &s); void doListen(int epoll_fd); }; #endif // SCOPEDSOCKET_H ================================================ FILE: sdnotify.cpp ================================================ #include "sdnotify.h" /* SPDX-License-Identifier: MIT-0 */ /* Implement the systemd notify protocol without external dependencies. * Supports both readiness notification on startup and on reloading, * according to the protocol defined at: * https://www.freedesktop.org/software/systemd/man/latest/sd_notify.html * This protocol is guaranteed to be stable as per: * https://systemd.io/PORTABILITY_AND_STABILITY/ */ #define _GNU_SOURCE 1 #include #include #include #include #include #include #include #include #include #include #include #define _cleanup_(f) __attribute__((cleanup(f))) static void closep(int *fd) { if (!fd || *fd < 0) return; close(*fd); *fd = -1; } static int notify(const char *message) { union sockaddr_union { struct sockaddr sa; struct sockaddr_un sun; } socket_addr {}; socket_addr.sun.sun_family = AF_UNIX; size_t path_length, message_length; _cleanup_(closep) int fd = -1; const char *socket_path; /* Verify the argument first */ if (!message) return -EINVAL; message_length = strlen(message); if (message_length == 0) return -EINVAL; /* If the variable is not set, the protocol is a noop */ socket_path = getenv("NOTIFY_SOCKET"); if (!socket_path) return 0; /* Not set? Nothing to do */ /* Only AF_UNIX is supported, with path or abstract sockets */ if (socket_path[0] != '/' && socket_path[0] != '@') return -EAFNOSUPPORT; path_length = strlen(socket_path); /* Ensure there is room for NUL byte */ if (path_length >= sizeof(socket_addr.sun.sun_path)) return -E2BIG; memcpy(socket_addr.sun.sun_path, socket_path, path_length); /* Support for abstract socket */ if (socket_addr.sun.sun_path[0] == '@') socket_addr.sun.sun_path[0] = 0; fd = socket(AF_UNIX, SOCK_DGRAM|SOCK_CLOEXEC, 0); if (fd < 0) return -errno; if (connect(fd, &socket_addr.sa, offsetof(struct sockaddr_un, sun_path) + path_length) != 0) return -errno; ssize_t written = write(fd, message, message_length); if (written != (ssize_t) message_length) return written < 0 ? -errno : -EPROTO; return 1; /* Notified! */ } int notify_ready(void) { return notify("READY=1"); } int notify_reloading(void) { /* A buffer with length sufficient to format the maximum UINT64 value. */ char reload_message[sizeof("RELOADING=1\nMONOTONIC_USEC=18446744073709551615")]; struct timespec ts; uint64_t now; /* Notify systemd that we are reloading, including a CLOCK_MONOTONIC timestamp in usec * so that the program is compatible with a Type=notify-reload service. */ if (clock_gettime(CLOCK_MONOTONIC, &ts) < 0) return -errno; if (ts.tv_sec < 0 || ts.tv_nsec < 0 || (uint64_t) ts.tv_sec > (UINT64_MAX - (ts.tv_nsec / 1000ULL)) / 1000000ULL) return -EINVAL; now = (uint64_t) ts.tv_sec * 1000000ULL + (uint64_t) ts.tv_nsec / 1000ULL; if (snprintf(reload_message, sizeof(reload_message), "RELOADING=1\nMONOTONIC_USEC=%" PRIu64, now) < 0) return -EINVAL; return notify(reload_message); } int notify_stopping(void) { return notify("STOPPING=1"); } ================================================ FILE: sdnotify.h ================================================ #ifndef SDNOTIFY_H #define SDNOTIFY_H int notify_ready(void); int notify_reloading(void); int notify_stopping(void); #endif // SDNOTIFY_H ================================================ FILE: session.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "session.h" #include "client.h" #include "threadglobals.h" #include "threaddata.h" #include "exceptions.h" #include "plugin.h" #include "settings.h" Session::Session(const std::string &clientid, const std::string &username, const std::optional &fmq_client_group_id) : client_id(clientid), username(username), fmq_client_group_id(fmq_client_group_id), // Sessions also get defaults from the handleConnect() method, but when you create sessions elsewhere, we do need some sensible defaults. qos(ThreadGlobals::getSettings()->maxQosMsgPendingPerClient) { this->sessionExpiryInterval = ThreadGlobals::getSettings()->expireSessionsAfterSeconds; } void Session::QoSData::increaseFlowControlQuota() { flowControlQuota++; flowControlQuota = std::min(flowControlQuota, flowControlCealing); } void Session::QoSData::increaseFlowControlQuota(int n) { flowControlQuota += n; flowControlQuota = std::min(flowControlQuota, flowControlCealing); } void Session::QoSData::clearExpiredMessagesFromQueue() { const auto now = std::chrono::steady_clock::now(); if (lastExpiredMessagesAt + std::chrono::seconds(1) > now) return; lastExpiredMessagesAt = now; const int n = qosPacketQueue.clearExpiredMessages(); increaseFlowControlQuota(n); } /** * @brief get next packet ID and decrease the flow control counter. Remember to increase the flow control counter in the proper places. * @return * * You should also check that flowControl is higher than 0 before you use this. */ uint16_t Session::QoSData::getNextPacketId() { nextPacketId++; nextPacketId = std::max(nextPacketId, 1); assert(flowControlQuota > 0); flowControlQuota--; return nextPacketId; } Session::~Session() { logger->log(LOG_DEBUG) << "Session destructor of session with client ID '" << this->client_id << "'."; } /** * @brief Session::makeSharedClient get the client of the session, or a null when it has no active current client. * @return Returns shared_ptr, which can contain null when the client has disconnected. * * The lock() operation is atomic and therefore is the only way to get the current active client without race condition, because * typically, this method is called from other client's threads to perform writes, so you have to check validity after * obtaining the shared pointer. */ std::shared_ptr Session::makeSharedClient() { return client.lock(); } void Session::assignActiveConnection(const std::shared_ptr &client) { this->client = client; this->willPublish = client->getWill(); this->removalQueued = false; } void Session::assignActiveConnection(const std::shared_ptr &thisSession, const std::shared_ptr &client, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval, bool clean_start) { assert(this == thisSession.get()); std::lock_guard locker(this->clientSwitchMutex); if (username != client->getUsername()) throw ProtocolError("Cannot take over session with different username", ReasonCodes::NotAuthorized); thisSession->assignActiveConnection(client); client->assignSession(thisSession); thisSession->setSessionProperties(clientReceiveMax, sessionExpiryInterval, clean_start, client->getProtocolVersion()); } /** * @brief Session::writePacket is the main way to give a client a packet -> it goes through the session. * @param packet is not const. We set the qos and packet id for each publish. This should be safe, because the packet * with original packet id and qos is not saved. This saves unnecessary copying. * @param max_qos * @param retain. Keep MQTT-3.3.1-9 in mind: existing subscribers don't get retain=1 on packets. * @param count. Reference value is updated. It's for statistics. */ PacketDropReason Session::writePacket(PublishCopyFactory ©Factory, const uint8_t max_qos, bool retainAsPublished, const uint32_t subscriptionIdentifier) { /* * We want to do as little as possible before the ACL check, because it's code that's called * exponentially for subscribers that don't have access to topics, like wildcard subscribers. */ assert(max_qos <= 2); const uint8_t effectiveQos = copyFactory.getEffectiveQos(max_qos); retainAsPublished = retainAsPublished || clientType == ClientType::Mqtt3DefactoBridge; bool effectiveRetain = copyFactory.getEffectiveRetain(retainAsPublished); const AuthResult aclResult = ThreadGlobals::getThreadData()->authentication.aclCheck( client_id, username, copyFactory.getTopic(), copyFactory.getSubtopics(), "", copyFactory.getPayload(), AclAccess::read, effectiveQos, effectiveRetain, copyFactory.getCorrelationData(), copyFactory.getResponseTopic(), copyFactory.getContentType(), copyFactory.getExpiresAt(), copyFactory.getUserProperties()); if (aclResult != AuthResult::success) { return PacketDropReason::AuthDenied; } const std::shared_ptr c = makeSharedClient(); std::optional topic_override; { // The size check is to prevent making "local/haha/" into "" when the local_prefix is "local/haha/" if (local_prefix && startsWith(copyFactory.getTopic(), *local_prefix) && copyFactory.getTopic().size() > local_prefix->size()) { topic_override = copyFactory.getTopic(); topic_override->erase(0, local_prefix->length()); } if (remote_prefix) { const std::string &to_append = topic_override.value_or(copyFactory.getTopic()); topic_override = *remote_prefix + to_append; } } uint16_t pack_id = 0; if (__builtin_expect(effectiveQos > 0, 0)) { const Settings *settings = ThreadGlobals::getSettings(); MutexLocked qos_locked = qos.lock(); // We don't clear expired messages for online clients. It would slow down the 'happy flow' and those packets are already in the output // buffer, so we can't clear them anyway. if (!c) { qos_locked->clearExpiredMessagesFromQueue(); } if (qos_locked->flowControlQuota <= 0 || (qos_locked->qosPacketQueue.getByteSize() >= settings->maxQosBytesPendingPerClient && qos_locked->qosPacketQueue.size() > 0)) { if (qos_locked->QoSLogPrintedAtId != qos_locked->nextPacketId) { if (c) { logger->log(LOG_WARNING) << "Dropping QoS message(s) for on-line client '" << client_id << "', because it hasn't seen " "enough PUBACK/PUBCOMP/PUBRECs to release places " "or it exceeded the queue size. You could increase 'max_qos_msg_pending_per_client' " "or 'max_qos_bytes_pending_per_client' (but this is also subject the client's 'receive max')."; } else { logger->log(LOG_WARNING) << "Dropping QoS message(s) for off-line client '" << client_id << "', because the limit has been reached. " "You can increase 'max_qos_msg_pending_per_client' and/or 'max_qos_bytes_pending_per_client' to buffer more."; } qos_locked->QoSLogPrintedAtId = qos_locked->nextPacketId; } return PacketDropReason::QoSTODOSomethingSomething; } pack_id = qos_locked->getNextPacketId(); if (!destroyOnDisconnect) qos_locked->qosPacketQueue.queuePublish(copyFactory, pack_id, effectiveQos, effectiveRetain, subscriptionIdentifier, topic_override); } PacketDropReason return_value = PacketDropReason::ClientOffline; if (c) { if (!c->isRetainedAvailable()) effectiveRetain = false; return_value = c->writeMqttPacketAndBlameThisClient(copyFactory, effectiveQos, pack_id, effectiveRetain, subscriptionIdentifier, topic_override); } return return_value; } /** * @brief Session::clearQosMessage clears a QOS message from the queue. Note that in QoS 2, that doesn't complete the handshake. * @param packet_id * @param qosHandshakeEnds can be set to true when you know the QoS handshake ends, (like) when PUBREC contains an error. * @return whether the packet_id in question was found. */ bool Session::clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds) { bool result = false; MutexLocked qos_locked = qos.lock(); if (logger->wouldLog(LOG_PUBLISH)) logger->logf(LOG_PUBLISH, "Clearing QoS message for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, qos_locked->qosPacketQueue.size()); if (!destroyOnDisconnect) result = qos_locked->qosPacketQueue.erase(packet_id); else { result = true; } if (qosHandshakeEnds && result) { qos_locked->increaseFlowControlQuota(); } return result; } /** * @brief Session::sendAllPendingQosData sends pending publishes and QoS2 control packets. * * [MQTT-4.4.0-1] (about MQTT 3.1.1): "When a Client reconnects with CleanSession set to 0, both the Client and Server MUST * re-send any unacknowledged PUBLISH Packets (where QoS > 0) and PUBREL Packets using their original Packet Identifiers. This * is the only circumstance where a Client or Server is REQUIRED to redeliver messages." * * Only MQTT 3.1 requires retransmission. MQTT 3.1.1 and MQTT 5 only send on reconnect. At time of writing this comment, * FlashMQ doesn't have a retransmission system. I don't think I want to implement one for the sake of 3.1 compliance, * because it's just not that great an idea in terms of server load and quality of modern TCP. However, receiving clients * can still decide to drop packets, like when their buffers are full. The clients from where the packet originates will * never know that, because IT will have received the PUBACK from FlashMQ. The QoS system is not between publisher * and subscriber. Users are required to implement something themselves. */ void Session::sendAllPendingQosData() { Authentication &authentication = ThreadGlobals::getThreadData()->authentication; std::shared_ptr c = makeSharedClient(); if (c) { std::vector, uint16_t>> copiedPublishes; std::vector copiedQoS2Ids; { MutexLocked qos_locked = qos.lock(); std::shared_ptr qp_ = qos_locked->qosPacketQueue.getTail(); while (qp_) { std::shared_ptr qp = qp_; qp_ = qp_->next.lock(); Publish &pub = qp->getPublish(); if (pub.hasExpired() || (authentication.aclCheck(pub, pub.payload) != AuthResult::success)) { qos_locked->qosPacketQueue.erase(qp->getPacketId()); continue; } if (qos_locked->flowControlQuota <= 0) { logger->logf(LOG_WARNING, "Dropping QoS message(s) for client '%s', because it exceeds its receive maximum.", client_id.c_str()); qos_locked->qosPacketQueue.erase(qp->getPacketId()); continue; } qos_locked->flowControlQuota--; copiedPublishes.emplace_back(pub, qp->getTopicOverride(), qp->getPacketId()); } for (const uint16_t packet_id : qos_locked->outgoingQoS2MessageIds) { copiedQoS2Ids.push_back(packet_id); } } for(std::tuple, uint16_t> &p : copiedPublishes) { Publish &pub = std::get(p); PublishCopyFactory fac(&pub); const bool retain = !c->isRetainedAvailable() ? false : pub.retain; c->writeMqttPacketAndBlameThisClient(fac, pub.qos, std::get(p), retain, pub.subscriptionIdentifier, std::get>(p)); } for(uint16_t id : copiedQoS2Ids) { PubResponse pubRel(c->getProtocolVersion(), PacketType::PUBREL, ReasonCodes::Success, id); MqttPacket packet(pubRel); c->writeMqttPacketAndBlameThisClient(packet); } } } bool Session::hasActiveClient() { return !client.expired(); } void Session::clearWill() { this->willPublish.reset(); } std::shared_ptr Session::getWill() { return this->willPublish.getCopy(); } void Session::setWill(WillPublish &&pub) { this->willPublish = std::make_shared(std::move(pub)); } void Session::setClientType(ClientType val) { this->clientType = val; } void Session::addIncomingQoS2MessageId(uint16_t packet_id) { assert(packet_id > 0); MutexLocked qos_locked = qos.lock(); qos_locked->incomingQoS2MessageIds.insert(packet_id); } bool Session::incomingQoS2MessageIdInTransit(uint16_t packet_id) { assert(packet_id > 0); MutexLocked qos_locked = qos.lock(); const auto it = qos_locked->incomingQoS2MessageIds.find(packet_id); return it != qos_locked->incomingQoS2MessageIds.end(); } bool Session::removeIncomingQoS2MessageId(u_int16_t packet_id) { assert(packet_id > 0); MutexLocked qos_locked = qos.lock(); if (logger->wouldLog(LOG_PUBLISH)) { logger->logf(LOG_PUBLISH, "As QoS 2 receiver: publish released (PUBREL) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, qos_locked->incomingQoS2MessageIds.size()); } bool result = false; const auto it = qos_locked->incomingQoS2MessageIds.find(packet_id); if (it != qos_locked->incomingQoS2MessageIds.end()) { qos_locked->incomingQoS2MessageIds.erase(it); result = true; } return result; } void Session::addOutgoingQoS2MessageId(uint16_t packet_id) { MutexLocked qos_locked = qos.lock(); qos_locked->outgoingQoS2MessageIds.insert(packet_id); } void Session::removeOutgoingQoS2MessageId(u_int16_t packet_id) { MutexLocked qos_locked = qos.lock(); if (logger->wouldLog(LOG_PUBLISH)) { logger->logf(LOG_PUBLISH, "As QoS 2 sender: publish complete (PUBCOMP) for '%s', packet id '%d'. Left in queue: %d", client_id.c_str(), packet_id, qos_locked->outgoingQoS2MessageIds.size()); } const auto it = qos_locked->outgoingQoS2MessageIds.find(packet_id); if (it != qos_locked->outgoingQoS2MessageIds.end()) { qos_locked->outgoingQoS2MessageIds.erase(it); qos_locked->increaseFlowControlQuota(); } } void Session::increaseFlowControlQuotaLocked() { MutexLocked qos_locked = qos.lock(); qos_locked->increaseFlowControlQuota(); } uint16_t Session::getNextPacketIdLocked() { MutexLocked qos_locked = qos.lock(); return qos_locked->getNextPacketId(); } void Session::resetQoSData() { MutexLocked qos_locked = qos.lock(); QoSData &q = *qos_locked; QoSData new_q(ThreadGlobals::getSettings()->maxQosMsgPendingPerClient); q = std::move(new_q); } /** * @brief Session::getDestroyOnDisconnect * @return * * MQTT5: Setting Clean Start to 1 and a Session Expiry Interval of 0, is equivalent to setting CleanSession to 1 in the MQTT Specification Version 3.1.1. */ bool Session::getDestroyOnDisconnect() const { return destroyOnDisconnect; } void Session::setSessionProperties(uint16_t clientReceiveMax, uint32_t sessionExpiryInterval, bool clean_start, ProtocolVersion protocol_version) { MutexLocked qos_locked = qos.lock(); // Flow control is not part of the session state, so/but/and because we call this function every time a client connects, we reset it properly. qos_locked->flowControlQuota = clientReceiveMax; qos_locked->flowControlCealing = clientReceiveMax; this->sessionExpiryInterval = sessionExpiryInterval; if (protocol_version <= ProtocolVersion::Mqtt311) destroyOnDisconnect = clean_start; else destroyOnDisconnect = sessionExpiryInterval == 0; } void Session::setSessionExpiryInterval(uint32_t newVal) { // This is only the case on disconnect, but there's no other place where this method is called (so far...) if (this->sessionExpiryInterval == 0 && newVal > 0) { throw ProtocolError("Setting a non-zero session expiry after it was 0 initially is a protocol error.", ReasonCodes::ProtocolError); } this->sessionExpiryInterval = newVal; } void Session::setQueuedRemovalAt() { this->removalQueuedAt = std::chrono::steady_clock::now(); this->removalQueued = true; } uint32_t Session::getSessionExpiryInterval() const { return this->sessionExpiryInterval; } uint32_t Session::getCurrentSessionExpiryInterval() { if (!this->removalQueued || hasActiveClient()) return this->sessionExpiryInterval; const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - this->removalQueuedAt); const uint32_t ageInSeconds = age.count(); const uint32_t result = ageInSeconds <= this->sessionExpiryInterval ? this->sessionExpiryInterval - age.count() : 0; return result; } void Session::setLocalPrefix(const std::optional &s) { // A work-around for prefixes of session being accessed cross-thread. The std::optional sets the 'engaged' flag last, so this works. if (this->local_prefix) return; this->local_prefix = s; } void Session::setRemotePrefix(const std::optional &s) { // A work-around for prefixes of session being accessed cross-thread. The std::optional sets the 'engaged' flag last, so this works. if (this->remote_prefix) return; this->remote_prefix = s; } ================================================ FILE: session.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SESSION_H #define SESSION_H #include #include #include #include #include #include "forward_declarations.h" #include "logger.h" #include "sessionsandsubscriptionsdb.h" #include "qospacketqueue.h" #include "publishcopyfactory.h" #include "lockedweakptr.h" #include "lockedsharedptr.h" #include "mutexowned.h" class Session { #ifdef TESTING friend class MainTests; #endif struct QoSData { QoSPublishQueue qosPacketQueue; std::set incomingQoS2MessageIds; std::set outgoingQoS2MessageIds; uint16_t nextPacketId = 0; uint16_t QoSLogPrintedAtId = 0; std::chrono::time_point lastExpiredMessagesAt = std::chrono::steady_clock::now(); /** * Even though flow control data is not part of the session state, I'm keeping it here because there are already * mutexes that they can be placed under, saving additional synchronization. */ int flowControlCealing = 0xFFFF; int flowControlQuota = 0xFFFF; QoSData(const uint16_t maxQosMsgPendingPerClient) : flowControlQuota(maxQosMsgPendingPerClient) { } void clearExpiredMessagesFromQueue(); void increaseFlowControlQuota(); void increaseFlowControlQuota(int n); uint16_t getNextPacketId(); }; friend class SessionsAndSubscriptionsDB; /* * THREADING WARNING * * Sessions are accessed cross-thread. Use unprotected primitives, atomics, MutexOwned objects, or * const-constructed objects with careful consideration. */ LockedWeakPtr client; const std::string client_id; const std::string username; const std::optional fmq_client_group_id; MutexOwned qos; // Note, we set these write-once to avoid threading issues. As a work-around to avoid mutexing in a hot path. std::optional local_prefix; std::optional remote_prefix; std::mutex clientSwitchMutex; uint32_t sessionExpiryInterval = 0; bool destroyOnDisconnect = false; LockedSharedPtr willPublish; bool removalQueued = false; ClientType clientType = ClientType::Normal; std::chrono::time_point removalQueuedAt; Logger *logger = Logger::getInstance(); Session(const Session &other) = delete; public: Session(const std::string &clientid, const std::string &username, const std::optional &fmq_client_group_id); Session(Session &&other) = delete; ~Session(); const std::string &getClientId() const { return client_id; } const std::string &getUsername() const { return username; } const std::optional &getFmqClientGroupId() const { return fmq_client_group_id; } std::shared_ptr makeSharedClient(); void assignActiveConnection(const std::shared_ptr &client); void assignActiveConnection(const std::shared_ptr &thisSession, const std::shared_ptr &client, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval, bool clean_start); PacketDropReason writePacket(PublishCopyFactory ©Factory, const uint8_t max_qos, bool retainAsPublished, const uint32_t subscriptionIdentifier); bool clearQosMessage(uint16_t packet_id, bool qosHandshakeEnds); void sendAllPendingQosData(); bool hasActiveClient(); void clearWill(); std::shared_ptr getWill(); void setWill(WillPublish &&pub); ClientType getClientType() const { return clientType; } void setClientType(ClientType val); void addIncomingQoS2MessageId(uint16_t packet_id); bool incomingQoS2MessageIdInTransit(uint16_t packet_id); bool removeIncomingQoS2MessageId(u_int16_t packet_id); void addOutgoingQoS2MessageId(uint16_t packet_id); void removeOutgoingQoS2MessageId(u_int16_t packet_id); void increaseFlowControlQuotaLocked(); uint16_t getNextPacketIdLocked(); void resetQoSData(); bool getDestroyOnDisconnect() const; void setSessionProperties(uint16_t clientReceiveMax, uint32_t sessionExpiryInterval, bool clean_start, ProtocolVersion protocol_version); void setSessionExpiryInterval(uint32_t newVal); void setQueuedRemovalAt(); uint32_t getSessionExpiryInterval() const; uint32_t getCurrentSessionExpiryInterval(); void setLocalPrefix(const std::optional &s); void setRemotePrefix(const std::optional &s); const std::optional &getLocalPrefix() const { return local_prefix; } const std::optional &getRemotePrefix() const { return remote_prefix; } }; #endif // SESSION_H ================================================ FILE: sessionsandsubscriptionsdb.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "sessionsandsubscriptionsdb.h" #include #include #include "mqttpacket.h" #include "threadglobals.h" #include "utils.h" #include "client.h" #include "session.h" #include "settings.h" #include SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionidentifier) : clientId(clientId), qos(qos), noLocal(noLocal), retainAsPublished(retainAsPublished), subscriptionidentifier(subscriptionidentifier) { } SubscriptionForSerializing::SubscriptionForSerializing(const std::string &clientId, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionidentifier, const std::string &shareName) : clientId(clientId), qos(qos), shareName(shareName), noLocal(noLocal), retainAsPublished(retainAsPublished), subscriptionidentifier(subscriptionidentifier) { } SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionidentifier) : clientId(std::move(clientId)), qos(qos), noLocal(noLocal), retainAsPublished(retainAsPublished), subscriptionidentifier(subscriptionidentifier) { } SubscriptionForSerializing::SubscriptionForSerializing(const std::string &&clientId, SubscriptionOptionsByte options, uint32_t subscriptionidentifier, const std::string &shareName) : clientId(std::move(clientId)), qos(options.getQos()), shareName(shareName), noLocal(options.getNoLocal()), retainAsPublished(options.getRetainAsPublished()), subscriptionidentifier(subscriptionidentifier) { } SubscriptionOptionsByte SubscriptionForSerializing::getSubscriptionOptions() const { return SubscriptionOptionsByte(qos, noLocal, retainAsPublished, RetainHandling::SendRetainedMessagesAtSubscribe); } SessionsAndSubscriptionsDB::SessionsAndSubscriptionsDB(const std::string &filePath) : PersistenceFile(filePath) { } void SessionsAndSubscriptionsDB::openWrite() { PersistenceFile::openWrite(MAGIC_STRING_SESSION_FILE_V8); } void SessionsAndSubscriptionsDB::openRead() { const std::string current_magic_string(MAGIC_STRING_SESSION_FILE_V8); PersistenceFile::openRead(current_magic_string); if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V1) readVersion = ReadVersion::v1; else if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V2) readVersion = ReadVersion::v2; else if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V3) readVersion = ReadVersion::v3; else if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V4) readVersion = ReadVersion::v4; else if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V5) readVersion = ReadVersion::v5; else if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V6) readVersion = ReadVersion::v6; else if (detectedVersionString == MAGIC_STRING_SESSION_FILE_V7) readVersion = ReadVersion::v7; else if (detectedVersionString == current_magic_string) readVersion = ReadVersion::v8; else throw std::runtime_error("Unknown file version."); } SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readDataV3V4V5V6V7() { const Settings &settings = *ThreadGlobals::getSettings(); SessionsAndSubscriptionsResult result; while (!feof(f)) { bool eofFound = false; const int64_t fileSavedAt = readInt64(eofFound); if (eofFound) continue; const int64_t now_epoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); const int64_t persistence_state_age = fileSavedAt > now_epoch ? 0 : now_epoch - fileSavedAt; logger->log(LOG_DEBUG) << "Session file was saved at " << fileSavedAt << ". That's " << persistence_state_age << " seconds ago."; const uint32_t nrOfSessions = readUint32(eofFound); if (eofFound) continue; std::vector reserved(RESERVED_SPACE_SESSIONS_DB_V2); CirBuf cirbuf(1024); std::shared_ptr dummyThreadData; // which thread am I going get/use here? std::shared_ptr dummyClient(std::make_shared(ClientType::Normal, -1, dummyThreadData, FmqSsl(), ConnectionProtocol::Mqtt, HaProxyMode::Off, nullptr, settings, false)); dummyClient->setClientProperties(ProtocolVersion::Mqtt5, "Dummyforloadingqueuedqos", {}, "nobody", true, 60); for (uint32_t i = 0; i < nrOfSessions; i++) { readCheck(buf.data(), 1, RESERVED_SPACE_SESSIONS_DB_V2, f); std::string username = readString(eofFound); std::string clientId = readString(eofFound); std::optional fmq_client_group_id; if (readVersion >= ReadVersion::v8) fmq_client_group_id = readOptionalString(eofFound); std::shared_ptr ses = std::make_shared(clientId, username, fmq_client_group_id); result.sessions.push_back(ses); logger->logf(LOG_DEBUG, "Loading session '%s'.", ses->getClientId().c_str()); { MutexLocked qos_locked = ses->qos.lock(); const uint32_t nrOfQueuedQoSPackets = readUint32(eofFound); for (uint32_t i = 0; i < nrOfQueuedQoSPackets; i++) { const uint16_t fixed_header_length = readUint16(eofFound); const uint16_t id = readUint16(eofFound); const uint32_t originalPubAge = readUint32(eofFound); const uint32_t packlen = readUint32(eofFound); const std::string sender_clientid = readString(eofFound); const std::string sender_username = readString(eofFound); std::optional topic_override; if (readVersion >= ReadVersion::v7) topic_override = readOptionalString(eofFound); assert(id > 0); cirbuf.reset(); cirbuf.ensureFreeSpace(packlen + 32); readCheck(cirbuf.headPtr(), 1, packlen, f); cirbuf.advanceHead(packlen); MqttPacket pack(cirbuf.readToVector(packlen), fixed_header_length, dummyClient); pack.parsePublishData(dummyClient); Publish pub(pack.getPublishData()); pub.client_id = sender_clientid; pub.username = sender_username; const uint32_t newPubAge = persistence_state_age + originalPubAge; if (pub.expireInfo) pub.expireInfo->createdAt = timepointFromAge(newPubAge); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Loaded QoS %d message for topic '%s' for session '%s'.", pub.qos, pub.topic.c_str(), ses->getClientId().c_str()); qos_locked->qosPacketQueue.queuePublish(std::move(pub), id, topic_override); } const uint32_t nrOfIncomingPacketIds = readUint32(eofFound); for (uint32_t i = 0; i < nrOfIncomingPacketIds; i++) { uint16_t id = readUint16(eofFound); assert(id > 0); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Loaded incomming QoS2 message id %d.", id); qos_locked->incomingQoS2MessageIds.insert(id); } const uint32_t nrOfOutgoingPacketIds = readUint32(eofFound); for (uint32_t i = 0; i < nrOfOutgoingPacketIds; i++) { uint16_t id = readUint16(eofFound); assert(id > 0); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Loaded outgoing QoS2 message id %d.", id); qos_locked->outgoingQoS2MessageIds.insert(id); } const uint16_t nextPacketId = readUint16(eofFound); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Loaded next packetid %d.", qos_locked->nextPacketId); qos_locked->nextPacketId = nextPacketId; } const uint32_t originalSessionExpiryInterval = readUint32(eofFound); const uint32_t compensatedSessionExpiry = persistence_state_age > originalSessionExpiryInterval ? 0 : originalSessionExpiryInterval - persistence_state_age; const uint32_t sessionExpiryInterval = std::min(compensatedSessionExpiry, settings.getExpireSessionAfterSeconds()); // We will set the session expiry interval as it would have had time continued. If a connection picks up session, it will update // it with a more relevant value. // The protocol version 5 is just dummy, to get the behavior I want. ses->setSessionProperties(0xFFFF, sessionExpiryInterval, 0, ProtocolVersion::Mqtt5); const uint16_t hasWill = readUint16(eofFound); if (hasWill) { const uint16_t fixed_header_length = readUint16(eofFound); const uint32_t originalWillDelay = readUint32(eofFound); const uint32_t originalWillQueueAge = readUint32(eofFound); const uint32_t newWillDelayAfterMaybeAlreadyBeingQueued = originalWillQueueAge < originalWillDelay ? originalWillDelay - originalWillQueueAge : 0; const uint32_t packlen = readUint32(eofFound); const std::string sender_clientid = readString(eofFound); const std::string sender_username = readString(eofFound); const uint32_t stateAgecompensatedWillDelay = persistence_state_age > newWillDelayAfterMaybeAlreadyBeingQueued ? 0 : newWillDelayAfterMaybeAlreadyBeingQueued - persistence_state_age; cirbuf.reset(); cirbuf.ensureFreeSpace(packlen + 32); readCheck(cirbuf.headPtr(), 1, packlen, f); cirbuf.advanceHead(packlen); MqttPacket publishpack(cirbuf.readToVector(packlen), fixed_header_length, dummyClient); publishpack.parsePublishData(dummyClient); WillPublish willPublish = publishpack.getPublishData(); willPublish.will_delay = stateAgecompensatedWillDelay; willPublish.client_id = sender_clientid; willPublish.username = sender_username; if (settings.willsEnabled) ses->setWill(std::move(willPublish)); } } const uint32_t nrOfSubscriptions = readUint32(eofFound); for (uint32_t i = 0; i < nrOfSubscriptions; i++) { const std::string topic = readString(eofFound); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Loading subscriptions to topic '%s'.", topic.c_str()); const uint32_t nrOfClientIds = readUint32(eofFound); for (uint32_t i = 0; i < nrOfClientIds; i++) { std::string sharename; if (readVersion >= ReadVersion::v4) sharename = readString(eofFound); std::string clientId = readString(eofFound); const SubscriptionOptionsByte subscriptionOptions(readUint8(eofFound)); uint32_t subscription_identifier = 0; if (readVersion >= ReadVersion::v6) subscription_identifier = readUint32(eofFound); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Loading session '%s' subscription to '%s' QoS %d.", clientId.c_str(), topic.c_str(), subscriptionOptions.getQos()); SubscriptionForSerializing sub(std::move(clientId), subscriptionOptions, subscription_identifier, sharename); result.subscriptions[topic].push_back(std::move(sub)); } } } return result; } void SessionsAndSubscriptionsDB::writeRowHeader() { } void SessionsAndSubscriptionsDB::saveData(const std::vector> &sessions, const std::unordered_map> &subscriptions) { if (!f) return; char reserved[RESERVED_SPACE_SESSIONS_DB_V2]; std::memset(reserved, 0, RESERVED_SPACE_SESSIONS_DB_V2); const int64_t now_epoch = std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); logger->log(LOG_DEBUG) << "Saving current time stamp " << now_epoch << "."; writeInt64(now_epoch); std::vector> sessionsToSave; // Sessions created with clean session need to be destroyed when disconnecting, so no point in saving them. std::copy_if(sessions.begin(), sessions.end(), std::back_inserter(sessionsToSave), [](const std::shared_ptr &ses) { return ses && !ses->destroyOnDisconnect; }); writeUint32(sessionsToSave.size()); CirBuf cirbuf(1024); for (std::shared_ptr &ses : sessionsToSave) { { MutexLocked qos_locked = ses->qos.lock(); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Saving session '%s'.", ses->getClientId().c_str()); writeRowHeader(); writeCheck(reserved, 1, RESERVED_SPACE_SESSIONS_DB_V2, f); writeString(ses->username); writeString(ses->client_id); writeOptionalString(ses->fmq_client_group_id); const size_t qosPacketsExpected = qos_locked->qosPacketQueue.size(); size_t qosPacketsCounted = 0; writeUint32(qosPacketsExpected); std::shared_ptr qp = qos_locked->qosPacketQueue.getTail(); while (qp) { qosPacketsCounted++; Publish &pub = qp->getPublish(); assert(!pub.skipTopic); assert(pub.topicAlias == 0); if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Saving QoS %d message for topic '%s'.", pub.qos, pub.topic.c_str()); MqttPacket pack(ProtocolVersion::Mqtt5, pub); pack.setPacketId(qp->getPacketId()); const uint32_t packSize = pack.getSizeIncludingNonPresentHeader(); cirbuf.reset(); cirbuf.ensureFreeSpace(packSize + 32); pack.readIntoBuf(cirbuf); const uint32_t pubAge = pub.expireInfo ? ageFromTimePoint(pub.expireInfo.value().createdAt) : 0; writeUint16(pack.getFixedHeaderLength()); writeUint16(qp->getPacketId()); writeUint32(pubAge); writeUint32(packSize); writeString(pub.client_id); writeString(pub.username); writeOptionalString(qp->getTopicOverride()); writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f); qp = qp->next.lock(); } assert(qosPacketsExpected == qosPacketsCounted); writeUint32(qos_locked->incomingQoS2MessageIds.size()); for (uint16_t id : qos_locked->incomingQoS2MessageIds) { if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Writing incomming QoS2 message id %d.", id); writeUint16(id); } writeUint32(qos_locked->outgoingQoS2MessageIds.size()); for (uint16_t id : qos_locked->outgoingQoS2MessageIds) { if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Writing outgoing QoS2 message id %d.", id); writeUint16(id); } if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Writing next packetid %d.", qos_locked->nextPacketId); writeUint16(qos_locked->nextPacketId); writeUint32(ses->getCurrentSessionExpiryInterval()); std::shared_ptr will = ses->getWill(); const bool hasWillThatShouldSurviveRestart = will.operator bool() && will->will_delay > 0; writeUint16(static_cast(hasWillThatShouldSurviveRestart)); if (hasWillThatShouldSurviveRestart) { MqttPacket willpacket(ProtocolVersion::Mqtt5, *will); // Dummy, to please the parser on reading. if (will->qos > 0) willpacket.setPacketId(666); const uint32_t packSize = willpacket.getSizeIncludingNonPresentHeader(); cirbuf.reset(); cirbuf.ensureFreeSpace(packSize + 32); willpacket.readIntoBuf(cirbuf); writeUint16(willpacket.getFixedHeaderLength()); writeUint32(will->will_delay); writeUint32(will->getQueuedAtAge()); writeUint32(packSize); writeString(will->client_id); writeString(will->username); writeCheck(cirbuf.tailPtr(), 1, cirbuf.usedBytes(), f); } } // Keep flushing outside of session lock, to reduce the amount of flushing while holding that lock. fflush(f); } writeUint32(subscriptions.size()); for (auto &pair : subscriptions) { const std::string &topic = pair.first; const std::list &subscriptions = pair.second; if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Writing subscriptions to topic '%s'.", topic.c_str()); writeString(topic); writeUint32(subscriptions.size()); for (const SubscriptionForSerializing &subscription : subscriptions) { if (!subscription.shareName.empty()) { if (logger->wouldLog(LOG_DEBUG)) { logger->logf(LOG_DEBUG, "Saving session '%s' subscription with sharename '%s' to '%s' QoS %d.", subscription.clientId.c_str(), subscription.shareName.c_str(), topic.c_str(), subscription.qos); } } else { if (logger->wouldLog(LOG_DEBUG)) logger->logf(LOG_DEBUG, "Saving session '%s' subscription to '%s' QoS %d.", subscription.clientId.c_str(), topic.c_str(), subscription.qos); } writeString(subscription.shareName); writeString(subscription.clientId); writeUint8(subscription.getSubscriptionOptions().b); writeUint32(subscription.subscriptionidentifier); // Added in file version 6. } } fflush(f); } SessionsAndSubscriptionsResult SessionsAndSubscriptionsDB::readData() { SessionsAndSubscriptionsResult defaultResult; if (!f) return defaultResult; if (readVersion == ReadVersion::v1) logger->logf(LOG_WARNING, "File '%s' is version 1, an internal development version that was never finalized. Not reading.", getFilePath().c_str()); if (readVersion == ReadVersion::v2) logger->logf(LOG_WARNING, "File '%s' is version 2, an internal development version that was never finalized. Not reading.", getFilePath().c_str()); if (readVersion >= ReadVersion::v3) return readDataV3V4V5V6V7(); return defaultResult; } ================================================ FILE: sessionsandsubscriptionsdb.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SESSIONSANDSUBSCRIPTIONSDB_H #define SESSIONSANDSUBSCRIPTIONSDB_H #include #include #include "forward_declarations.h" #include "persistencefile.h" #include "types.h" #define MAGIC_STRING_SESSION_FILE_V1 "FlashMQRetainedDBv1" // That this is called 'retained' was a bug... #define MAGIC_STRING_SESSION_FILE_V2 "FlashMQSessionDBv2" #define MAGIC_STRING_SESSION_FILE_V3 "FlashMQSessionDBv3" #define MAGIC_STRING_SESSION_FILE_V4 "FlashMQSessionDBv4" #define MAGIC_STRING_SESSION_FILE_V5 "FlashMQSessionDBv5" #define MAGIC_STRING_SESSION_FILE_V6 "FlashMQSessionDBv6" #define MAGIC_STRING_SESSION_FILE_V7 "FlashMQSessionDBv7" #define MAGIC_STRING_SESSION_FILE_V8 "FlashMQSessionDBv8" #define RESERVED_SPACE_SESSIONS_DB_V2 32 /** * @brief The SubscriptionForSerializing struct contains the fields we're interested in when saving a subscription. */ struct SubscriptionForSerializing { const std::string clientId; const uint8_t qos = 0; const std::string shareName; const bool noLocal = false; const bool retainAsPublished = false; const uint32_t subscriptionidentifier = 0; SubscriptionForSerializing(const std::string &clientId, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionidentifier); SubscriptionForSerializing(const std::string &clientId, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionidentifier, const std::string &shareName); SubscriptionForSerializing(const std::string &&clientId, uint8_t qos, bool noLocal, bool retainAsPublished, uint32_t subscriptionidentifier); SubscriptionForSerializing(const std::string &&clientId, SubscriptionOptionsByte options, uint32_t subscriptionidentifier, const std::string &shareName); SubscriptionOptionsByte getSubscriptionOptions() const; }; struct SessionsAndSubscriptionsResult { std::list> sessions; std::unordered_map> subscriptions; }; class SessionsAndSubscriptionsDB : private PersistenceFile { enum class ReadVersion { unknown, v1, v2, v3, v4, v5, v6, v7, v8 }; ReadVersion readVersion = ReadVersion::unknown; SessionsAndSubscriptionsResult readDataV3V4V5V6V7(); void writeRowHeader(); public: SessionsAndSubscriptionsDB(const std::string &filePath); void openWrite(); void openRead(); void saveData(const std::vector> &sessions, const std::unordered_map> &subscriptions); SessionsAndSubscriptionsResult readData(); }; #endif // SESSIONSANDSUBSCRIPTIONSDB_H ================================================ FILE: settings.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "exceptions.h" #include "settings.h" #include "utils.h" void checkUniqueBridgeNames(const std::list &bridges) { std::unordered_set prefixes; for (const BridgeConfig &bridge : bridges) { const std::string &prefix = bridge.clientidPrefix; if (prefixes.find(bridge.clientidPrefix) != prefixes.end()) { std::string err = formatString("Value '%s' is not unique. All bridge prefixes must be unique.", prefix.c_str()); throw ConfigFileException(err); } prefixes.insert(prefix); } } Settings::Settings() { persistenceDataToSave.setAll(); } AuthOptCompatWrap &Settings::getAuthOptsCompat() { return authOptCompatWrap; } std::unordered_map &Settings::getFlashmqpluginOpts() { return this->flashmqpluginOpts; } std::string Settings::getRetainedMessagesDBFile() const { if (storageDir.empty()) return ""; std::string path = formatString("%s/%s", storageDir.c_str(), "retained.db"); return path; } std::string Settings::getSessionsDBFile() const { if (storageDir.empty()) return ""; std::string path = formatString("%s/%s", storageDir.c_str(), "sessions.db"); return path; } std::string Settings::getBridgeNamesDBFile() const { if (storageDir.empty()) return ""; std::string path = formatString("%s/%s", storageDir.c_str(), "bridgenames.db"); return path; } std::string Settings::getGeneratedShareNamesFilePath() const { if (storageDir.empty()) return ""; std::string path = formatString("%s/%s", storageDir.c_str(), "sharenames.db"); return path; } /** * @brief because 0 means 'forever', we have to translate this. * @return */ uint32_t Settings::getExpireSessionAfterSeconds() const { return expireSessionsAfterSeconds > 0 ? expireSessionsAfterSeconds : std::numeric_limits::max(); } bool Settings::matchAddrWithSetRealIpFrom(const sockaddr *addr) const { return std::any_of(setRealIpFrom.begin(), setRealIpFrom.end(), [=](const Network &n) { return n.match(addr);}); } bool Settings::matchAddrWithSetRealIpFrom(const sockaddr_in6 *addr) const { return matchAddrWithSetRealIpFrom(reinterpret_cast(addr)); } bool Settings::matchAddrWithSetRealIpFrom(const sockaddr_in *addr) const { return matchAddrWithSetRealIpFrom(reinterpret_cast(addr)); } std::list Settings::stealBridges() { std::list result = std::move(this->bridges); this->bridges.clear(); return result; } ================================================ FILE: settings.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SETTINGS_H #define SETTINGS_H #include #include #include #include #include "enums.h" #include "mosquittoauthoptcompatwrap.h" #include "listener.h" #include "network.h" #include "bridgeconfig.h" #include "flags.h" #define ABSOLUTE_MAX_PACKET_SIZE 268435455 #define HEARTBEAT_INTERVAL 1000 #define OVERLOAD_LOGS_MUTE_AFTER_LINES 5000 enum class RetainedMessagesMode { Enabled, EnabledWithoutPersistence, EnabledWithoutRetaining, Downgrade, Drop, DisconnectWithError }; enum class SharedSubscriptionTargeting { RoundRobin, SenderHash, First }; enum class WildcardSubscriptionDenyMode { DenyAll, DenyRetainedOnly }; enum class PersistenceDataToSave { SessionsAndSubscriptions = 1, RetainedMessages = 2, BridgeInfo = 3 }; void checkUniqueBridgeNames(const std::list &bridges); class Settings { friend class ConfigFileParser; AuthOptCompatWrap authOptCompatWrap; std::unordered_map flashmqpluginOpts; std::list bridges; public: // Actual config options with their defaults. std::string pluginPath; std::string logPath; std::optional quiet; bool allowUnsafeClientidChars = false; bool allowUnsafeUsernameChars = false; bool pluginSerializeInit = false; bool pluginSerializeAuthChecks = false; bool logSubscriptions = false; bool logPublishes = false; bool subscriptionIdentifierEnabled = true; int clientInitialBufferSize = 1024; // Must be power of 2 uint32_t maxPacketSize = ABSOLUTE_MAX_PACKET_SIZE; uint32_t clientMaxWriteBufferSize = 1048576; uint16_t maxIncomingTopicAliasValue = 65535; uint16_t maxOutgoingTopicAliasValue = 65535; uint8_t maxQos = 2; uint16_t maxStringLength = 4096; Mqtt3QoSExceedAction mqtt3QoSExceedAction = Mqtt3QoSExceedAction::Disconnect; #ifdef TESTING std::optional logDebug; LogLevel logLevel = LogLevel::Debug; #else std::optional logDebug; LogLevel logLevel = LogLevel::Info; #endif std::string mosquittoPasswordFile; std::string mosquittoAclFile; bool allowAnonymous = false; int rlimitNoFile = 1000000; uint32_t expireSessionsAfterSeconds = 1209600; std::chrono::seconds expireRetainedMessagesAfterSeconds = std::chrono::seconds(std::numeric_limits::max()); int pluginTimerPeriod = 60; std::string storageDir; int threadCount = 0; uint16_t maxQosMsgPendingPerClient = 512; uint maxQosBytesPendingPerClient = 65536; bool willsEnabled = true; uint32_t retainedMessagesDeliveryLimit = 2048; std::chrono::seconds subscriptionNodeLifetime = std::chrono::seconds(3600); uint32_t retainedMessagesNodeLimit = std::numeric_limits::max(); std::chrono::seconds retainedMessageNodeLifetime = std::chrono::seconds(0); RetainedMessagesMode retainedMessagesMode = RetainedMessagesMode::Enabled; SharedSubscriptionTargeting sharedSubscriptionTargeting = SharedSubscriptionTargeting::RoundRobin; uint16_t minimumWildcardSubscriptionDepth = 0; uint16_t maxTopicSplitDepth = 128; WildcardSubscriptionDenyMode wildcardSubscriptionDenyMode = WildcardSubscriptionDenyMode::DenyAll; bool zeroByteUsernameIsAnonymous = false; std::chrono::milliseconds maxEventLoopDrift = std::chrono::milliseconds(2000); OverloadMode overloadMode = OverloadMode::Log; std::chrono::milliseconds setRetainedMessageDeferTimeout = std::chrono::milliseconds(0); std::chrono::milliseconds setRetainedMessageDeferTimeoutSpread = std::chrono::milliseconds(1000); std::chrono::seconds saveStateInterval = std::chrono::seconds(3623); Flags persistenceDataToSave; std::list> listeners; // Default one is created later, when none are defined. std::list setRealIpFrom; AuthOptCompatWrap &getAuthOptsCompat(); std::unordered_map &getFlashmqpluginOpts(); std::string getRetainedMessagesDBFile() const; std::string getSessionsDBFile() const; std::string getBridgeNamesDBFile() const; std::string getGeneratedShareNamesFilePath() const; uint32_t getExpireSessionAfterSeconds() const; bool matchAddrWithSetRealIpFrom(const struct sockaddr *addr) const; bool matchAddrWithSetRealIpFrom(const struct sockaddr_in6 *addr) const; bool matchAddrWithSetRealIpFrom(const struct sockaddr_in *addr) const; std::list stealBridges(); Settings(); }; #endif // SETTINGS_H ================================================ FILE: sharedsubscribers.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "sharedsubscribers.h" #include SharedSubscribers::SharedSubscribers(const std::string &shareName, const std::optional &fmq_group_id) noexcept : shareName(shareName), overrideSharedSubscriptionTarget(fmq_group_id ? SharedSubscriptionTargeting::SenderHash : std::optional()) { } /** * @brief SharedSubscribers::operator [] access or create a shared subscription in a shared subscription. * @param clientid * @return * * Note that the reference returned will likely be invalidated when you call it again, so don't keep lingering references around. */ Subscription &SharedSubscribers::operator[](const std::string &clientid) { auto index_pos = index.find(clientid); if (index_pos != index.end()) { const int index = index_pos->second; assert(index < static_cast(members.size())); return members[index]; } const int newIndex = members.size(); index[clientid] = newIndex; members.emplace_back(); Subscription &r = members.back(); return r; } const Subscription *SharedSubscribers::getFirst() const { const Subscription *result = nullptr; for (const Subscription &s : members) { if (!s.session.expired()) { result = &s; break; } } return result; } const Subscription *SharedSubscribers::getNext() { const Subscription *result = nullptr; for (size_t i = 0; i < members.size(); i++) { // This counter use is not thread safe / atomic, but it doesn't matter much. const Subscription &s = members[roundRobinCounter++ % members.size()]; if (!s.session.expired()) { result = &s; break; } } return result; } const Subscription *SharedSubscribers::getNext(size_t hash) const { const Subscription *result = nullptr; if (members.empty()) return nullptr; size_t pos = hash % members.size(); for (size_t i = 0; i < members.size(); i++) { const Subscription &s = members[pos++ % members.size()]; if (!s.session.expired()) { result = &s; break; } } return result; } void SharedSubscribers::erase(const std::string &clientid) { auto index_pos = index.find(clientid); if (index_pos != index.end()) { const int index = index_pos->second; assert(index < static_cast(members.size())); members[index].reset(); } } void SharedSubscribers::purgeAndReIndex() { int i = 0; std::vector newMembers; std::unordered_map newIndex; for (auto &pair : index) { const int index = pair.second; Subscription &sub = members[index]; if (sub.session.expired()) continue; newMembers.push_back(sub); newIndex[pair.first] = i; i++; } this->members = std::move(newMembers); this->index = std::move(newIndex); } bool SharedSubscribers::empty() const { return members.empty(); } void SharedSubscribers::getForSerializing(const std::string &topic, std::unordered_map> &outputList) const { for (const Subscription &member : members) { std::shared_ptr ses = member.session.lock(); if (ses) { SubscriptionForSerializing sub(ses->getClientId(), member.qos, member.noLocal, member.retainAsPublished, member.subscriptionIdentifier, this->shareName); outputList[topic].push_back(sub); } } } ================================================ FILE: sharedsubscribers.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SHAREDSUBSCRIBERS_H #define SHAREDSUBSCRIBERS_H #include #include #include #include #include #include "forward_declarations.h" #include "subscription.h" #include "settings.h" class SharedSubscribers { #ifdef TESTING friend class MainTests; #endif std::vector members; std::unordered_map index; int roundRobinCounter = 0; public: const std::string shareName; /* * We need to ensure message ordering when using multiple fmq_group_id connections, so we stick publishers on the other * end to one connection (by using SharedSubscriptionTargeting::SenderHash). This doesn't just guarantee order, * but also thread stickiness, which is useful in auth plugins, like when you do caching assuming clients * will send many messages in a row. */ const std::optional overrideSharedSubscriptionTarget; SharedSubscribers(const std::string &shareName, const std::optional &fmq_group_id) noexcept; Subscription& operator[](const std::string &clientid); const Subscription *getFirst() const; const Subscription *getNext(); const Subscription *getNext(size_t hash) const; void erase(const std::string &clientid); void purgeAndReIndex(); bool empty() const; void getForSerializing(const std::string &topic, std::unordered_map> &outputList) const; }; #endif // SHAREDSUBSCRIBERS_H ================================================ FILE: sslctxmanager.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "sslctxmanager.h" #include SslCtxManager::SslCtxManager() : ssl_ctx(SSL_CTX_new(TLS_server_method()), SSL_CTX_free) { } SslCtxManager::SslCtxManager(const SSL_METHOD *method) : ssl_ctx(SSL_CTX_new(method), SSL_CTX_free) { } SSL_CTX *SslCtxManager::get() const { return ssl_ctx.get(); } int SslCtxManager::tlsEnumToInt(TLSVersion v) { switch (v) { case TLSVersion::TLSv1_0: return TLS1_VERSION; case TLSVersion::TLSv1_1: return TLS1_1_VERSION; case TLSVersion::TLSv1_2: return TLS1_2_VERSION; case TLSVersion::TLSv1_3: return TLS1_3_VERSION; default: throw std::runtime_error("Unsupported version in tlsEnumToInt"); } } void SslCtxManager::setMinimumTlsVersion(TLSVersion min_version) { if (!ssl_ctx) return; SSL_CTX_set_min_proto_version(ssl_ctx.get(), tlsEnumToInt(min_version)); } ================================================ FILE: sslctxmanager.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SSLCTXMANAGER_H #define SSLCTXMANAGER_H #include #include #include "enums.h" class SslCtxManager { std::unique_ptr ssl_ctx; public: SslCtxManager(); SslCtxManager(const SSL_METHOD *method); SSL_CTX *get() const; static int tlsEnumToInt(TLSVersion v); void setMinimumTlsVersion(TLSVersion min_version); }; #endif // SSLCTXMANAGER_H ================================================ FILE: subscription.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "subscription.h" /** * @brief Subscription::operator == Compares subscription equality based on client id only. * @param rhs Right-hand side. * @return true or false * * QoS is not used in the comparision. This means you upgrade your QoS by subscribing again. The * specs don't specify what to do there. */ bool Subscription::operator==(const Subscription &rhs) const { if (session.expired() && rhs.session.expired()) return true; if (session.expired() || rhs.session.expired()) return false; const std::shared_ptr lhs_ses = session.lock(); const std::shared_ptr rhs_ses = rhs.session.lock(); return lhs_ses && rhs_ses && lhs_ses->getClientId() == rhs_ses->getClientId(); } void Subscription::reset() { session.reset(); qos = 0; } ================================================ FILE: subscription.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SUBSCRIPTION_H #define SUBSCRIPTION_H #include #include "session.h" struct Subscription { std::weak_ptr session; // Weak pointer expires when session has been cleaned by 'clean session' connect or when it was remove because it expired uint8_t qos; bool noLocal = false; bool retainAsPublished = false; uint32_t subscriptionIdentifier = 0; bool operator==(const Subscription &rhs) const; void reset(); }; #endif // SUBSCRIPTION_H ================================================ FILE: subscriptionstore.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "subscriptionstore.h" #include #include "rwlockguard.h" #include "retainedmessagesdb.h" #include "publishcopyfactory.h" #include "threadglobals.h" #include "utils.h" #include "settings.h" #include "plugin.h" #include "exceptions.h" #include "threaddata.h" #include "globals.h" #include DeferredGetSubscription::DeferredGetSubscription(const std::shared_ptr &node, const std::string &composedTopic, const bool root) : node(node), composedTopic(composedTopic), root(root) { } ReceivingSubscriber::ReceivingSubscriber(const std::weak_ptr &ses, uint8_t qos, bool retainAsPublished, const uint32_t subscriptionIdentifier) : session(ses.lock()), qos(qos), retainAsPublished(retainAsPublished), subscriptionIdentifier(subscriptionIdentifier) { } SubscriptionNode::SubscriptionNode() { } const std::unordered_map &SubscriptionNode::getSubscribers() const { return subscribers; } std::unordered_map &SubscriptionNode::getSharedSubscribers() { return sharedSubscribers; } AddSubscriptionType SubscriptionNode::addSubscriber( const std::shared_ptr &subscriber, uint8_t qos, bool noLocal, bool retainAsPublished, const std::string &shareName, const uint32_t subscriptionIdentifier) { if (!subscriber) return AddSubscriptionType::Invalid; Subscription sub; sub.session = subscriber; sub.qos = qos; sub.noLocal = noLocal; sub.retainAsPublished = retainAsPublished; sub.subscriptionIdentifier = subscriptionIdentifier; const std::string &client_id = subscriber->getClientId(); AddSubscriptionType result = AddSubscriptionType::Invalid; std::unique_lock locker(lock); lastUpdate = std::chrono::steady_clock::now(); if (shareName.empty()) { Subscription &s = subscribers[client_id]; result = s.session.expired() ? AddSubscriptionType::NewSubscription : AddSubscriptionType::ExistingSubscription; s = sub; } else { SharedSubscribers &subscribers = sharedSubscribers.try_emplace(shareName, shareName, subscriber->getFmqClientGroupId()).first->second; Subscription &s = subscribers[client_id]; result = s.session.expired() ? AddSubscriptionType::NewSubscription : AddSubscriptionType::ExistingSubscription; s = sub; } return result; } void SubscriptionNode::removeSubscriber(const std::shared_ptr &subscriber, const std::string &shareName) { Subscription sub; sub.session = subscriber; sub.qos = 0; const std::string &clientId = subscriber->getClientId(); std::unique_lock locker(lock); lastUpdate = std::chrono::steady_clock::now(); if (shareName.empty()) { auto it = subscribers.find(clientId); if (it != subscribers.end()) { subscribers.erase(it); } } else { auto pos = sharedSubscribers.find(shareName); if (pos != sharedSubscribers.end()) { SharedSubscribers &subscribers = pos->second; subscribers.erase(clientId); } } } SubscriptionStore::SubscriptionStore() : sessionsByIdConst(sessionsById) { } /** * @brief SubscriptionStore::getDeepestNode gets the node in the tree walking the path of 'the/subscription/topic/path', making new nodes as required. * @param topic * @param subtopics * @return */ std::shared_ptr SubscriptionStore::getDeepestNode(const std::vector &subtopics, bool abort_on_dead_end) { const std::shared_ptr *start = &root; if (!subtopics.empty()) { const std::string &first = subtopics.front(); if (first.length() > 0 && first[0] == '$') start = &rootDollar; } std::shared_ptr result; bool retry_mode = false; for(int i = 0; i < 2; i++) { assert(i < 1 || retry_mode); std::shared_lock rlock(subscriptions_lock, std::defer_lock); std::unique_lock wlock(subscriptions_lock, std::defer_lock); if (retry_mode) { assert(!abort_on_dead_end); wlock.lock(); } else rlock.lock(); auto subtopic_pos = subtopics.begin(); const std::shared_ptr *deepestNode = start; while(subtopic_pos != subtopics.end()) { const std::string &subtopic = *subtopic_pos; std::shared_ptr *selectedChildren = nullptr; if (subtopic == "#") selectedChildren = &(*deepestNode)->childrenPound; else if (subtopic == "+") selectedChildren = &(*deepestNode)->childrenPlus; else { auto &children = (*deepestNode)->children; if (retry_mode) { assert(wlock.owns_lock()); selectedChildren = &children[subtopic]; } else // read-only path { assert(rlock.owns_lock()); auto child_pos = children.find(subtopic); if (child_pos == children.end()) { if (abort_on_dead_end) return result; retry_mode = true; break; } else { selectedChildren = &child_pos->second; } } } std::shared_ptr &node = *selectedChildren; if (!node) { if (!retry_mode) { assert(rlock.owns_lock()); if (abort_on_dead_end) return result; retry_mode = true; break; } assert(retry_mode); assert(wlock.owns_lock()); node = std::make_shared(); /* * This is not technically correct, because we haven't made a subscription yet, but it: * * 1) Only happens when we're about to make a subscription, and not when we remove one. * 2) Allows continuous subscription and unsubcription on a node without needing to update this atomic counter. */ subscriptionCount++; } deepestNode = &node; subtopic_pos++; } assert(deepestNode); assert(*deepestNode); if (retry_mode && subtopic_pos != subtopics.end()) continue; result = *deepestNode; assert(result); return result; } return result; } AddSubscriptionType SubscriptionStore::addSubscription( const std::shared_ptr &session, const std::vector &subtopics, uint8_t qos, bool noLocal, bool retainAsPublished, const std::string &shareName, const uint32_t subscriptionIdentifier) { if (!session) return AddSubscriptionType::Invalid; const std::shared_ptr deepestNode = getDeepestNode(subtopics); if (!deepestNode) return AddSubscriptionType::Invalid; return deepestNode->addSubscriber(session, qos, noLocal, retainAsPublished, shareName, subscriptionIdentifier); } void SubscriptionStore::removeSubscription( const std::shared_ptr &session, const std::vector &subtopics, const std::string &shareName) { if (!session) return; std::shared_ptr node = getDeepestNode(subtopics, true); if (!node) return; node->removeSubscriber(session, shareName); } std::shared_ptr SubscriptionStore::getBridgeSession(std::shared_ptr &client) { const std::string &client_id = client->getClientId(); std::unique_lock locker(sessions_lock); std::shared_ptr &session = sessionsById[client_id]; if (!session || session->getDestroyOnDisconnect() || client->getCleanStart()) session = std::make_shared(client_id, client->getUsername(), client->getFmqClientGroupId()); session->assignActiveConnection(client); client->assignSession(session); session->setClientType(ClientType::LocalBridge); return session; } /** * @brief SubscriptionStore::registerClientAndKickExistingOne registers a client with previously set parameters for the session. * @param client * * Under normal MQTT operation, the 'if' clause is always used. The 'else' is only in (fuzz) testing and other rare conditions. */ void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client) { const std::unique_ptr ®istrationData = client->getRegistrationData(); if (registrationData) { registerClientAndKickExistingOne(client, registrationData->clean_start, registrationData->clientReceiveMax, registrationData->sessionExpiryInterval); client->clearRegistrationData(); } else { const Settings *settings = ThreadGlobals::getSettings(); registerClientAndKickExistingOne(client, true, settings->maxQosMsgPendingPerClient, settings->expireSessionsAfterSeconds); } } // Removes an existing client when it already exists [MQTT-3.1.4-2]. void SubscriptionStore::registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval) { ThreadGlobals::getThreadData()->queueClientNextKeepAliveCheck(client, true); // These destructors need to be called outside the sessions lock, so placing here. std::shared_ptr session; if (client->getClientId().empty()) throw ProtocolError("Trying to store client without an ID.", ReasonCodes::ProtocolError); { std::unique_lock ses_locker(sessions_lock); auto session_it = sessionsById.find(client->getClientId()); if (session_it != sessionsById.end()) { session = session_it->second; if (session) { if (session->getUsername() != client->getUsername()) throw ProtocolError("Cannot take over session with different username", ReasonCodes::NotAuthorized); if (session->getFmqClientGroupId() != client->getFmqClientGroupId()) throw ProtocolError("Cannot take over session with different FMQ client group ID", ReasonCodes::NotAuthorized); std::shared_ptr clientOfOtherSession = session->makeSharedClient(); if (clientOfOtherSession) { logger->logf(LOG_NOTICE, "Disconnecting existing client with id '%s'", clientOfOtherSession->getClientId().c_str()); std::shared_ptr td = clientOfOtherSession->lockThreadData(); td->serverInitiatedDisconnect(std::move(clientOfOtherSession), ReasonCodes::SessionTakenOver, "Another client with this ID connected"); } } } if (!session || session->getDestroyOnDisconnect() || clean_start) { // Don't use sdt::make_shared to avoid the weak pointers from retaining the size of session in the control block. session = std::shared_ptr(new Session(client->getClientId(), client->getUsername(), client->getFmqClientGroupId())); sessionsById[client->getClientId()] = session; } } session->assignActiveConnection(session, client, clientReceiveMax, sessionExpiryInterval, clean_start); } /** * @brief SubscriptionStore::lockSession returns the session if it exists. Returning is done keep the shared pointer active, to * avoid race conditions with session removal. * @param clientid * @return */ std::shared_ptr SubscriptionStore::lockSession(const std::string &clientid) { std::shared_lock ses_locker(sessions_lock); auto it = sessionsByIdConst.find(clientid); if (it != sessionsByIdConst.end()) { return it->second; } return std::shared_ptr(); } void SubscriptionStore::sendWill(const std::shared_ptr will, const std::shared_ptr session, const std::string &log) { if (!will || !session) return; /* * Avoid sending two immediate wills when a session is destroyed with the client disconnect. * Session is null when you're destroying a client before a session is assigned, or * when an old client has no session anymore after a client with the same ID connects. */ session->clearWill(); const Settings *settings = ThreadGlobals::getSettings(); if (!settings->willsEnabled) return; logger->log(LOG_DEBUG) << log << " " << will->topic; Authentication &auth = ThreadGlobals::getThreadData()->authentication; const AuthResult authResult = auth.aclCheck(*will, will->payload, AclAccess::write); if (authResult == AuthResult::success || authResult == AuthResult::success_without_setting_retained) { PublishCopyFactory factory(will.get()); // Not having a stored fmq_client_group_id to pass. Theoretically, it should not be needed, // because a client cannot get its own will. queuePacketAtSubscribers(factory, will->client_id, {}); if (will->retain && authResult == AuthResult::success) setRetainedMessage(*will, will->getSubtopics()); } } /** * @brief SubscriptionStore::sendQueuedWillMessages sends queued will messages. * * The expiry interval as set in the properties of the will message is not used to check for expiration here. To * quote the specs: "If present, the Four Byte value is the lifetime of the Will Message in seconds and is sent as * the Publication Expiry Interval when the Server publishes the Will Message." * * If a new Network Connection to this Session is made before the Will Delay Interval has passed, the Server * MUST NOT send the Will Message [MQTT-3.1.3-9]. */ void SubscriptionStore::sendQueuedWillMessages() { const auto now = std::chrono::steady_clock::now(); const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast(now.time_since_epoch()); std::lock_guard locker(this->pendingWillsMutex); auto it = pendingWillMessages.begin(); while (it != pendingWillMessages.end()) { const std::chrono::seconds &sendAt = it->first; if (sendAt > secondsSinceEpoch) break; std::vector &willsOfSlot = it->second; for(QueuedWill &will : willsOfSlot) { std::shared_ptr p = will.getWill().lock(); std::shared_ptr s = will.getSession(); // If the session has been picked up again after the will was originally queued, we should not send it. if (s && s->hasActiveClient()) continue; sendWill(p, s, "Sending delayed will on topic: "); } it = pendingWillMessages.erase(it); } } void SubscriptionStore::queueOrSendWillMessage( const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow) { if (!willMessage) return; const uint32_t delay = forceNow ? 0 : willMessage->will_delay; if (delay > 0) queueWillMessage(willMessage, session); else sendWill(willMessage, session, "Sending immediate will on topic: "); } /** * @brief SubscriptionStore::queueWillMessage queues the will message in a sorted map. * * The queued will is only valid for that time. Should a new will be placed in the map for a session, the original shared_ptr * will be cleared and the previously queued entry is void (but still there, so it needs to be checked). */ void SubscriptionStore::queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session) { if (!willMessage) return; if (logger->wouldLog(LOG_DEBUG)) logger->log(LOG_DEBUG) << "Queueing will on topic '" << willMessage->topic << "', with delay of " << willMessage->will_delay << " seconds."; willMessage->setQueuedAt(); QueuedWill queuedWill(willMessage, session); const std::chrono::time_point sendWillAt = std::chrono::steady_clock::now() + std::chrono::seconds(willMessage->will_delay); std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast(sendWillAt.time_since_epoch()); std::lock_guard locker(this->pendingWillsMutex); this->pendingWillMessages[secondsSinceEpoch].push_back(queuedWill); } void SubscriptionStore::publishNonRecursively( SubscriptionNode *this_node, std::vector &targetSessions, const std::string &senderClientId, const std::optional &fmq_client_group_id) noexcept { std::shared_lock locker(this_node->lock); for (auto &pair : this_node->subscribers) { const Subscription &sub = pair.second; targetSessions.emplace_back(sub.session, sub.qos, sub.retainAsPublished, sub.subscriptionIdentifier); /* * Shared pointer expires when session has been cleaned by 'clean session' disconnect. * * By not using a tempory locked shared_ptr for checks, doing an optimistic insertion instead, we avoid * unnecessary copies. The only extra overhead this causes is the list pop when we decice to remove it. */ if (!targetSessions.back().session) { targetSessions.pop_back(); continue; } if (sub.noLocal && targetSessions.back().session->getClientId() == senderClientId) { targetSessions.pop_back(); continue; } } if (this_node->sharedSubscribers.empty()) return; const Settings *settings = ThreadGlobals::getSettings(); for(auto &pair : this_node->sharedSubscribers) { SharedSubscribers &subscribers = pair.second; // See SharedSubscribers for details about the target override. In short: we need to ensure delivery order. const SharedSubscriptionTargeting targetting = subscribers.overrideSharedSubscriptionTarget.value_or(settings->sharedSubscriptionTargeting); const Subscription *sub = nullptr; if (targetting == SharedSubscriptionTargeting::SenderHash) { const size_t hash = std::hash()(senderClientId); sub = subscribers.getNext(hash); } else if (targetting == SharedSubscriptionTargeting::RoundRobin) sub = subscribers.getNext(); else if (targetting == SharedSubscriptionTargeting::First) sub = subscribers.getFirst(); if (sub == nullptr) continue; // Same comment about duplicate copies as above. targetSessions.emplace_back(sub->session, sub->qos, sub->retainAsPublished, sub->subscriptionIdentifier); if (!targetSessions.back().session) { targetSessions.pop_back(); continue; } /* * Normal shared subscriptions are not supposed to filter out the 'no local' subscriptions (MQTT-3.8.3-4), * because it would create inconsistent behavior depending on which client gets it. So, we implement a * custom feature. */ const std::optional &session_group_id = targetSessions.back().session->getFmqClientGroupId(); if (session_group_id && session_group_id == fmq_client_group_id) { targetSessions.pop_back(); continue; } } } /** * @brief SubscriptionStore::publishRecursively * @param cur_subtopic_it * @param end * @param this_node * @param packet * @param count as a reference (vs return value) because a return value introduces an extra call i.e. limits tail recursion optimization. * * As noted in the params section, this method was written so that it could be (somewhat) optimized for tail recursion by the compiler. If you refactor this, * look at objdump --disassemble --demangle to see how many calls (not jumps) to itself are made and compare. */ void SubscriptionStore::publishRecursively( std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, SubscriptionNode *this_node, std::vector &targetSessions, const std::string &senderClientId, const std::optional &fmq_client_group_id) noexcept { if (cur_subtopic_it == end) // This is the end of the topic path, so look for subscribers here. { if (this_node) { publishNonRecursively(this_node, targetSessions, senderClientId, fmq_client_group_id); // Subscribing to 'one/two/three/#' also gives you 'one/two/three'. if (this_node->childrenPound) { publishNonRecursively(this_node->childrenPound.get(), targetSessions, senderClientId, fmq_client_group_id); } } return; } // Null nodes in the tree shouldn't happen at this point. It points to bugs elsewhere. However, I don't remember why I check the pointer // inside the if-block above, instead of the start of the method. assert(this_node != nullptr); if (this_node->children.empty() && !this_node->childrenPlus && !this_node->childrenPound) return; const std::string &cur_subtop = *cur_subtopic_it; const auto next_subtopic = ++cur_subtopic_it; if (this_node->childrenPound) { publishNonRecursively(this_node->childrenPound.get(), targetSessions, senderClientId, fmq_client_group_id); } const auto &sub_node = this_node->children.find(cur_subtop); if (this_node->childrenPlus) { publishRecursively(next_subtopic, end, this_node->childrenPlus.get(), targetSessions, senderClientId, fmq_client_group_id); } if (sub_node != this_node->children.end()) { publishRecursively(next_subtopic, end, sub_node->second.get(), targetSessions, senderClientId, fmq_client_group_id); } } void SubscriptionStore::queuePacketAtSubscribers( PublishCopyFactory ©Factory, const std::string &senderClientId, const std::optional &fmq_client_group_id, bool dollar) { /* * Sometimes people publish or set as will topics with dollar. Node-to-Node communication for bridges for instance. * We still accept them, but just decide to do nothing with them. Only the FlashMQ internals are allowed to publish * on dollar topics. */ if (!dollar && copyFactory.getTopic()[0] == '$') // String is always 0-terminated, so we can access first element. { return; } SubscriptionNode *startNode = dollar ? rootDollar.get() : root.get(); const size_t reserve = this->subscriber_reserve.load(std::memory_order_relaxed); std::vector subscriberSessions; subscriberSessions.reserve(reserve); { const std::vector &subtopics = copyFactory.getSubtopics(); std::shared_lock locker(subscriptions_lock); publishRecursively(subtopics.begin(), subtopics.end(), startNode, subscriberSessions, senderClientId, fmq_client_group_id); } if (subscriberSessions.size() > reserve && subscriberSessions.size() <= 1048576) this->subscriber_reserve.store(reserve, std::memory_order_relaxed); for(const ReceivingSubscriber &x : subscriberSessions) { x.session->writePacket(copyFactory, x.qos, x.retainAsPublished, x.subscriptionIdentifier); } } void SubscriptionStore::giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, const std::shared_ptr &this_node, bool poundMode, const std::shared_ptr &session, const uint8_t max_qos, const uint32_t subscription_identifier, const std::chrono::time_point &limit, std::deque &deferred, int &drop_count, int &processed_nodes_count) { if (!this_node) return; if (cur_subtopic_it == end) { Authentication &auth = ThreadGlobals::getThreadData()->authentication; std::lock_guard locker(this_node->messageSetMutex); if (this_node->message) { const RetainedMessage &rm = *this_node->message; if (!rm.hasExpired()) // We can't also erase here, because we're operating under a read lock. { Publish publish = rm.publish; if (auth.aclCheck(publish, publish.payload) == AuthResult::success) { PublishCopyFactory copyFactory(&publish); const PacketDropReason drop_reason = session->writePacket(copyFactory, max_qos, true, subscription_identifier); if (drop_reason == PacketDropReason::BufferFull || drop_reason == PacketDropReason::QoSTODOSomethingSomething) { drop_count++; DeferredRetainedMessageNodeDelivery d; d.node = this_node; d.cur = cur_subtopic_it; d.end = end; d.poundMode = poundMode; deferred.push_back(d); /* * Capacity-wise, we wouldn't have to return here if it was merely a QoS limit, but then it becomes hard * to filter out the children of this node, which you'd need to avoid duplicate deliveries. So, * this is the easier approach. */ return; } } } } if (poundMode) { const int drop_count_start = drop_count; for (auto &pair : this_node->children) { if (std::chrono::steady_clock::now() >= limit || drop_count_start != drop_count) { DeferredRetainedMessageNodeDelivery d; d.node = pair.second; d.cur = cur_subtopic_it; d.end = end; d.poundMode = poundMode; deferred.push_back(d); } else { std::shared_ptr &child = pair.second; giveClientRetainedMessagesRecursively( cur_subtopic_it, end, child, poundMode, session, max_qos, subscription_identifier, limit, deferred, drop_count, ++processed_nodes_count); } } } return; } const std::string &cur_subtop = *cur_subtopic_it; const auto next_subtopic = ++cur_subtopic_it; if (cur_subtop == "#") { // We start at this node, so that a subscription on 'one/two/three/#' gives you 'one/two/three' too. giveClientRetainedMessagesRecursively( next_subtopic, end, this_node, true, session, max_qos, subscription_identifier, limit, deferred, drop_count, ++processed_nodes_count); } else if (cur_subtop == "+") { const int drop_count_start = drop_count; for (std::pair> &pair : this_node->children) { if (std::chrono::steady_clock::now() >= limit || drop_count_start != drop_count) { DeferredRetainedMessageNodeDelivery d; d.node = pair.second; d.cur = next_subtopic; d.end = end; d.poundMode = poundMode; deferred.push_back(d); } else { std::shared_ptr &child = pair.second; giveClientRetainedMessagesRecursively( next_subtopic, end, child, false, session, max_qos, subscription_identifier, limit, deferred, drop_count, ++processed_nodes_count); } } } else { std::shared_ptr children = this_node->getChildren(cur_subtop); if (children) { if (std::chrono::steady_clock::now() >= limit) { DeferredRetainedMessageNodeDelivery d; d.node = children; d.cur = next_subtopic; d.end = end; d.poundMode = poundMode; deferred.push_back(d); } else { giveClientRetainedMessagesRecursively( next_subtopic, end, children, false, session, max_qos, subscription_identifier, limit, deferred, drop_count, ++processed_nodes_count); } } } } void SubscriptionStore::giveClientRetainedMessagesInitiateDeferred(const std::weak_ptr ses, const std::shared_ptr> subscribeSubtopicsCopy, std::shared_ptr> deferred, int &requeue_count, uint &total_node_count, uint8_t max_qos, const uint32_t subscription_identifier) { std::shared_ptr session = ses.lock(); if (!session) return; const Settings *settings = ThreadGlobals::getSettings(); const std::chrono::time_point new_limit = std::chrono::steady_clock::now() + std::chrono::milliseconds(10); int drop_count = 0; int processed_nodes = 0; for (; !deferred->empty(); deferred->pop_front()) { if (std::chrono::steady_clock::now() >= new_limit) break; if (drop_count > 0) break; DeferredRetainedMessageNodeDelivery &d = deferred->front(); std::shared_ptr node = d.node.lock(); if (!node) continue; RWLockGuard locker(&retainedMessagesRwlock); locker.rdlock(); giveClientRetainedMessagesRecursively(d.cur, d.end, node, d.poundMode, session, max_qos, subscription_identifier, new_limit, *deferred, drop_count, processed_nodes); } total_node_count += processed_nodes; if (processed_nodes > 0) { requeue_count = 0; } if (!deferred->empty() && ++requeue_count < 100 && total_node_count < settings->retainedMessagesNodeLimit) { auto &t = ThreadGlobals::getThreadData(); auto again = std::bind(&SubscriptionStore::giveClientRetainedMessagesInitiateDeferred, this, ses, subscribeSubtopicsCopy, deferred, requeue_count, total_node_count, max_qos, subscription_identifier); /* * Adding a delayed retry is kind of a cheap way to avoid detecting when there is buffer or QoS space in the event loop, but that comes * with several layers of complexity that makes the the normal (non-retained) flow more complicated. Also, buffer full situations * is probably most likely to happen when clients subscribe to a broad wildcard on a big subscription tree, and this * allows some depriorirzation. */ if (drop_count > 0) t->addDelayedTask(again, 50); else t->addImmediateTask(again); } } void SubscriptionStore::giveClientRetainedMessages(const std::shared_ptr &ses, const std::vector &subscribeSubtopics, uint8_t max_qos, const uint32_t subscriptionIdentifier) { if (!ses) return; const Settings *settings = ThreadGlobals::getSettings(); if (settings->retainedMessagesMode >= RetainedMessagesMode::EnabledWithoutRetaining) return; // The specs aren't clear whether retained messages should be dropped, or just have their retain flag stripped. I chose the former, // otherwise clients have no way of knowing if a message was retained or not. { const std::shared_ptr client = ses->makeSharedClient(); if (client && !client->isRetainedAvailable()) return; } const std::shared_ptr *startNode = &retainedMessagesRoot; if (!subscribeSubtopics.empty() && !subscribeSubtopics[0].empty() > 0 && subscribeSubtopics[0][0] == '$') startNode = &retainedMessagesRootDollar; const std::shared_ptr> subscribeSubtopicsCopy = std::make_shared>(subscribeSubtopics); std::shared_ptr> deferred = std::make_shared>(); DeferredRetainedMessageNodeDelivery start; start.node = *startNode; start.cur = subscribeSubtopicsCopy->begin(); start.end = subscribeSubtopicsCopy->end(); deferred->push_back(start); int requeue_count = 0; uint total_node_count = 0; giveClientRetainedMessagesInitiateDeferred(ses, subscribeSubtopicsCopy, deferred, requeue_count, total_node_count, max_qos, subscriptionIdentifier); } /** * @brief SubscriptionStore::trySetRetainedMessages queues setting of retained messages if not able to set directly. * @param publish * @param subtopics */ void SubscriptionStore::trySetRetainedMessages(const Publish &publish, const std::vector &subtopics) { const Settings *settings = ThreadGlobals::getSettings(); const bool try_lock_fail = settings->setRetainedMessageDeferTimeout.count() != 0; auto &td = ThreadGlobals::getThreadData(); if (!td) return; td->retainedMessageSet.inc(1); // Only do direct setting when there are none queued, to avoid out of order races, which would result in the wrong ultimate value. if (td->queuedRetainedMessagesEmpty() && setRetainedMessage(publish, subtopics, try_lock_fail)) return; auto spread {settings->setRetainedMessageDeferTimeoutSpread}; if (spread != std::chrono::milliseconds::zero()) { const int64_t rnd = static_cast(td->randomish()); spread = std::chrono::milliseconds{rnd % spread.count()}; } const auto randomized_timeout = settings->setRetainedMessageDeferTimeout + spread; std::chrono::time_point limit = std::chrono::steady_clock::now() + randomized_timeout; td->queueSettingRetainedMessage(publish, subtopics, limit); } bool SubscriptionStore::setRetainedMessage(const Publish &publish, const std::vector &subtopics, bool try_lock_fail) { assert(!subtopics.empty()); const Settings *settings = ThreadGlobals::getSettings(); if (settings->retainedMessagesMode >= RetainedMessagesMode::EnabledWithoutRetaining) return true; const std::shared_ptr *deepestNode = &retainedMessagesRoot; if (!subtopics.empty() && !subtopics[0].empty() > 0 && subtopics[0][0] == '$') deepestNode = &retainedMessagesRootDollar; bool needsWriteLock = false; auto subtopic_pos = subtopics.begin(); std::shared_ptr selected_node; std::shared_ptr retry_point; // First do a read-only search for the node. { RWLockGuard locker(&retainedMessagesRwlock); if (try_lock_fail) { if (!locker.tryrdlock()) return false; } else locker.rdlock(); while(subtopic_pos != subtopics.end()) { auto pos = (*deepestNode)->children.find(*subtopic_pos); if (pos == (*deepestNode)->children.end()) { needsWriteLock = true; retry_point = *deepestNode; break; } std::shared_ptr &selectedChildren = pos->second; if (!selectedChildren) { needsWriteLock = true; retry_point = *deepestNode; break; } deepestNode = &selectedChildren; subtopic_pos++; } assert(deepestNode); if (!needsWriteLock && deepestNode) { selected_node = *deepestNode; } } if (needsWriteLock) { RWLockGuard locker(&retainedMessagesRwlock); if (try_lock_fail) { if (!locker.trywrlock()) return false; } else locker.wrlock(); deepestNode = &retry_point; while(subtopic_pos != subtopics.end()) { std::shared_ptr &selectedChildren = (*deepestNode)->children[*subtopic_pos]; if (!selectedChildren) { selectedChildren = std::make_shared(); } deepestNode = &selectedChildren; subtopic_pos++; } assert(deepestNode); if (deepestNode) { selected_node = *deepestNode; } } if (selected_node) { const ssize_t diff = selected_node->addPayload(publish); if (diff != 0) this->retainedMessageCount.fetch_add(diff); } return true; } // Clean up the weak pointers to sessions and remove nodes that are empty. int SubscriptionNode::cleanSubscriptions(std::deque> &defferedLeafs, size_t &real_subscriber_count) { const size_t children_amount = children.size(); const bool split = children_amount > 15; int subscribersLeftInChildren = 0; auto childrenIt = children.begin(); while(childrenIt != children.end()) { std::shared_ptr &node = childrenIt->second; if (!node) continue; if (split && !node->empty()) { defferedLeafs.push_back(node); subscribersLeftInChildren += 1; // We just have to be sure it's not 0. childrenIt++; continue; } int n = node->cleanSubscriptions(defferedLeafs, real_subscriber_count); subscribersLeftInChildren += n; if (n > 0) childrenIt++; else childrenIt = children.erase(childrenIt); } std::list*> wildcardChildren; wildcardChildren.push_back(&childrenPlus); wildcardChildren.push_back(&childrenPound); for (std::shared_ptr *node : wildcardChildren) { std::shared_ptr &node_ = *node; if (!node_) continue; int n = node_->cleanSubscriptions(defferedLeafs, real_subscriber_count); subscribersLeftInChildren += n; if (n == 0) { if (Logger::getInstance()->wouldLog(LOG_DEBUG)) Logger::getInstance()->logf(LOG_DEBUG, "Resetting wildcard children"); node_.reset(); } } { // This is not particularlly fast when it's many items. But we don't do it often, so is probably okay. auto it = subscribers.begin(); while (it != subscribers.end()) { auto cur_it = it; it++; if (cur_it->second.session.expired()) { if (Logger::getInstance()->wouldLog(LOG_DEBUG)) Logger::getInstance()->logf(LOG_DEBUG, "Removing empty spot in subscribers map"); subscribers.erase(cur_it); } } } { auto shared_it = sharedSubscribers.begin(); while (shared_it != sharedSubscribers.end()) { auto cur_shared = shared_it; shared_it++; SharedSubscribers &subscribers_of_share = cur_shared->second; subscribers_of_share.purgeAndReIndex(); if (subscribers_of_share.empty()) sharedSubscribers.erase(cur_shared); } } const Settings *settings = ThreadGlobals::getSettings(); const bool grace_period_expired = lastUpdate + settings->subscriptionNodeLifetime < std::chrono::steady_clock::now(); const int grace_period_fake = static_cast(!grace_period_expired); const size_t node_subscriber_count = subscribers.size() + sharedSubscribers.size(); real_subscriber_count += node_subscriber_count; return node_subscriber_count + subscribersLeftInChildren + grace_period_fake; } bool SubscriptionNode::empty() const { return children.empty() && subscribers.empty() && sharedSubscribers.empty() && !childrenPlus && !childrenPound; } void SubscriptionStore::removeSession(const std::shared_ptr &session) { if (!session) return; const std::string &clientid = session->getClientId(); if (logger->wouldLog(LOG_DEBUG)) logger->log(LOG_DEBUG) << "Removing session of client '" << clientid << "', if it matches the object."; std::list> sessionsToRemove; { std::unique_lock session_locker(sessions_lock); auto session_it = sessionsById.find(clientid); if (session_it != sessionsById.end() && session_it->second == session) { sessionsToRemove.push_back(session_it->second); sessionsById.erase(session_it); } } for(std::shared_ptr &s : sessionsToRemove) { if (!s) continue; std::shared_ptr will = s->getWill(); if (will) { queueOrSendWillMessage(will, s, true); } s.reset(); } } /** * @brief SubscriptionStore::removeExpiredSessionsClients removes expired sessions. * * For Mqtt3 this is non-standard, but the standard doesn't keep real world constraints into account. */ void SubscriptionStore::removeExpiredSessionsClients() { logger->logf(LOG_DEBUG, "Cleaning out old sessions"); const std::chrono::time_point now = std::chrono::steady_clock::now(); const std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast(now.time_since_epoch()); // Collect sessions to remove for a separate step, to avoid holding two locks at the same time. std::vector> sessionsToRemove; int removedSessions = 0; int processedRemovals = 0; int queuedRemovalsLeft = -1; { std::lock_guard locker(this->queuedSessionRemovalsMutex); auto it = queuedSessionRemovals.begin(); while (it != queuedSessionRemovals.end()) { const std::chrono::seconds &removeAt = it->first; if (removeAt > secondsSinceEpoch) { break; } std::vector> &sessionsFromSlot = it->second; for (std::weak_ptr &ses : sessionsFromSlot) { std::shared_ptr lockedSession = ses.lock(); // A session could have been picked up again, so we have to verify its expiration status. if (lockedSession && !lockedSession->hasActiveClient()) { sessionsToRemove.push_back(lockedSession); } } it = queuedSessionRemovals.erase(it); processedRemovals++; } queuedRemovalsLeft = queuedSessionRemovals.size(); } for(std::shared_ptr &session : sessionsToRemove) { removeSession(session); removedSessions++; } logger->logf(LOG_DEBUG, "Processed %d queued session removals, resulting in %d deleted expired sessions. %d queued removals in the future.", processedRemovals, removedSessions, queuedRemovalsLeft); } bool SubscriptionStore::hasDeferredSubscriptionTreeNodesForPurging() { std::shared_lock locker(subscriptions_lock); return !deferredSubscriptionLeafsForPurging.empty(); } bool SubscriptionStore::purgeSubscriptionTree() { bool deferredLeavesPresent = hasDeferredSubscriptionTreeNodesForPurging(); if (deferredLeavesPresent) { std::unique_lock locker(subscriptions_lock); logger->log(LOG_INFO) << "Rebuilding subscription tree: we have " << deferredSubscriptionLeafsForPurging.size() << " deferred leafs to clean up. Doing some."; const std::chrono::time_point limit = std::chrono::steady_clock::now() + std::chrono::milliseconds(10); int counter = 0; for (; !deferredSubscriptionLeafsForPurging.empty(); deferredSubscriptionLeafsForPurging.pop_front()) { if (limit < std::chrono::steady_clock::now()) break; std::shared_ptr node = deferredSubscriptionLeafsForPurging.front().lock(); if (node) { counter++; node->cleanSubscriptions(deferredSubscriptionLeafsForPurging, subscriptionDeferredCounter); } } logger->log(LOG_INFO) << "Rebuilding subscription tree: processed " << counter << " deferred leafs. Deferred leafs left: " << deferredSubscriptionLeafsForPurging.size(); } else { std::unique_lock locker(subscriptions_lock); logger->logf(LOG_INFO, "Rebuilding subscription tree"); subscriptionDeferredCounter = 0; root->cleanSubscriptions(deferredSubscriptionLeafsForPurging, subscriptionDeferredCounter); logger->log(LOG_INFO) << "Rebuilding subscription tree done, with " << deferredSubscriptionLeafsForPurging.size() << " deferred direct leafs to check"; } const bool done = !hasDeferredSubscriptionTreeNodesForPurging(); if (done) subscriptionCount = subscriptionDeferredCounter; return done; } bool SubscriptionStore::hasDeferredRetainedMessageNodesForPurging() { RWLockGuard lock_guard(&retainedMessagesRwlock); lock_guard.rdlock(); return !deferredRetainedMessageNodeToPurge.empty(); } bool SubscriptionStore::expireRetainedMessages() { bool deferredRetainedCleanup = hasDeferredRetainedMessageNodesForPurging(); const std::chrono::time_point limit = std::chrono::steady_clock::now() + std::chrono::milliseconds(10); if (deferredRetainedCleanup) { RWLockGuard lock_guard(&retainedMessagesRwlock); lock_guard.wrlock(); logger->log(LOG_INFO) << "Expiring retained messages: we have " << deferredRetainedMessageNodeToPurge.size() << " deferred leafs to expire. Doing some."; int counter = 0; for (; !deferredRetainedMessageNodeToPurge.empty(); deferredRetainedMessageNodeToPurge.pop_front()) { if (limit < std::chrono::steady_clock::now()) break; std::shared_ptr node = deferredRetainedMessageNodeToPurge.front().lock(); if (node) { counter++; this->expireRetainedMessages(node.get(), limit, deferredRetainedMessageNodeToPurge, retainedMessageDeferredCounter); } } logger->log(LOG_INFO) << "Expiring retained messages: processed " << counter << " deferred leafs. Deferred leafs left: " << deferredRetainedMessageNodeToPurge.size(); } else { logger->log(LOG_INFO) << "Expiring retained messages."; RWLockGuard lock_guard(&retainedMessagesRwlock); lock_guard.wrlock(); retainedMessageDeferredCounter = 0; this->expireRetainedMessages(retainedMessagesRoot.get(), limit, deferredRetainedMessageNodeToPurge, retainedMessageDeferredCounter); logger->log(LOG_INFO) << "Expiring retained messages done, with " << deferredRetainedMessageNodeToPurge.size() << " deferred nodes to check."; } const bool done = !hasDeferredRetainedMessageNodesForPurging(); if (done) retainedMessageCount = retainedMessageDeferredCounter; return done; } /** * @brief SubscriptionStore::queueSessionRemoval places session efficiently in a sorted map that is periodically dequeued. * @param session */ void SubscriptionStore::queueSessionRemoval(const std::shared_ptr &session) { if (!session) return; std::chrono::time_point removeAt = std::chrono::steady_clock::now() + std::chrono::seconds(session->getSessionExpiryInterval()); std::chrono::seconds secondsSinceEpoch = std::chrono::duration_cast(removeAt.time_since_epoch()); session->setQueuedRemovalAt(); std::lock_guard locker(this->queuedSessionRemovalsMutex); queuedSessionRemovals[secondsSinceEpoch].push_back(session); } size_t SubscriptionStore::getRetainedMessageCount() const { return retainedMessageCount; } uint64_t SubscriptionStore::getSessionCount() const { return sessionsByIdConst.size(); } size_t SubscriptionStore::getSubscriptionCount() { return subscriptionCount; } void SubscriptionStore::getRetainedMessages( RetainedMessageNode *this_node, std::vector &outputList, const std::chrono::time_point &limit, const size_t limit_count, std::deque> &deferred) const { { std::lock_guard locker(this_node->messageSetMutex); if (this_node->message) outputList.emplace_back(*this_node->message); } for(auto &pair : this_node->children) { const std::shared_ptr &child = pair.second; if (std::chrono::steady_clock::now() > limit || outputList.size() >= limit_count) deferred.push_back(child); else getRetainedMessages(child.get(), outputList, limit, limit_count, deferred); } } #ifdef TESTING std::vector SubscriptionStore::getAllRetainedMessages() { RWLockGuard locker(&retainedMessagesRwlock); locker.rdlock(); const std::shared_ptr node = retainedMessagesRoot; std::vector result; std::deque> deferred; auto time_limit = std::chrono::time_point::max(); getRetainedMessages(node.get(), result, time_limit, std::numeric_limits::max(), deferred); return result; } #endif /** * @brief SubscriptionStore::getSubscriptions * @param this_node * @param composedTopic * @param root bool. Every subtopic is concatenated with a '/', but not the first topic to 'root'. The root is a bit weird, virtual, so it needs different treatment. * @param outputList */ void SubscriptionStore::getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, std::unordered_map> &outputList, std::deque &deferred, const std::chrono::time_point limit) const { // No code should make dummy nodes that are null, but still protecting against it. assert(this_node); if (!this_node) return; for (auto &pair : this_node->getSubscribers()) { const Subscription &node = pair.second; std::shared_ptr ses = node.session.lock(); if (ses) { SubscriptionForSerializing sub(ses->getClientId(), node.qos, node.noLocal, node.retainAsPublished, node.subscriptionIdentifier); outputList[composedTopic].push_back(sub); } } for (auto &pair : this_node->getSharedSubscribers()) { const SharedSubscribers &node = pair.second; node.getForSerializing(composedTopic, outputList); } for (auto &pair : this_node->children) { SubscriptionNode *node = pair.second.get(); const std::string topicAtNextLevel = root ? pair.first : composedTopic + "/" + pair.first; if (std::chrono::steady_clock::now() < limit) getSubscriptions(node, topicAtNextLevel, false, outputList, deferred, limit); else deferred.emplace_back(pair.second, topicAtNextLevel, false); } if (this_node->childrenPlus) { const std::string topicAtNextLevel = root ? "+" : composedTopic + "/+"; if (std::chrono::steady_clock::now() < limit) getSubscriptions(this_node->childrenPlus.get(), topicAtNextLevel, false, outputList, deferred, limit); else deferred.emplace_back(this_node->childrenPlus, topicAtNextLevel, false); } if (this_node->childrenPound) { const std::string topicAtNextLevel = root ? "#" : composedTopic + "/#"; if (std::chrono::steady_clock::now() < limit) getSubscriptions(this_node->childrenPound.get(), topicAtNextLevel, false, outputList, deferred, limit); else deferred.emplace_back(this_node->childrenPound, topicAtNextLevel, false); } } std::unordered_map> SubscriptionStore::getSubscriptions() { std::deque deferred; std::unordered_map> subscriptionCopies; DeferredGetSubscription start(root, "", true); deferred.push_front(std::move(start)); for (; !deferred.empty(); deferred.pop_front()) { std::shared_lock locker(subscriptions_lock, std::defer_lock); if (globals->quitting) locker.lock(); else { // TODO: C++ doesn't provide the non-portable pthread functions to avoid these try_lock read locks // from jumping the queue over waiting write locks. But, I'm not quite sure if I want reader // or writer starvation yet anyway... const auto try_lock_timeout = std::chrono::steady_clock::now() + std::chrono::milliseconds(10); while (std::chrono::steady_clock::now() < try_lock_timeout) { if (locker.try_lock()) break; std::this_thread::sleep_for(std::chrono::microseconds(100)); } if (!locker.owns_lock()) locker.lock(); } DeferredGetSubscription &def = deferred.front(); std::shared_ptr node = def.node.lock(); if (!node) continue; const std::chrono::time_point limit = std::chrono::steady_clock::now() + std::chrono::milliseconds(10); getSubscriptions(node.get(), def.composedTopic, def.root, subscriptionCopies, deferred, limit); } return subscriptionCopies; } void SubscriptionStore::expireRetainedMessages( RetainedMessageNode *this_node, const std::chrono::time_point &limit, std::deque> &deferred, size_t &real_message_counter) { if (this_node->message && this_node->message->hasExpired()) { this_node->message.reset(); } if (this_node->message) real_message_counter++; auto cpos = this_node->children.begin(); while (cpos != this_node->children.end()) { auto cur = cpos; cpos++; if (std::chrono::steady_clock::now() > limit) { deferred.push_back(cur->second); continue; } const std::shared_ptr &child = cur->second; expireRetainedMessages(child.get(), limit, deferred, real_message_counter); if (child->isOrphaned()) { const Settings *settings = ThreadGlobals::getSettings(); if (child->getMessageSetAt() + settings->retainedMessageNodeLifetime < std::chrono::steady_clock::now()) this_node->children.erase(cur); } } } void SubscriptionStore::saveRetainedMessages(const std::string &filePath, bool in_background) { logger->logf(LOG_NOTICE, "Saving retained messages to '%s'", filePath.c_str()); std::deque> deferred; deferred.push_back(retainedMessagesRoot); RetainedMessagesDB db(filePath); db.openWrite(); size_t total_count = 0; for (; !deferred.empty(); deferred.pop_front()) { std::vector result; { RWLockGuard locker(&retainedMessagesRwlock); locker.rdlock(); std::shared_ptr node = deferred.front().lock(); if (!node) continue; const std::chrono::time_point limit = std::chrono::steady_clock::now() + std::chrono::milliseconds(5); getRetainedMessages(node.get(), result, limit, 10000, deferred); } if (globals->quitting && in_background) { logger->log(LOG_NOTICE) << "Aborted background saving of retained messages because we're quitting. It will be reinitiated."; db.dontSaveTmpFile(); return; } total_count += result.size(); logger->log(LOG_DEBUG) << "Collected batch of " << result.size() << " retained messages to save."; db.saveData(result); // Because we only do this operation in background threads or on exit, we don't have to requeue, so can just sleep. if (in_background && !deferred.empty()) std::this_thread::sleep_for(std::chrono::microseconds(100)); } logger->log(LOG_NOTICE) << "Done saving " << total_count << " retained messages."; } void SubscriptionStore::loadRetainedMessages(const std::string &filePath) { try { logger->logf(LOG_NOTICE, "Loading '%s'", filePath.c_str()); RetainedMessagesDB db(filePath); db.openRead(); size_t count = 0; size_t total_count = 0; do { std::list messages = db.readData(1000); count = messages.size(); total_count += count; for (RetainedMessage &rm : messages) { setRetainedMessage(rm.publish, rm.publish.getSubtopics()); } } while (count > 0); logger->log(LOG_NOTICE) << "Done loading " << total_count << " retained messages."; } catch (PersistenceFileCantBeOpened &ex) { logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); } } void SubscriptionStore::saveSessionsAndSubscriptions(const std::string &filePath) { logger->logf(LOG_NOTICE, "Saving sessions and subscriptions to '%s' in thread.", filePath.c_str()); const std::chrono::time_point start = std::chrono::steady_clock::now(); std::vector> sessionPointers; std::unordered_map> subscriptionCopies; { std::shared_lock session_locker(sessions_lock); sessionPointers.reserve(sessionsByIdConst.size()); for (const auto &pair : sessionsByIdConst) { sessionPointers.push_back(pair.second); } } subscriptionCopies = getSubscriptions(); const std::chrono::time_point doneCopying = std::chrono::steady_clock::now(); const std::chrono::milliseconds copyDuration = std::chrono::duration_cast(doneCopying - start); logger->log(LOG_INFO) << "Collected " << sessionPointers.size() << " sessions and " << subscriptionCopies.size() << " subscriptions to save, in " << copyDuration.count() << " ms (including defer time)."; SessionsAndSubscriptionsDB db(filePath); db.openWrite(); db.saveData(sessionPointers, subscriptionCopies); const std::chrono::time_point doneSaving = std::chrono::steady_clock::now(); const std::chrono::milliseconds saveDuration = std::chrono::duration_cast(doneSaving - doneCopying); logger->log(LOG_INFO) << "Saved " << sessionPointers.size() << " sessions and " << subscriptionCopies.size() << " subscriptions to '" << filePath << "', in " << saveDuration.count() << " ms."; } void SubscriptionStore::loadSessionsAndSubscriptions(const std::string &filePath) { try { logger->logf(LOG_NOTICE, "Loading '%s'", filePath.c_str()); SessionsAndSubscriptionsDB db(filePath); db.openRead(); SessionsAndSubscriptionsResult loadedData = db.readData(); std::unique_lock session_locker(sessions_lock); for (std::shared_ptr &session : loadedData.sessions) { sessionsById[session->getClientId()] = session; queueSessionRemoval(session); queueWillMessage(session->getWill(), session); } for (auto &pair : loadedData.subscriptions) { const std::string &topic = pair.first; const std::list &subs = pair.second; for (const SubscriptionForSerializing &sub : subs) { auto session_it = sessionsByIdConst.find(sub.clientId); if (session_it == sessionsByIdConst.end()) continue; addSubscription(session_it->second, splitTopic(topic), sub.qos, sub.noLocal, sub.retainAsPublished, sub.shareName, sub.subscriptionidentifier); } } } catch (PersistenceFileCantBeOpened &ex) { logger->logf(LOG_WARNING, "File '%s' is not there (yet)", filePath.c_str()); } } ssize_t RetainedMessageNode::addPayload(const Publish &publish) { std::lock_guard locker(this->messageSetMutex); const bool retained_found = message.operator bool(); ssize_t result = 0; if (retained_found) result--; if (publish.payload.empty()) { if (retained_found) message.reset(); return result; } if (message) { (*message) = publish; } else { message = std::make_unique(publish); } result++; messageSetAt = std::chrono::steady_clock::now(); return result; } /** * @brief RetainedMessageNode::getChildren return the children or nullptr when there are none. Const, so doesn't default construct. * @param subtopic * @return */ std::shared_ptr RetainedMessageNode::getChildren(const std::string &subtopic) const { auto it = children.find(subtopic); if (it != children.end()) return it->second; return std::shared_ptr(); } bool RetainedMessageNode::isOrphaned() const { return children.empty() && !message; } const std::chrono::time_point RetainedMessageNode::getMessageSetAt() const { return this->messageSetAt; } QueuedWill::QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session) : will(will), session(session) { } const std::weak_ptr &QueuedWill::getWill() const { return this->will; } std::shared_ptr QueuedWill::getSession() { return this->session.lock(); } ================================================ FILE: subscriptionstore.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef SUBSCRIPTIONSTORE_H #define SUBSCRIPTIONSTORE_H #include #include #include #include #include #include #include #include #include #include "client.h" #include "session.h" #include "retainedmessage.h" #include "logger.h" #include "subscription.h" #include "sharedsubscribers.h" struct ReceivingSubscriber { const std::shared_ptr session; const uint8_t qos; const bool retainAsPublished; const uint32_t subscriptionIdentifier = 0; public: ReceivingSubscriber(const std::weak_ptr &ses, uint8_t qos, bool retainAsPublished, const uint32_t subscriptionIdentifier); }; enum class AddSubscriptionType { Invalid, NewSubscription, ExistingSubscription }; class SubscriptionNode { friend class SubscriptionStore; std::unordered_map subscribers; std::unordered_map sharedSubscribers; std::shared_mutex lock; std::chrono::time_point lastUpdate; public: SubscriptionNode(); SubscriptionNode(const SubscriptionNode &node) = delete; SubscriptionNode(SubscriptionNode &&node) = delete; const std::unordered_map &getSubscribers() const; std::unordered_map &getSharedSubscribers(); AddSubscriptionType addSubscriber(const std::shared_ptr &subscriber, uint8_t qos, bool noLocal, bool retainAsPublished, const std::string &shareName, const uint32_t subscriptionIdentifier); void removeSubscriber(const std::shared_ptr &subscriber, const std::string &shareName); std::unordered_map> children; std::shared_ptr childrenPlus; std::shared_ptr childrenPound; int cleanSubscriptions(std::deque> &defferedLeafs, size_t &real_subscriber_count); bool empty() const; }; class RetainedMessageNode { friend class SubscriptionStore; std::unordered_map> children; std::mutex messageSetMutex; std::unique_ptr message; std::chrono::time_point messageSetAt; ssize_t addPayload(const Publish &publish); std::shared_ptr getChildren(const std::string &subtopic) const; bool isOrphaned() const; const std::chrono::time_point getMessageSetAt() const; }; class QueuedWill { std::weak_ptr will; std::weak_ptr session; public: QueuedWill(const std::shared_ptr &will, const std::shared_ptr &session); const std::weak_ptr &getWill() const; std::shared_ptr getSession(); }; struct DeferredRetainedMessageNodeDelivery { std::weak_ptr node; std::vector::const_iterator cur; std::vector::const_iterator end; bool poundMode = false; }; struct DeferredGetSubscription { const std::weak_ptr node; const std::string composedTopic; const bool root = false; DeferredGetSubscription(const std::shared_ptr &node, const std::string &composedTopic, const bool root); }; class SubscriptionStore { #ifdef TESTING friend class MainTests; #endif const std::shared_ptr root = std::make_shared(); const std::shared_ptr rootDollar = std::make_shared(); std::atomic subscriber_reserve = 1024; std::shared_mutex subscriptions_lock; std::shared_mutex sessions_lock; std::unordered_map> sessionsById; const std::unordered_map> &sessionsByIdConst; std::mutex queuedSessionRemovalsMutex; std::map>> queuedSessionRemovals; pthread_rwlock_t retainedMessagesRwlock = PTHREAD_RWLOCK_INITIALIZER; std::deque> deferredRetainedMessageNodeToPurge; size_t retainedMessageDeferredCounter = 0; const std::shared_ptr retainedMessagesRoot = std::make_shared(); const std::shared_ptr retainedMessagesRootDollar = std::make_shared(); /* * Retained messages are hard to count correctly because of how they expire, so this counter is not 100% * correct. But, it gives a good idea and is corrected periodically by tree maintenance. */ std::atomic retainedMessageCount = 0; /* * Subscription events and expiring sessions are hard to track so this counter is not 100% correct, but it * gives a good idea and it's corrected by the periodic tree maintenance. */ std::atomic subscriptionCount = 0; std::mutex pendingWillsMutex; std::map> pendingWillMessages; std::deque> deferredSubscriptionLeafsForPurging; size_t subscriptionDeferredCounter = 0; Logger *logger = Logger::getInstance(); static void publishNonRecursively( SubscriptionNode *this_node, std::vector &targetSessions, const std::string &senderClientId, const std::optional &fmq_client_group_id) noexcept; static void publishRecursively( std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, SubscriptionNode *this_node, std::vector &targetSessions, const std::string &senderClientId, const std::optional &fmq_client_group_id) noexcept; static void giveClientRetainedMessagesRecursively(std::vector::const_iterator cur_subtopic_it, std::vector::const_iterator end, const std::shared_ptr &this_node, bool poundMode, const std::shared_ptr &session, const uint8_t max_qos, const uint32_t subscription_identifier, const std::chrono::time_point &limit, std::deque &deferred, int &drop_count, int &processed_nodes_count); void getRetainedMessages(RetainedMessageNode *this_node, std::vector &outputList, const std::chrono::time_point &limit, const size_t limit_count, std::deque> &deferred) const; #ifdef TESTING std::vector getAllRetainedMessages(); #endif void getSubscriptions(SubscriptionNode *this_node, const std::string &composedTopic, bool root, std::unordered_map> &outputList, std::deque &deferred, const std::chrono::time_point limit) const; std::unordered_map> getSubscriptions(); static void expireRetainedMessages( RetainedMessageNode *this_node, const std::chrono::time_point &limit, std::deque> &deferred, size_t &real_message_counter); std::shared_ptr getDeepestNode(const std::vector &subtopics, bool abort_on_dead_end=false); void sendWill(const std::shared_ptr will, const std::shared_ptr session, const std::string &log); public: SubscriptionStore(); AddSubscriptionType addSubscription( const std::shared_ptr &session, const std::vector &subtopics, uint8_t qos, bool noLocal, bool retainAsPublished, const std::string &shareName, const uint32_t subscriptionIdentifier); void removeSubscription(const std::shared_ptr &session, const std::vector &subtopics, const std::string &shareName); std::shared_ptr getBridgeSession(std::shared_ptr &client); void registerClientAndKickExistingOne(std::shared_ptr &client); void registerClientAndKickExistingOne(std::shared_ptr &client, bool clean_start, uint16_t clientReceiveMax, uint32_t sessionExpiryInterval); std::shared_ptr lockSession(const std::string &clientid); void sendQueuedWillMessages(); void queueOrSendWillMessage( const std::shared_ptr &willMessage, const std::shared_ptr &session, bool forceNow = false); void queueWillMessage(const std::shared_ptr &willMessage, const std::shared_ptr &session); void queuePacketAtSubscribers( PublishCopyFactory ©Factory, const std::string &senderClientId, const std::optional &fmq_client_group_id, bool dollar = false); void giveClientRetainedMessages(const std::shared_ptr &ses, const std::vector &subscribeSubtopics, uint8_t max_qos, const uint32_t subscriptionIdentifier); void giveClientRetainedMessagesInitiateDeferred(const std::weak_ptr ses, const std::shared_ptr> subscribeSubtopicsCopy, std::shared_ptr> deferred, int &requeue_count, uint &total_node_count, uint8_t max_qos, const uint32_t subscription_identifier); void trySetRetainedMessages(const Publish &publish, const std::vector &subtopics); bool setRetainedMessage(const Publish &publish, const std::vector &subtopics, bool try_lock_fail=false); void removeSession(const std::shared_ptr &session); void removeExpiredSessionsClients(); bool hasDeferredSubscriptionTreeNodesForPurging(); bool purgeSubscriptionTree(); bool hasDeferredRetainedMessageNodesForPurging(); bool expireRetainedMessages(); size_t getRetainedMessageCount() const; uint64_t getSessionCount() const; size_t getSubscriptionCount(); void saveRetainedMessages(const std::string &filePath, bool in_background); void loadRetainedMessages(const std::string &filePath); void saveSessionsAndSubscriptions(const std::string &filePath); void loadSessionsAndSubscriptions(const std::string &filePath); void queueSessionRemoval(const std::shared_ptr &session); }; #endif // SUBSCRIPTIONSTORE_H ================================================ FILE: threaddata.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "threaddata.h" #include #include #include #include "globalstats.h" #include "subscriptionstore.h" #include "mainapp.h" #include "utils.h" #include "threadloop.h" KeepAliveCheck::KeepAliveCheck(const std::shared_ptr client) : client(client) { } QueuedRetainedMessage::QueuedRetainedMessage(const Publish &p, const std::vector &subtopics, const std::chrono::time_point limit) : p(p), subtopics(subtopics), limit(limit) { } ThreadData::ThreadData(int threadnr, const Settings &settings, const std::shared_ptr &pluginLoader, const std::weak_ptr mainApp) : epollfd(check(epoll_create(999))), pluginLoader(pluginLoader), settingsLocalCopy(settings), authentication(settingsLocalCopy), threadnr(threadnr), mMainApp(mainApp) { logger = Logger::getInstance(); taskEventFd = eventfd(0, EFD_NONBLOCK); if (taskEventFd < 0) throw std::runtime_error("Can't create eventfd."); disconnectingAllEventFd = eventfd(0, EFD_NONBLOCK); if (disconnectingAllEventFd < 0) throw std::runtime_error("Can't create eventfd."); randomish.seed(get_random_int()); std::array event_fds {taskEventFd, disconnectingAllEventFd, acceptQueue.event_fd.get()}; for (int efd : event_fds) { struct epoll_event ev{}; ev.data.fd = efd; ev.events = EPOLLIN; check(epoll_ctl(this->epollfd.get(), EPOLL_CTL_ADD, efd, &ev)); } } ThreadData::~ThreadData() { if (taskEventFd >= 0) close(taskEventFd); if (disconnectingAllEventFd >= 0) { close(disconnectingAllEventFd); disconnectingAllEventFd = -1; } } void ThreadData::quit() { running = false; } /** * @brief ThreadData::queuePublishStatsOnDollarTopic makes this thread publish the $SYS topics. * @param threads * * We want to do that in a thread because all authentication state is thread local. */ void ThreadData::queuePublishStatsOnDollarTopic(std::vector> &threads) { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::publishStatsOnDollarTopic, this, threads); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queueSendingQueuedWills() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::sendQueuedWills, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queueRemoveExpiredSessions() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::removeExpiredSessions, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queuePurgeSubscriptionTree() { std::shared_ptr subscriptionStore = globals->subscriptionStore; if (subscriptionStore->hasDeferredSubscriptionTreeNodesForPurging()) return; auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::purgeSubscriptionTree, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queueRemoveExpiredRetainedMessages() { std::shared_ptr subscriptionStore = globals->subscriptionStore; if (subscriptionStore->hasDeferredRetainedMessageNodesForPurging()) return; auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::removeExpiredRetainedMessages, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queueClientNextKeepAliveCheck(std::shared_ptr &client, bool keepRechecking) { assert(pthread_self() == thread_id); const std::chrono::seconds k = client->getSecondsTillKeepAliveAction(); if (k == std::chrono::seconds(0)) return; const std::chrono::seconds when = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch() + k); KeepAliveCheck check(client); check.recheck = keepRechecking; queuedKeepAliveChecks[when].push_back(check); } /** * @brief ThreadData::continuationOfAuthentication is logic that either needs to be called synchronously, or by the a plugin. * @param client * @param authResult * @param authMethod * @param returnData * * It always needs to run in the client's thread. For that, also see queueContinuationOfAuthentication(). */ void ThreadData::continuationOfAuthentication(std::shared_ptr &client, AuthResult authResult, const std::string &authMethod, const std::string &returnData) { assert(pthread_self() == thread_id); std::shared_ptr subscriptionStore = globals->subscriptionStore; if (authResult == AuthResult::auth_continue) { Auth auth(ReasonCodes::ContinueAuthentication, authMethod, returnData); MqttPacket pack(auth); client->writeMqttPacket(pack); } else if (authResult == AuthResult::success) { if (!client->getAuthenticated()) // First auth sends connack packets on success. { if (!returnData.empty()) client->addAuthReturnDataToStagedConnAck(returnData); const std::shared_ptr will = client->getStagedWill(); if (will && authentication.aclCheck(*will, will->payload, AclAccess::register_will) == AuthResult::success) { client->setWillFromStaged(); } subscriptionStore->registerClientAndKickExistingOne(client); client->sendConnackSuccess(); client->setAuthenticated(true); client->getSession()->sendAllPendingQosData(); client->handleAfterAsyncQueue(client); } else // Reauth (to authenticated clients) sends AUTH on success. { Auth auth(ReasonCodes::Success, authMethod, returnData); MqttPacket authPack(auth); client->writeMqttPacket(authPack); logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication successful.", client->getClientId().c_str(), client->getUsername().c_str()); } } else { if (!client->getAuthenticated()) // First auth sends connack with 'deny' code packets on failure. { const ReasonCodes reason = authResultToReasonCode(authResult); client->sendConnackDeny(reason); } else // Reauth (to authenticated clients) sends DISCONNECT on failure. { const ReasonCodes finalResult = authResultToReasonCode(authResult); Disconnect disconnect(client->getProtocolVersion(), finalResult); MqttPacket disconnectPack(disconnect); client->setDisconnectReason("Reauth denied"); client->setDisconnectStage(DisconnectStage::SendPendingAppData); client->writeMqttPacket(disconnectPack); logger->logf(LOG_NOTICE, "Client '%s', user '%s' reauthentication denied.", client->getClientId().c_str(), client->getUsername().c_str()); } } } void ThreadData::clientDisconnectEvent(const std::string &clientid) { authentication.clientDisconnected(clientid); } void ThreadData::bridgeReconnect() { acceptPendingBridges(); bool requeue = false; std::shared_ptr bridge; std::shared_ptr _threadData; for (auto &pair : clients.bridges) { bridge = pair.second; if (!bridge) continue; try { bridge->initSSL(false); std::shared_ptr client; std::shared_ptr session = bridge->session->lock(); if (session) client = session->makeSharedClient(); if (client) continue; if (!bridge->timeForNewReconnectAttempt()) { continue; } _threadData = bridge->threadData->lock(); if (!_threadData) continue; _threadData->publishBridgeState(bridge, false, "Connecting"); if (bridge->dnsResults.empty()) { // If no DNS query is pending, queue one. if (bridge->dns.idle()) { bridge->dns.query(bridge->c.address, bridge->c.inet_protocol, std::chrono::milliseconds(5000)); requeue = true; continue; } const std::list &results = bridge->dns.getResult(); // If empty, we're still waiting for the result but there is no error. if (results.empty()) { requeue = true; continue; } bridge->dnsResults = results; } FMQSockaddr addr = bridge->popDnsResult(); bridge->registerReconnect(); int sockfd = check(socket(addr.getFamily(), SOCK_STREAM, 0)); int flags = fcntl(sockfd, F_GETFL); fcntl(sockfd, F_SETFL, flags | O_NONBLOCK); FmqSsl clientSSL; if (bridge->c.tlsMode > BridgeTLSMode::None) { clientSSL = FmqSsl(*bridge->sslctx); if (!clientSSL) { logger->logf(LOG_ERR, "Problem creating SSL object for bridge. Closing client."); close(sockfd); continue; } clientSSL.set_fd(sockfd); } std::shared_ptr c = std::make_shared(ClientType::LocalBridge, sockfd, _threadData, std::move(clientSSL), ConnectionProtocol::Mqtt, HaProxyMode::Off, addr.getSockaddr(), settingsLocalCopy); c->addToEpoll(EPOLLIN | EPOLLOUT); c->setBridgeState(bridge); if (bridge->c.maxBufferSize) c->setMaxBufSizeOverride(bridge->c.maxBufferSize.value()); logger->logf(LOG_NOTICE, "Connecting brige: %s", c->repr().c_str()); clients.by_fd[sockfd] = c; queueClientNextKeepAliveCheck(c, true); if (session) { session->assignActiveConnection(c); c->assignSession(session); } else { std::shared_ptr subscriptionStore = globals->subscriptionStore; session = subscriptionStore->getBridgeSession(c); bridge->session = std::weak_ptr(session); } session->setLocalPrefix(bridge->c.local_prefix); session->setRemotePrefix(bridge->c.remote_prefix); c->connectToBridgeTarget(addr); } catch (std::exception &ex) { logger->log(LOG_WARNING) << "Error creating bridge '" << bridge->c.clientidPrefix << "': " << ex.what(); bridge->registerReconnect(); if (_threadData && bridge) _threadData->publishBridgeState(bridge, false, ex.what()); } } if (requeue) { auto f = std::bind(&ThreadData::bridgeReconnect, this); delayedTasks.addTask(f, 500); } } void ThreadData::queueContinuationOfAuthentication( const std::shared_ptr &client, AuthResult authResult, const std::string &authMethod, const std::string &returnData, const uint32_t delay_in_ms) { auto f = [client, authResult, authMethod, returnData] { client->setAsyncAuthResult({authResult, authMethod, returnData}); }; bool wake_up_needed = false; if (delay_in_ms == 0) { auto task_queue_locked = taskQueue.lock(); wake_up_needed = task_queue_locked->empty(); task_queue_locked->push_back(std::move(f)); } else { auto fdelayed = [this, f, delay_in_ms] { this->addDelayedTask(f, delay_in_ms); }; auto task_queue_locked = taskQueue.lock(); wake_up_needed = task_queue_locked->empty(); task_queue_locked->push_back(std::move(fdelayed)); } if (wake_up_needed) wakeUpThread(); } void ThreadData::clientDisconnectActions( bool authenticated, const std::string &clientid, std::shared_ptr &willPublish, std::shared_ptr &session, std::weak_ptr &bridgeState, const std::string &disconnect_reason) { std::shared_ptr store = globals->subscriptionStore; assert(store); publishBridgeState(bridgeState.lock(), false, disconnect_reason); if (willPublish) { store->queueOrSendWillMessage(willPublish, session); } if (session && session->getDestroyOnDisconnect()) { store->removeSession(session); } else { store->queueSessionRemoval(session); } if (authenticated) clientDisconnectEvent(clientid); } void ThreadData::queueClientDisconnectActions( bool authenticated, const std::string &clientid, std::shared_ptr &&willPublish, std::shared_ptr &&session, std::weak_ptr &&bridgeState, const std::string &disconnect_reason) { auto f = std::bind( &ThreadData::clientDisconnectActions, this, authenticated, clientid, std::move(willPublish), std::move(session), std::move(bridgeState), disconnect_reason); assert(!willPublish); assert(!session); auto task_queue_locked = taskQueue.lock(); task_queue_locked->push_back(std::move(f)); wakeUpThread(); } void ThreadData::queueBridgeReconnect() { auto f = std::bind(&ThreadData::bridgeReconnect, this); { auto task_queue_locked = taskQueue.lock(); task_queue_locked->push_back(f); } wakeUpThread(); } void ThreadData::publishStatsOnDollarTopic(std::vector> &threads) { size_t nrOfClients = 0; double receivedMessageCountPerSecond = 0; int64_t receivedMessageCount = 0; double sentMessageCountPerSecond = 0; int64_t sentMessageCount = 0; double mqttConnectCountPerSecond = 0; int64_t mqttConnectCount = 0; double aclReadChecksPerSecond = 0; int64_t aclReadCheckCount = 0; double aclWriteChecksPerSecond = 0; int64_t aclWriteCheckCount = 0; double aclSubscribeChecksPerSecond = 0; int64_t aclSubscribeCheckCount = 0; double aclRegisterWillChecksPerSecond = 0; int64_t aclRegisterWillCheckCount = 0; double retainedMessagesSetPerSecond = 0; int64_t retainedMessagesSetCount = 0; for (const std::shared_ptr &thread : threads) { nrOfClients += thread->getNrOfClients(); receivedMessageCountPerSecond += thread->receivedMessageCounter.getPerSecond(); receivedMessageCount += thread->receivedMessageCounter.get(); sentMessageCountPerSecond += thread->sentMessageCounter.getPerSecond(); sentMessageCount += thread->sentMessageCounter.get(); mqttConnectCountPerSecond += thread->mqttConnectCounter.getPerSecond(); mqttConnectCount += thread->mqttConnectCounter.get(); aclReadChecksPerSecond += thread->aclReadChecks.getPerSecond(); aclReadCheckCount += thread->aclReadChecks.get(); aclWriteChecksPerSecond += thread->aclWriteChecks.getPerSecond(); aclWriteCheckCount += thread->aclWriteChecks.get(); aclSubscribeChecksPerSecond += thread->aclSubscribeChecks.getPerSecond(); aclSubscribeCheckCount += thread->aclSubscribeChecks.get(); aclRegisterWillChecksPerSecond += thread->aclRegisterWillChecks.getPerSecond(); aclRegisterWillCheckCount += thread->aclRegisterWillChecks.get(); retainedMessagesSetPerSecond += thread->retainedMessageSet.getPerSecond(); retainedMessagesSetCount += thread->retainedMessageSet.get(); publishStat("$SYS/broker/threads/" + std::to_string(thread->threadnr) + "/drift/latest__ms", thread->driftCounter.getDrift().count()); publishStat("$SYS/broker/threads/" + std::to_string(thread->threadnr) + "/drift/moving_avg__ms", thread->driftCounter.getAvgDrift().count()); publishStat("$SYS/broker/threads/" + std::to_string(thread->threadnr) + "/retained_deferrals/count", thread->deferredRetainedMessagesSet.get()); publishStat("$SYS/broker/threads/" + std::to_string(thread->threadnr) + "/retained_deferrals/persecond", static_cast(thread->deferredRetainedMessagesSet.getPerSecond())); publishStat("$SYS/broker/threads/" + std::to_string(thread->threadnr) + "/retained_deferrals/timeout/count", thread->deferredRetainedMessagesSetTimeout.get()); publishStat("$SYS/broker/threads/" + std::to_string(thread->threadnr) + "/retained_deferrals/timeout/persecond", static_cast(thread->deferredRetainedMessagesSetTimeout.getPerSecond())); } publishStat("$SYS/broker/network/socketconnects/total", globals->stats.socketConnects.get()); publishStat("$SYS/broker/network/socketconnects/persecond", static_cast(globals->stats.socketConnects.getPerSecond())); publishStat("$SYS/broker/clients/mqttconnects/total", mqttConnectCount); publishStat("$SYS/broker/clients/mqttconnects/persecond", static_cast(mqttConnectCountPerSecond)); publishStat("$SYS/broker/clients/total", nrOfClients); publishStat("$SYS/broker/load/messages/received/total", receivedMessageCount); publishStat("$SYS/broker/load/messages/received/persecond", static_cast(receivedMessageCountPerSecond)); publishStat("$SYS/broker/load/messages/sent/total", sentMessageCount); publishStat("$SYS/broker/load/messages/sent/persecond", static_cast(sentMessageCountPerSecond)); publishStat("$SYS/broker/load/messages/set_retained/total", retainedMessagesSetCount); publishStat("$SYS/broker/load/messages/set_retained/persecond", static_cast(retainedMessagesSetPerSecond)); publishStat("$SYS/broker/load/aclchecks/read/total", aclReadCheckCount); publishStat("$SYS/broker/load/aclchecks/read/persecond", static_cast(aclReadChecksPerSecond)); publishStat("$SYS/broker/load/aclchecks/write/total", aclWriteCheckCount); publishStat("$SYS/broker/load/aclchecks/write/persecond", static_cast(aclWriteChecksPerSecond)); publishStat("$SYS/broker/load/aclchecks/subscribe/total", aclSubscribeCheckCount); publishStat("$SYS/broker/load/aclchecks/subscribe/persecond", static_cast(aclSubscribeChecksPerSecond)); publishStat("$SYS/broker/load/aclchecks/registerwill/total", aclRegisterWillCheckCount); publishStat("$SYS/broker/load/aclchecks/registerwill/persecond", static_cast(aclRegisterWillChecksPerSecond)); std::shared_ptr subscriptionStore = globals->subscriptionStore; publishStat("$SYS/broker/retained messages/count", subscriptionStore->getRetainedMessageCount()); publishStat("$SYS/broker/sessions/total", subscriptionStore->getSessionCount()); publishStat("$SYS/broker/subscriptions/count", subscriptionStore->getSubscriptionCount()); for (auto &pair : globals->stats.getExtras()) { Publish p(pair.first, pair.second, 0); publishWithAcl(p); } } void ThreadData::publishStat(const std::string &topic, int64_t n) { const std::string payload = std::to_string(n); Publish p(topic, payload, 0); publishWithAcl(p, true); } void ThreadData::publishBridgeState(std::shared_ptr bridge, bool connected, const std::optional &error) { if (!bridge) return; { const std::string payload = connected ? "1" : "0"; std::stringstream ss; ss << "$SYS/broker/bridge/" << bridge->c.clientidPrefix << "/connected"; const std::string topic = ss.str(); globals->stats.setExtra(topic, payload); Publish p(topic, payload, 0); publishWithAcl(p, true); } { const std::string message_on_no_error = connected ? "Connected" : "Not connected"; const std::string message = error.value_or(message_on_no_error); const std::string topic = "$SYS/broker/bridge/" + bridge->c.clientidPrefix + "/connection_status"; globals->stats.setExtra(topic, message); Publish p(topic, message, 0); publishWithAcl(p, true); } } void ThreadData::queueSettingRetainedMessage(const Publish &p, const std::vector &subtopics, const std::chrono::time_point limit) { assert(pthread_self() == thread_id); const bool wakeup_required = this->queuedRetainedMessages.empty(); this->queuedRetainedMessages.emplace_front(p, subtopics, limit); this->deferredRetainedMessagesSet.inc(1); if (wakeup_required) wakeUpThread(); } void ThreadData::publishWithAcl(Publish &pub, bool setRetain) { authentication.aclCheck(pub, pub.payload, AclAccess::write); PublishCopyFactory factory(&pub); std::shared_ptr subscriptionStore = globals->subscriptionStore; subscriptionStore->queuePacketAtSubscribers(factory, "", {}, true); if (setRetain) subscriptionStore->setRetainedMessage(pub, factory.getSubtopics()); } /** * @brief ThreadData::sendQueuedWills is not an operation per thread, but it's good practice to perform certain tasks in the worker threads, where * the thread-local globals work. */ void ThreadData::sendQueuedWills() { globals->subscriptionStore->sendQueuedWillMessages(); } /** * @brief ThreadData::removeExpiredSessions is not an operation per thread, but it's good practice to perform certain tasks in the worker threads, where * the thread-local globals work. */ void ThreadData::removeExpiredSessions() { globals->subscriptionStore->removeExpiredSessionsClients(); } void ThreadData::purgeSubscriptionTree() { bool done = globals->subscriptionStore->purgeSubscriptionTree(); if (!done) { auto f = std::bind(&ThreadData::purgeSubscriptionTree, this); addDelayedTask(f, 100); } } /** * @brief ThreadData::removeExpiredRetainedMessages is not an operation per thread, but it's good practice to perform certain tasks in the worker threads, where * the thread-local globals work. */ void ThreadData::removeExpiredRetainedMessages() { bool done = globals->subscriptionStore->expireRetainedMessages(); if (!done) { auto f = std::bind(&ThreadData::removeExpiredRetainedMessages, this); #ifdef TESTING addImmediateTask(f); #else addDelayedTask(f, 100); #endif } } void ThreadData::sendAllWills() { assert(pthread_self() == thread_id); for(auto &pair : clients.by_fd) { std::shared_ptr &c = pair.second; c->sendOrQueueWill(); } allWillsQueued = true; } void ThreadData::sendAllDisconnects() { assert(pthread_self() == thread_id); std::vector> clientsFound; clientsFound.reserve(clients.by_fd.size()); // We collect them first, so we are sure they are all added to 'disconnectingClients'. for(auto &pair : clients.by_fd) { clientsFound.push_back(pair.second); } for (std::shared_ptr &c : clientsFound) { serverInitiatedDisconnect(c, ReasonCodes::ServerShuttingDown, ""); } auto queued_collect_disconnecting_clients = [clientsFound, this](){ for (const std::shared_ptr &c : clientsFound) { this->disconnectingClients.push_back(c); } uint64_t one = 1; check(write(disconnectingAllEventFd, &one, sizeof(uint64_t))); }; addImmediateTask(queued_collect_disconnecting_clients); } void ThreadData::removeQueuedClients() { assert(pthread_self() == thread_id); for (const std::weak_ptr &c : clientsQueuedForRemoving) { std::shared_ptr client = c.lock(); if (!client) continue; const int fd = client->getFd(); auto pos = clients.by_fd.find(fd); if (pos != clients.by_fd.end() && pos->second == client) { clients.by_fd.erase(pos); } } clientsQueuedForRemoving.clear(); } void ThreadData::giveClient(std::shared_ptr &&client) { acceptQueue.giveClient(std::move(client)); } void ThreadData::giveBridge(std::shared_ptr &bridgeState) { if (!bridgeState) return; acceptQueue.giveBridge(std::move(bridgeState)); } void ThreadData::removeBridgeQueued(const BridgeConfig &bridgeConfig, const std::string &reason) { auto f = std::bind(&ThreadData::removeBridge, this, bridgeConfig, reason); addImmediateTask(f); } void ThreadData::removeBridge(const BridgeConfig &bridgeConfig, const std::string &reason) { assert(pthread_self() == thread_id); auto pos = clients.bridges.find(bridgeConfig.clientidPrefix); if (pos == clients.bridges.end()) return; std::shared_ptr bridge = pos->second; clients.bridges.erase(pos); if (!bridge) return; std::shared_ptr session = bridge->session->lock(); if (!session) return; std::shared_ptr client = session->makeSharedClient(); if (!client) return; if (!reason.empty()) client->setDisconnectReason(reason); publishBridgeState(bridge, false, reason); removeClientQueued(client); } void ThreadData::acceptPendingClients() { assert(pthread_self() == thread_id); std::vector> clientsToAccept = acceptQueue.takeClients(); for (std::shared_ptr &client : clientsToAccept) { if (!client) continue; const int fd = client->getFd(); // A non-repeating keep-alive check is for when clients do a TCP connect and then nothing else. queueClientNextKeepAliveCheck(client, false); client->addToEpoll(EPOLLIN); clients.by_fd[fd] = std::move(client); } } void ThreadData::acceptPendingBridges() { assert(pthread_self() == thread_id); std::vector> bridgesToAccept = acceptQueue.takeBridges(); for (std::shared_ptr &bridgeState : bridgesToAccept) { bridgeState->resetThreadOwners(); auto pos = clients.bridges.find(bridgeState->c.clientidPrefix); if (pos != clients.bridges.end()) { std::shared_ptr &existingState = pos->second; if (!existingState) existingState = bridgeState; else { if (existingState->c != bridgeState->c) { logger->log(LOG_NOTICE) << "Bridge '" << existingState->c.clientidPrefix << "' has changed. Reconnecting."; existingState = bridgeState; } } } else { clients.bridges[bridgeState->c.clientidPrefix] = bridgeState; } } } void ThreadData::deleteClients() { // Can have shared pointers to clients taskQueue.lock()->clear(); delayedTasks.clear(); clients.by_fd.clear(); clients.bridges.clear(); acceptQueue.clients.lock()->clear(); acceptQueue.bridges.lock()->clear(); } void ThreadData::setQueuedRetainedMessages() { if (this->queuedRetainedMessages.empty()) return; std::shared_ptr store = globals->subscriptionStore; if (!store) return; auto _pos = this->queuedRetainedMessages.begin(); while (_pos != this->queuedRetainedMessages.end()) { auto cur = _pos; _pos++; const bool try_lock_fail = cur->limit > std::chrono::steady_clock::now(); if (!try_lock_fail) { deferredRetainedMessagesSetTimeout.inc(1); } if (store->setRetainedMessage(cur->p, cur->subtopics, try_lock_fail)) { this->queuedRetainedMessages.erase(cur); continue; } else { wakeUpThread(); return; } } } bool ThreadData::queuedRetainedMessagesEmpty() const { return queuedRetainedMessages.empty(); } void ThreadData::clearQueuedRetainedMessages() { queuedRetainedMessages.clear(); } void ThreadData::queueInternalHeartbeat() { auto f = [this](std::chrono::time_point t){ this->driftCounter.update(t); if (this->driftCounter.getDrift() > settingsLocalCopy.maxEventLoopDrift) Logger::getInstance()->log(LOG_WARNING) << "Thread " << threadnr << " drift is: " << this->driftCounter.getDrift().count() << " ms"; }; { auto bound = std::bind(f, std::chrono::steady_clock::now()); auto task_queue_locked = taskQueue.lock(); task_queue_locked->push_back(bound); } wakeUpThread(); } std::shared_ptr ThreadData::getClient(int fd) { assert(pthread_self() == thread_id); auto pos = clients.by_fd.find(fd); if (pos == clients.by_fd.end()) return std::shared_ptr(); return pos->second; } void ThreadData::removeClientQueued(const std::shared_ptr &client) { // This is for same-thread calling, to avoid the calling side ending up with // the last reference on the shared pointer to client. assert(pthread_self() == thread_id); bool wakeUpNeeded = true; wakeUpNeeded = clientsQueuedForRemoving.empty(); clientsQueuedForRemoving.push_front(client); if (wakeUpNeeded) { auto f = std::bind(&ThreadData::removeQueuedClients, this); auto task_queue_locked = taskQueue.lock(); task_queue_locked->push_back(f); wakeUpThread(); } } void ThreadData::removeClientQueued(int fd) { auto f = [this, fd] { bool wakeUpNeeded = true; std::shared_ptr clientFound; auto client_it = clients.by_fd.find(fd); if (client_it != clients.by_fd.end()) { clientFound = client_it->second; } if (!clientFound) return; wakeUpNeeded = clientsQueuedForRemoving.empty(); clientsQueuedForRemoving.push_front(std::move(clientFound)); if (!wakeUpNeeded) return; auto f = std::bind(&ThreadData::removeQueuedClients, this); addImmediateTask(f); }; addImmediateTask(f); } void ThreadData::removeClient(std::shared_ptr client) { // This function is only for same-thread calling. assert(pthread_self() == thread_id); if (!client) return; client->setDisconnectStage(DisconnectStage::Now); auto pos = clients.by_fd.find(client->getFd()); if (pos != clients.by_fd.end() && pos->second == client) clients.by_fd.erase(pos); } void ThreadData::serverInitiatedDisconnect(std::shared_ptr &&client, ReasonCodes reason, const std::string &reason_text) { auto c = std::move(client); serverInitiatedDisconnect(c, reason, reason_text); } /** * @brief ThreadData::serverInitiatedDisconnect queues a disconnect packet and when the last bytes are written, the thread loop will disconnect it. * @param client * @param reason * @param reason_text * * Sending clients disconnect packets is only supported by MQTT >= 5, so in case of MQTT3, just close the connection. * * There is a chance that an client's TCP buffers are full (when the client is gone, for example) and epoll will not report the * fd as EPOLLOUT, which means the disconnect will not happen. It will then be up to the keep-alive mechanism to kick the client out. */ void ThreadData::serverInitiatedDisconnect(const std::shared_ptr &client, ReasonCodes reason, const std::string &reason_text) { if (!client) return; auto f = [client, reason, reason_text, this]() { if (!reason_text.empty()) client->setDisconnectReason(reason_text); ReasonCodes new_reason = reason; std::string code_to_text; if (client->getClientType() == ClientType::LocalBridge && reason == ReasonCodes::ServerShuttingDown) { code_to_text = "sending disconnect to server"; new_reason = ReasonCodes::Success; } if (code_to_text.empty()) { code_to_text = reasonCodeToString(reason); } client->setDisconnectReason(code_to_text); if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5) { client->setDisconnectStage(DisconnectStage::SendPendingAppData); Disconnect d(ProtocolVersion::Mqtt5, new_reason); client->writeMqttPacket(d); } else { client->setDisconnectStage(DisconnectStage::Now); removeClientQueued(client); } }; addImmediateTask(f); } void ThreadData::queueDoKeepAliveCheck() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::doKeepAliveCheck, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queueQuit() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::quit, this); task_queue_locked->push_back(f); authentication.setQuitting(); wakeUpThread(); } void ThreadData::queuePasswdFileReload() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&Authentication::loadMosquittoPasswordFile, &authentication); task_queue_locked->push_back(f); auto f2 = std::bind(&Authentication::loadMosquittoAclFile, &authentication); task_queue_locked->push_back(f2); wakeUpThread(); } size_t ThreadData::getNrOfClients() { return this->clientCount.load(std::memory_order_relaxed); } void ThreadData::updateNrOfClients() { this->clientCount.store(clients.by_fd.size(), std::memory_order_relaxed); } void ThreadData::queuepluginPeriodicEvent() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::pluginPeriodicEvent, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::pluginPeriodicEvent() { authentication.periodicEvent(); } void ThreadData::queueSendWills() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::sendAllWills, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::queueSendDisconnects() { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::sendAllDisconnects, this); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::pollExternalFd(int fd, uint32_t events, const std::weak_ptr &p) { int mode = EPOLL_CTL_MOD; auto pos = externalFds.find(fd); if (pos == externalFds.end()) { mode = EPOLL_CTL_ADD; } if (mode == EPOLL_CTL_ADD || !p.expired()) externalFds[fd] = p; struct epoll_event ev {}; ev.data.fd = fd; ev.events = events; check(epoll_ctl(this->epollfd.get(), mode, fd, &ev)); } void ThreadData::pollExternalRemove(int fd) { this->externalFds.erase(fd); if (epoll_ctl(this->epollfd.get(), EPOLL_CTL_DEL, fd, NULL) != 0) { Logger *logger = Logger::getInstance(); logger->logf(LOG_ERR, "Removing externally watched fd %d from epoll produced error: %s", fd, strerror(errno)); } } uint32_t ThreadData::addDelayedTask(std::function f, uint32_t delayMs) { return delayedTasks.addTask(f, delayMs); } void ThreadData::removeDelayedTask(uint32_t id) { delayedTasks.eraseTask(id); } void ThreadData::addImmediateTask(std::function f) { bool wakeupNeeded = true; { auto task_queue_locked = taskQueue.lock(); wakeupNeeded = task_queue_locked->empty(); task_queue_locked->push_back(f); } if (wakeupNeeded) { wakeUpThread(); } } void ThreadData::performAllImmediateTasks() { uint64_t eventfd_value = 0; if (read(taskEventFd, &eventfd_value, sizeof(uint64_t)) < 0) logger->log(LOG_ERROR) << "Error reading taskEventFd: " << strerror(errno); std::list> copiedTasks; { auto task_queue_locked = taskQueue.lock(); copiedTasks = std::move(*task_queue_locked); task_queue_locked->clear(); } for(auto &f : copiedTasks) { try { f(); } catch (std::exception &ex) { Logger *logger = Logger::getInstance(); logger->logf(LOG_ERR, "Error in queued task: %s", ex.what()); } } } void ThreadData::doKeepAliveCheck() { assert(pthread_self() == thread_id); logger->logf(LOG_DEBUG, "doKeepAliveCheck in thread %d", threadnr); const std::chrono::seconds now = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()); try { // Put clients to delete in here, to avoid holding two locks. std::vector> clientsToRemove; std::vector> clientsToRecheck; int slotsTotal = 0; int slotsProcessed = 0; int clientsChecked = 0; { logger->logf(LOG_DEBUG, "Checking clients with pending keep-alive checks in thread %d", threadnr); slotsTotal = queuedKeepAliveChecks.size(); auto pos = queuedKeepAliveChecks.begin(); while (pos != queuedKeepAliveChecks.end()) { const std::chrono::seconds &doCheckAt = pos->first; if (doCheckAt > now) break; slotsProcessed++; std::vector const &checks = pos->second; for (KeepAliveCheck const &k : checks) { std::shared_ptr client = k.client.lock(); if (client) { clientsChecked++; if (client->isOutgoingConnection() && client->getAuthenticated()) { client->writePing(); } if (client->keepAliveExpired()) { clientsToRemove.push_back(client); } else if (k.recheck) { clientsToRecheck.push_back(client); } } } pos = queuedKeepAliveChecks.erase(pos); } for (std::shared_ptr &c : clientsToRecheck) { c->resetBuffersIfEligible(); queueClientNextKeepAliveCheck(c, true); } } logger->logf(LOG_DEBUG, "Checked %d clients in %d of %d keep-alive slots in thread %d", clientsChecked, slotsProcessed, slotsTotal, threadnr); for (std::shared_ptr c : clientsToRemove) { c->setDisconnectReason("Keep-alive expired: " + c->getKeepAliveInfoString()); clients.by_fd.erase(c->getFd()); } } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error handling keep-alives: %s.", ex.what()); } } void ThreadData::initplugin() { authentication.loadMosquittoPasswordFile(); authentication.loadMosquittoAclFile(); authentication.loadPlugin(*pluginLoader); authentication.init(); authentication.securityInit(false); } void ThreadData::cleanupplugin() { authentication.cleanup(); } void ThreadData::reload(const Settings &settings) { assert(pthread_self() == thread_id); logger->logf(LOG_DEBUG, "Doing reload in thread %d", threadnr); try { // Because the auth plugin has a reference to it, it will also be updated. settingsLocalCopy = settings; for (auto &pair : clients.bridges) { std::shared_ptr b = pair.second; if (!b) continue; b->initSSL(true); } authentication.securityCleanup(true); authentication.securityInit(true); } catch (std::exception &ex) { logger->logf(LOG_ERR, "Error reloading auth plugin: %s. Security checks will now fail, because we don't know the status of the plugin anymore.", ex.what()); } } void ThreadData::queueReload(const Settings &settings) { auto task_queue_locked = taskQueue.lock(); auto f = std::bind(&ThreadData::reload, this, settings); task_queue_locked->push_back(f); wakeUpThread(); } void ThreadData::wakeUpThread() { uint64_t one = 1; check(write(taskEventFd, &one, sizeof(uint64_t))); } ThreadDataOwner::ThreadDataOwner(int threadnr, const Settings &settings, const std::shared_ptr &pluginLoader, const std::weak_ptr mainApp) : td(std::make_shared(threadnr, settings, pluginLoader, mainApp)) { } ThreadDataOwner::~ThreadDataOwner() { waitForQuit(); } void ThreadDataOwner::start() { this->thread = std::thread(&do_thread_work, td); pthread_t native = this->thread.native_handle(); td->thread_id = native; std::ostringstream threadName; threadName << "FlashMQ T " << td->threadnr; threadName.flush(); std::string name = threadName.str(); const char *c_str = name.c_str(); pthread_setname_np(native, c_str); } void ThreadDataOwner::waitForQuit() { if (thread.joinable()) thread.join(); } ThreadData *ThreadDataOwner::operator->() const { return td.get(); } std::shared_ptr ThreadDataOwner::getThreadData() const { return td; } ================================================ FILE: threaddata.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef THREADDATA_H #define THREADDATA_H #include #include #include #include #include #include #include #include #include #include #include #include #include "client.h" #include "plugin.h" #include "logger.h" #include "derivablecounter.h" #include "queuedtasks.h" #include "settings.h" #include "bridgeconfig.h" #include "driftcounter.h" #include "fdmanaged.h" #include "mutexowned.h" #include "clientacceptqueue.h" class MainApp; struct KeepAliveCheck { std::weak_ptr client; bool recheck = true; KeepAliveCheck(const std::shared_ptr client); }; struct QueuedRetainedMessage { const Publish p; const std::vector subtopics; const std::chrono::time_point limit; QueuedRetainedMessage(const Publish &p, const std::vector &subtopics, const std::chrono::time_point limit); }; struct Clients { std::unordered_map> by_fd; std::unordered_map> bridges; }; struct ThreadDataOwner { std::shared_ptr td; std::thread thread; ThreadDataOwner() = delete; ThreadDataOwner(const ThreadDataOwner &other) = delete; ThreadDataOwner(ThreadDataOwner &&other) = default; ThreadDataOwner(int threadnr, const Settings &settings, const std::shared_ptr &pluginLoader, const std::weak_ptr mainApp); ~ThreadDataOwner(); ThreadDataOwner &operator=(const ThreadDataOwner &other) = delete; ThreadDataOwner &operator=(ThreadDataOwner &&other) = delete; void start(); void waitForQuit(); ThreadData *operator->() const; std::shared_ptr getThreadData() const; }; class ThreadData { FdManaged epollfd; Clients clients; Logger *logger; std::forward_list> clientsQueuedForRemoving; std::map> queuedKeepAliveChecks; std::list queuedRetainedMessages; const std::shared_ptr pluginLoader; void reload(const Settings &settings); void wakeUpThread(); void doKeepAliveCheck(); void quit(); void publishStatsOnDollarTopic(std::vector> &threads); void publishStat(const std::string &topic, int64_t n); void sendQueuedWills(); void removeExpiredSessions(); void purgeSubscriptionTree(); void removeExpiredRetainedMessages(); void sendAllWills(); void sendAllDisconnects(); void clientDisconnectEvent(const std::string &clientid); void clientDisconnectActions( bool authenticated, const std::string &clientid, std::shared_ptr &willPublish, std::shared_ptr &session, std::weak_ptr &bridgeState, const std::string &disconnect_reason); void bridgeReconnect(); void removeQueuedClients(); void publishWithAcl(Publish &pub, bool setRetain=false); void removeBridge(const BridgeConfig &bridgeConfig, const std::string &reason); public: Settings settingsLocalCopy; // Is updated on reload, within the thread loop. Authentication authentication; bool deferThreadReady = false; bool running = true; bool finished = false; bool allWillsQueued = false; pthread_t thread_id = pthread_self(); // Gets set later, but this helps asserts dummy use in fuzzing and tests. int threadnr = 0; int taskEventFd = -1; int disconnectingAllEventFd = -1; std::atomic clientCount{0}; MutexOwned>> taskQueue; QueuedTasks delayedTasks; DriftCounter driftCounter; std::unordered_map> externalFds; std::vector> disconnectingClients; ClientAcceptQueue acceptQueue; std::weak_ptr mMainApp; DerivableCounter receivedMessageCounter; DerivableCounter sentMessageCounter; DerivableCounter mqttConnectCounter; DerivableCounter aclReadChecks; DerivableCounter aclWriteChecks; DerivableCounter aclSubscribeChecks; DerivableCounter aclRegisterWillChecks; DerivableCounter deferredRetainedMessagesSet; DerivableCounter deferredRetainedMessagesSetTimeout; DerivableCounter retainedMessageSet; std::minstd_rand randomish; ThreadData(int threadnr, const Settings &settings, const std::shared_ptr &pluginLoader, const std::weak_ptr mainApp); ThreadData(const ThreadData &other) = delete; ThreadData(ThreadData &&other) = delete; ~ThreadData(); int getEpollFd() const { return epollfd.get(); } void giveClient(std::shared_ptr &&client); void giveBridge(std::shared_ptr &bridgeState); void removeBridgeQueued(const BridgeConfig &bridgeConfig, const std::string &reason); std::shared_ptr getClient(int fd); void removeClientQueued(const std::shared_ptr &client); void removeClientQueued(int fd); void removeClient(std::shared_ptr client); void serverInitiatedDisconnect(std::shared_ptr &&client, ReasonCodes reason, const std::string &reason_text); void serverInitiatedDisconnect(const std::shared_ptr &client, ReasonCodes reason, const std::string &reason_text); void initplugin(); void cleanupplugin(); void queueReload(const Settings &settings); void queueDoKeepAliveCheck(); void queueQuit(); void queuePasswdFileReload(); void queuePublishStatsOnDollarTopic(std::vector> &threads); void queueSendingQueuedWills(); void queueRemoveExpiredSessions(); void queuePurgeSubscriptionTree(); void queueRemoveExpiredRetainedMessages(); void queueClientNextKeepAliveCheck(std::shared_ptr &client, bool keepRechecking); void continuationOfAuthentication(std::shared_ptr &client, AuthResult authResult, const std::string &authMethod, const std::string &returnData); void queueContinuationOfAuthentication( const std::shared_ptr &client, AuthResult authResult, const std::string &authMethod, const std::string &returnData, const uint32_t delay_in_ms); void queueClientDisconnectActions( bool authenticated, const std::string &clientid, std::shared_ptr &&willPublish, std::shared_ptr &&session, std::weak_ptr &&bridgeState, const std::string &disconnect_reason); void queueBridgeReconnect(); void publishBridgeState(std::shared_ptr bridge, bool connected, const std::optional &error); void queueSettingRetainedMessage(const Publish &p, const std::vector &subtopics, const std::chrono::time_point limit); void setQueuedRetainedMessages(); bool queuedRetainedMessagesEmpty() const; void clearQueuedRetainedMessages(); void acceptPendingClients(); void acceptPendingBridges(); void deleteClients(); size_t getNrOfClients(); void updateNrOfClients(); void queuepluginPeriodicEvent(); void pluginPeriodicEvent(); void queueSendWills(); void queueSendDisconnects(); void queueInternalHeartbeat(); void pollExternalFd(int fd, uint32_t events, const std::weak_ptr &p); void pollExternalRemove(int fd); uint32_t addDelayedTask(std::function f, uint32_t delayMs); void removeDelayedTask(uint32_t id); void addImmediateTask(std::function f); void performAllImmediateTasks(); }; #endif // THREADDATA_H ================================================ FILE: threadglobals.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "threadglobals.h" #include thread_local CheckedSharedPtr ThreadGlobals::threadData; thread_local Settings *ThreadGlobals::settings = nullptr; void ThreadGlobals::assignThreadData(const std::shared_ptr &threadData) { #ifndef TESTING assert(!static_cast(ThreadGlobals::threadData)); #endif ThreadGlobals::threadData = threadData; } const CheckedSharedPtr &ThreadGlobals::getThreadData() { return threadData; } void ThreadGlobals::assignSettings(Settings *settings) { #ifndef TESTING assert(ThreadGlobals::settings == nullptr || ThreadGlobals::settings == settings); #endif ThreadGlobals::settings = settings; } Settings *ThreadGlobals::getSettings() { return settings; } ================================================ FILE: threadglobals.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef THREADGLOBALS_H #define THREADGLOBALS_H #include "forward_declarations.h" #include "checkedsharedptr.h" class Authentication; class ThreadGlobals { static thread_local CheckedSharedPtr threadData; static thread_local Settings *settings; public: static void assign(Authentication *auth); static Authentication *getAuth(); static void assignThreadData(const std::shared_ptr &threadData); static const CheckedSharedPtr &getThreadData(); static void assignSettings(Settings *settings); static Settings *getSettings(); }; #endif // THREADGLOBALS_H ================================================ FILE: threadlocalutils.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifdef __SSE4_2__ #include "threadlocalutils.h" #include #include #include #include #include std::vector SimdUtils::splitTopic(const std::string &topic) { const unsigned s = topic.size(); if (s > 65535) throw std::runtime_error("Trying to split a string longer than the maximum MQTT topic length."); // Prefill the last 16 byte "line" with zeros _mm_store_si128((__m128i *)(topicCopy.begin() + (s & ~15u)), _mm_setzero_si128()); std::copy_n(topic.begin(), s, topicCopy.begin()); /* Add a trailing '/' * The reason is that we then always find a last / so a special case to handle the last subtopic is not necessary. * We can just stop searching when the location is of this trailing / * */ topicCopy[s] = '/'; std::vector output; output.reserve(16); const char * b = topicCopy.data(); const char * i = topicCopy.data(); const char * const e = topicCopy.data() + s; while (true) { __m128i loaded = _mm_loadu_si128((const __m128i *)i); unsigned index = _mm_cmpestri(slashes, 1, loaded, 16, 0); i += index; if (index < 16) { // This means that a '/' was found // i will point at the position where '/' was found output.emplace_back(b, i); if (i == e) break; ++i; // advance over the separator b = i; } } return output; } /** * @brief SimdUtils::isValidUtf8 checks UTF-8 validity 16 bytes at a time, using SSE 4.2. * @param s * @param alsoCheckInvalidPublishChars is for checking the presence of '#' and '+' which is not allowed in publishes. * @return */ bool SimdUtils::isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars) { const int len = s.size(); if (len + 16 > TOPIC_MEMORY_LENGTH) return false; std::memcpy(topicCopy.data(), s.c_str(), len); std::memset(&topicCopy.data()[len], 0x20, 16); // I fill out with spaces, as valid chars int n = 0; const char *i = topicCopy.data(); while (n < len) { const int len_left = len - n; assert(len_left > 0); __m128i loaded = _mm_loadu_si128((__m128i*)&i[n]); __m128i loaded_AND_non_ascii = _mm_and_si128(loaded, non_ascii_mask); if (alsoCheckInvalidPublishChars && (_mm_movemask_epi8(_mm_cmpeq_epi8(loaded, pound) || _mm_movemask_epi8(_mm_cmpeq_epi8(loaded, plus))))) return false; int index = _mm_cmpestri(non_ascii_mask, 1, loaded_AND_non_ascii, len_left, 0); n += index; // Checking multi-byte chars one by one. With some effort, this may be done using SIMD too, but the majority of uses will // have a minimum of multi byte chars. if (index < 16) { uint8_t x = i[n++]; int8_t char_len_left = 0; int8_t total_char_len = 0; uint32_t cur_code_point = 0; if((x & 0b11100000) == 0b11000000) // 2 byte char { char_len_left = 1; cur_code_point += ((x & 0b00011111) << 6); } else if((x & 0b11110000) == 0b11100000) // 3 byte char { char_len_left = 2; cur_code_point += ((x & 0b00001111) << 12); } else if((x & 0b11111000) == 0b11110000) // 4 byte char { char_len_left = 3; cur_code_point += ((x & 0b00000111) << 18); } else return false; total_char_len = char_len_left + 1; while (char_len_left > 0) { if (n >= len) return false; x = i[n++]; if((x & 0b11000000) != 0b10000000) // All remainer bytes of this code point needs to start with 10 return false; char_len_left--; cur_code_point += ((x & 0b00111111) << (6*char_len_left)); } // Check overlong values, to avoid having mulitiple representations of the same value. if (total_char_len == 2 && cur_code_point < 0x80) return false; else if (total_char_len == 3 && cur_code_point < 0x800) return false; else if (total_char_len == 4 && cur_code_point < 0x10000) return false; if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343 return false; if (cur_code_point >= 0x7F && cur_code_point <= 0x009F) return false; // Unicode spec: "Which code points are noncharacters?". if (cur_code_point >= 0xFDD0 && cur_code_point <= 0xFDEF) return false; // The last two code points of each of the 17 planes are the remaining 34 non-chars. const uint32_t plane = (cur_code_point & 0x1F0000) >> 16; const uint32_t last_16_bit = cur_code_point & 0xFFFF; if (plane <= 16 && (last_16_bit == 0xFFFE || last_16_bit == 0xFFFF)) return false; } else { if (_mm_movemask_epi8(_mm_cmplt_epi8(loaded, lowerBound))) return false; if (_mm_movemask_epi8(_mm_cmpgt_epi8(loaded, lastAsciiChar))) return false; } } return true; } #endif ================================================ FILE: threadlocalutils.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef THREADLOCALUTILS_H #define THREADLOCALUTILS_H #ifdef __SSE4_2__ #include #include #include #include #define TOPIC_MEMORY_LENGTH 65560 class SimdUtils { alignas(64) std::array topicCopy; __m128i slashes = _mm_set1_epi8('/'); __m128i lowerBound = _mm_set1_epi8(0x20); __m128i lastAsciiChar = _mm_set1_epi8(0x7E); __m128i non_ascii_mask = _mm_set1_epi8(0b10000000); __m128i pound = _mm_set1_epi8('#'); __m128i plus = _mm_set1_epi8('+'); public: SimdUtils() = default; std::vector splitTopic(const std::string &topic); bool isValidUtf8(const std::string &s, bool alsoCheckInvalidPublishChars = false); }; #endif #endif // THREADLOCALUTILS_H ================================================ FILE: threadlocked.h ================================================ #ifndef THREADLOCKED_H #define THREADLOCKED_H #include #include #include template class ThreadLocked { T d; #ifndef NDEBUG std::optional user; #endif public: ThreadLocked() = default; template ThreadLocked(Args... args) : d(args...) #ifndef NDEBUG ,user(std::this_thread::get_id()) #endif { } ThreadLocked &operator=(ThreadLocked &&other) { #ifndef NDEBUG assert(!user || user.value() == std::this_thread::get_id()); #endif d = std::forward(other.d); return *this; } ThreadLocked &operator=(T &&other) { #ifndef NDEBUG assert(!user || user.value() == std::this_thread::get_id()); #endif d = std::forward(other); return *this; } T &operator*() noexcept { #ifndef NDEBUG if (!user) user = std::this_thread::get_id(); assert(user.value() == std::this_thread::get_id()); #endif return d; } T *operator->() noexcept { #ifndef NDEBUG if (!user) user = std::this_thread::get_id(); assert(user.value() == std::this_thread::get_id()); #endif return &d; } void reset_thread() { #ifndef NDEBUG user.reset(); #endif } }; #endif // THREADLOCKED_H ================================================ FILE: threadloop.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "threadloop.h" #include "settings.h" #include "threadglobals.h" #include "mainapp.h" #include "utils.h" #include "exceptions.h" void do_thread_work(std::shared_ptr threadData) { maskAllSignalsCurrentThread(); int epoll_fd = threadData->getEpollFd(); ThreadGlobals::assignThreadData(threadData); ThreadGlobals::assignSettings(&threadData->settingsLocalCopy); std::vector events(MAX_EVENTS); std::vector packetQueueIn; Logger *logger = Logger::getInstance(); threadData->running = false; try { logger->logf(LOG_NOTICE, "Thread %d doing auth init.", threadData->threadnr); threadData->initplugin(); threadData->running = true; auto lockedMainApp = threadData->mMainApp.lock(); if (!threadData->deferThreadReady && lockedMainApp) { lockedMainApp->queueThreadInitDecrement(); } } catch(std::exception &ex) { logger->logf(LOG_ERR, "Error initializing auth back-end: %s", ex.what()); auto lockedMainApp = threadData->mMainApp.lock(); if (lockedMainApp) { lockedMainApp->quit(); } } std::vector ready_clients; while (threadData->running) { VectorClearGuard clear_ready_clients(ready_clients); const uint32_t next_task_delay = threadData->delayedTasks.getTimeTillNext(); const uint32_t epoll_wait_time = std::min(next_task_delay, 100); int fdcount = epoll_wait(epoll_fd, events.data(), events.size(), epoll_wait_time); if (__builtin_expect(epoll_wait_time == 0, 0)) { threadData->delayedTasks.performAll(); } threadData->updateNrOfClients(); if (fdcount < 0) { if (errno == EINTR) continue; logger->logf(LOG_ERR, "Problem waiting for fd: %s", strerror(errno)); } for (int i = 0; i < fdcount; i++) { struct epoll_event cur_ev = events[i]; int fd = cur_ev.data.fd; if (fd == threadData->taskEventFd) { threadData->performAllImmediateTasks(); try { threadData->setQueuedRetainedMessages(); } catch (std::exception &ex) { Logger *logger = Logger::getInstance(); logger->log(LOG_ERR) << "Error in setting queued retained messages: " << ex.what() << ". This shouldn't " << "happen and is likely a bug. Clearing the queue for safety."; threadData->clearQueuedRetainedMessages(); } } else if (fd == threadData->disconnectingAllEventFd) { /* * This block (and related) ensures all the clients get processed at least once when a shutdown * is initiated, so that all clients (that can), are sent a disconnect packet. */ uint64_t eventfd_value = 0; if (read(fd, &eventfd_value, sizeof(uint64_t)) < 0) logger->log(LOG_ERROR) << "Error reading event fd: " << strerror(errno); for (std::weak_ptr &wc : threadData->disconnectingClients) { std::shared_ptr c = wc.lock(); if (!c) continue; ready_clients.emplace_back(EPOLLOUT, std::move(c)); } threadData->disconnectingClients.clear(); threadData->queueQuit(); } else if (fd == threadData->acceptQueue.event_fd.get()) { threadData->acceptQueue.readFd(); threadData->acceptPendingClients(); } else { ready_clients.emplace_back(static_cast(cur_ev.events), threadData->getClient(fd)); if (__builtin_expect(!ready_clients.back().client, 0)) { ready_clients.pop_back(); // If the fd is not a client, it may be an externally monitored fd, from the plugin. auto pos = threadData->externalFds.find(fd); if (pos != threadData->externalFds.end()) { std::weak_ptr &p = pos->second; threadData->authentication.fdReady(fd, cur_ev.events, p); } } } } for (ReadyClient &ready_client : ready_clients) { std::shared_ptr &client = ready_client.client; if (!client) continue; try { if (client->getHaProxyStage() != HaProxyStage::DoneOrNotNeeded) { if (client->readHaProxyData() == HaProxyStage::DoneOrNotNeeded) { if (client->getHaProxyConnectionType() == HaProxyConnectionType::Local) { client->setDisconnectReason("HAProxy health check"); threadData->removeClient(client); } } continue; } if (client->isOutgoingConnection() && !client->getOutgoingConnectionEstablished()) { client->detectOutgoingConnectionEstablished(); continue; } if (ready_client.events & EPOLLHUP) { client->setDisconnectReason("Hang up"); } if (ready_client.events & EPOLLERR) { client->setDisconnectReasonFromSocketError(); threadData->removeClient(client); continue; } if (client->isSsl() && !client->isSslAccepted()) { client->startOrContinueSslHandshake(); continue; } if (__builtin_expect((ready_client.events & EPOLLOUT) && client->hasAsyncAuthResult(), 0)) { const std::unique_ptr auth = client->stealAsyncAuthResult(); if (auth) threadData->continuationOfAuthentication(client, auth->result, auth->authMethod, auth->authData); } if ((ready_client.events & EPOLLIN) || ((ready_client.events & EPOLLOUT) && client->getSslReadWantsWrite())) { VectorClearGuard vectorClear(packetQueueIn); const DisconnectStage disconnect = client->readFdIntoBuffer(); client->bufferToMqttPackets(packetQueueIn, client); for (MqttPacket &packet : packetQueueIn) { #ifdef TESTING if (client->onPacketReceived) client->onPacketReceived(packet); else #endif if (packet.handle(client) == HandleResult::Defer) { client->addPacketToAfterAsyncQueue(std::move(packet)); } } if (disconnect == DisconnectStage::Now) { client->setDisconnectReason("socket disconnect detected"); threadData->removeClient(client); continue; } } if ((ready_client.events & EPOLLOUT) || ((ready_client.events & EPOLLIN) && client->getSslWriteWantsRead())) { client->writeBufIntoFd(); if (client->getDisconnectStage() == DisconnectStage::Now) { threadData->removeClient(client); continue; } } } catch (ProtocolError &ex) { std::string reason("Protocol error: "); reason.append(ex.what()); client->setDisconnectReason(reason, LOG_WARNING); bool clientRemoved = true; try { if (!client->getAuthenticated()) { ConnAck connAck(client->getProtocolVersion(), ex.reasonCode); if (connAck.supported_reason_code && !client->isOutgoingConnection()) { MqttPacket p(connAck); client->writeMqttPacket(p); client->setDisconnectStage(DisconnectStage::SendPendingAppData); } else { clientRemoved = false; } } else if (client->getProtocolVersion() >= ProtocolVersion::Mqtt5) { Disconnect d(client->getProtocolVersion(), ex.reasonCode); MqttPacket p(d); client->writeMqttPacket(p); client->setDisconnectStage(DisconnectStage::SendPendingAppData); // When a client's TCP buffers are full (when the client is gone, for instance), EPOLLOUT will never be // reported. In those cases, the client is not removed; not until the keep-alive mechanism anyway. Is // that a problem? } else { clientRemoved = false; } } catch (std::exception &inner_ex) { clientRemoved = false; logger->log(LOG_ERROR) << "Exception when notyfing client about ProtocolError: " << inner_ex.what(); } if (!clientRemoved) { threadData->removeClient(client); } } catch(BadClientException &ex) { client->setDisconnectReason(ex.what(), ex.getLogLevel().value_or(-1)); threadData->removeClient(client); } catch(std::exception &ex) { client->setDisconnectReason(ex.what()); logger->log(LOG_ERR) << "Error handling client: " << ex.what() << ". Removing client " << client->repr(); threadData->removeClient(client); } } } try { logger->logf(LOG_NOTICE, "Thread %d doing auth cleanup.", threadData->threadnr); threadData->cleanupplugin(); threadData->deleteClients(); } catch(std::exception &ex) { logger->logf(LOG_ERR, "Error cleaning auth back-end: %s", ex.what()); } threadData->finished = true; } ReadyClient::ReadyClient(uint32_t events, std::shared_ptr &&client) : events(events), client(std::move(client)) { } ================================================ FILE: threadloop.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef THREADLOOP_H #define THREADLOOP_H #define MAX_EVENTS 65536 #include #include #include "forward_declarations.h" template class VectorClearGuard { std::vector &v; public: VectorClearGuard(std::vector &v) : v(v) { } ~VectorClearGuard() { v.clear(); } }; struct ReadyClient { uint32_t events; std::shared_ptr client; ReadyClient() = delete; ReadyClient(const ReadyClient &other) = delete; ReadyClient(ReadyClient &&other) = default; ReadyClient(uint32_t events, std::shared_ptr &&client); }; void do_thread_work(std::shared_ptr threadData); #endif // THREADLOOP_H ================================================ FILE: types.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "types.h" #include "mqtt5properties.h" #include "exceptions.h" #include "utils.h" SubscriptionOptionsByte::SubscriptionOptionsByte(uint8_t byte) : b(byte) { } SubscriptionOptionsByte::SubscriptionOptionsByte(uint8_t qos, bool noLocal, bool retainAsPublished, RetainHandling retainHandling) : b( qos | (static_cast(noLocal) << 2) | (static_cast(retainAsPublished) << 3) | (static_cast(retainHandling) << 4) ) { } bool SubscriptionOptionsByte::getNoLocal() const { uint8_t bit = b & 0x04; return static_cast(bit); } bool SubscriptionOptionsByte::getRetainAsPublished() const { uint8_t bit = b & 0x08; return static_cast(bit); } RetainHandling SubscriptionOptionsByte::getRetainHandling() const { uint8_t x = b & 0b00110000; x = x >> 4; if (x == 4) throw ProtocolError("Retain Handling value of 4 is protocol error", ReasonCodes::ProtocolError); return static_cast(x); } uint8_t SubscriptionOptionsByte::getQos() const { uint8_t qos = b & 0x03; return static_cast(qos); } ConnAck::ConnAck(const ProtocolVersion protVersion, ReasonCodes return_code, bool session_present) : protocol_version(protVersion), session_present(session_present) { if (this->protocol_version <= ProtocolVersion::Mqtt311) { this->supported_reason_code = true; ConnAckReturnCodes mqtt3_return = ConnAckReturnCodes::Accepted; switch (return_code) { case ReasonCodes::Success: mqtt3_return = ConnAckReturnCodes::Accepted; break; case ReasonCodes::UnsupportedProtocolVersion: mqtt3_return = ConnAckReturnCodes::UnacceptableProtocolVersion; break; case ReasonCodes::ClientIdentifierNotValid: mqtt3_return = ConnAckReturnCodes::ClientIdRejected; break; case ReasonCodes::ServerUnavailable: mqtt3_return = ConnAckReturnCodes::ServerUnavailable; break; case ReasonCodes::BadUserNameOrPassword: mqtt3_return = ConnAckReturnCodes::MalformedUsernameOrPassword; break; case ReasonCodes::NotAuthorized: mqtt3_return = ConnAckReturnCodes::NotAuthorized; break; default: /* * The MQTT 3 says: "If none of the return codes listed in Table 3.1 – Connect Return code values are deemed applicable, * then the Server MUST close the Network Connection without sending a CONNACK" * * But throwing an exception here was removed, because it was too late, and it could bite us when trying we * were already in error handling code. So, you have the option to check the bool, but if you don't, * it just sends ServerUnavailable. */ supported_reason_code = false; mqtt3_return = ConnAckReturnCodes::ServerUnavailable; } // [MQTT-3.2.2-4] if (mqtt3_return > ConnAckReturnCodes::Accepted) this->session_present = false; this->return_code = static_cast(mqtt3_return); } else { this->supported_reason_code = true; this->return_code = static_cast(return_code); // MQTT-3.2.2-6 if (this->return_code > 0) this->session_present = false; } } size_t ConnAck::getLengthWithoutFixedHeader() const { size_t result = 2; if (this->protocol_version >= ProtocolVersion::Mqtt5) { const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; result += proplen; } return result; } SubAck::SubAck(const ProtocolVersion protVersion, uint16_t packet_id, const std::list &subs_qos_reponses) : protocol_version(protVersion), packet_id(packet_id) { assert(!subs_qos_reponses.empty()); for (const ReasonCodes ack_code : subs_qos_reponses) { assert(protVersion >= ProtocolVersion::Mqtt311 || ack_code <= ReasonCodes::GrantedQoS2); ReasonCodes _ack_code = ack_code; if (protVersion < ProtocolVersion::Mqtt5 && ack_code >= ReasonCodes::UnspecifiedError) _ack_code = ReasonCodes::UnspecifiedError; // Equals Mqtt 3.1.1 'suback failure' responses.push_back(static_cast(_ack_code)); } } size_t SubAck::getLengthWithoutFixedHeader() const { size_t result = responses.size(); result += 2; // Packet ID if (this->protocol_version >= ProtocolVersion::Mqtt5) { const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; result += proplen; } return result; } Publish::Publish(const std::string &topic, const std::string &payload, uint8_t qos) : topic(topic), payload(payload), qos(qos) { } Publish::Publish( const std::string &topic, const std::string &payload, uint8_t qos, bool retain, uint32_t expiryInterval, const std::vector> *userProperties, const std::string *responseTopic, const std::string *correlationData, const std::string *contentType) : Publish(topic, payload, qos) { this->retain = retain; if (userProperties) { for (const std::pair &pair : *userProperties) { this->addUserProperty(pair.first, pair.second); } } if (expiryInterval) this->setExpireAfter(expiryInterval); if (responseTopic) this->responseTopic = *responseTopic; if (correlationData) this->correlationData = *correlationData; if (contentType) this->contentType = *contentType; } bool Publish::hasExpired() const { if (!expireInfo) return false; return expireInfo->expiresAt() < std::chrono::steady_clock::now(); } std::optional> Publish::expiresAt() const { if (!expireInfo) return {}; return expireInfo->expiresAt(); } std::vector> *Publish::getUserProperties() const { return userProperties.get(); } void Publish::addUserProperty(const std::string &key, const std::string &val) { if (!userProperties) userProperties = std::make_shared>>(); userProperties->emplace_back(key, val); } void Publish::addUserProperty(std::string &&key, std::string &&val) { if (!userProperties) userProperties = std::make_shared>>(); if (userProperties->size() > 50) throw ProtocolError("Trying to set more than 50 user properties. Likely a bad actor.", ReasonCodes::ImplementationSpecificError); userProperties->emplace_back(std::move(key), std::move(val)); } std::optional Publish::getFirstUserProperty(const std::string &key) const { if (!userProperties) return {}; auto pos = std::find_if(userProperties->begin(), userProperties->end(), [&key](const std::pair &p) { return (p.first == key); }); if (pos == userProperties->end()) return {}; return pos->second; } std::optional Publish::getPropertyBuilder() const { std::optional property_builder; if (expireInfo) non_optional(property_builder)->writeMessageExpiryInterval(expireInfo->expiresAfter.count()); if (correlationData) non_optional(property_builder)->writeCorrelationData(*correlationData); if (responseTopic) non_optional(property_builder)->writeResponseTopic(*responseTopic); if (contentType) non_optional(property_builder)->writeContentType(*contentType); if (payloadUtf8) non_optional(property_builder)->writePayloadFormatIndicator(1); if (topicAlias > 0) non_optional(property_builder)->writeTopicAlias(topicAlias); if (subscriptionIdentifier > 0) non_optional(property_builder)->writeSubscriptionIdentifier(subscriptionIdentifier); if (userProperties) non_optional(property_builder)->writeUserProperties(*userProperties); return property_builder; } void Publish::setExpireAfter(uint32_t s) { this->expireInfo.emplace(); this->expireInfo->expiresAfter = std::chrono::seconds(s); } void Publish::setExpireAfterToCeiling(std::chrono::seconds ceiling) { if (expireInfo) { this->expireInfo->expiresAfter = std::min(this->expireInfo->expiresAfter, ceiling); return; } this->expireInfo.emplace(); this->expireInfo->expiresAfter = ceiling; } const std::vector &Publish::getSubtopics() { if (!subtopics) subtopics = splitTopic(this->topic); return subtopics.value(); } void Publish::resplitTopic() { subtopics = splitTopic(this->topic); } WillPublish::WillPublish(const Publish &other) : Publish(other) { } void WillPublish::setQueuedAt() { this->isQueued = true; this->queuedAt = std::chrono::steady_clock::now(); } /** * @brief WillPublish::getQueuedAtAge gets the time ago in seconds when this will was queued. The time is set externally by the queue action. * @return * * This age is required when saving wills to disk, because the new will delay to set on load is not the original will delay, but minus the * elapsed time after queueing. */ uint32_t WillPublish::getQueuedAtAge() const { if (!isQueued) return 0; const std::chrono::seconds age = std::chrono::duration_cast(std::chrono::steady_clock::now() - this->queuedAt); return age.count(); } std::optional WillPublish::getPropertyBuilder() const { auto property_builder = Publish::getPropertyBuilder(); if (this->will_delay > 0) non_optional(property_builder)->writeWillDelay(this->will_delay); return property_builder; } PubResponse::PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id) : packet_type(packet_type), protocol_version(protVersion), reason_code(protVersion >= ProtocolVersion::Mqtt5 ? reason_code : ReasonCodes::Success), packet_id(packet_id) { assert(packet_type == PacketType::PUBACK || packet_type == PacketType::PUBREC || packet_type == PacketType::PUBREL || packet_type == PacketType::PUBCOMP); } uint8_t PubResponse::getLengthIncludingFixedHeader() const { return 2 + getRemainingLength(); } uint8_t PubResponse::getRemainingLength() const { // I'm leaving out the property length of 0: "If the Remaining Length is less than 4 there is no Property Length and the value of 0 is used" const uint8_t result = needsReasonCode() ? 3 : 2; return result; } /** * @brief "The Reason Code and Property Length can be omitted if the Reason Code is 0x00 (Success) and there are no Properties" * @return */ bool PubResponse::needsReasonCode() const { return this->protocol_version >= ProtocolVersion::Mqtt5 && this->reason_code > ReasonCodes::Success; } UnsubAck::UnsubAck(const ProtocolVersion protVersion, uint16_t packet_id, const int unsubCount) : protocol_version(protVersion), packet_id(packet_id), reasonCodes(unsubCount) { if (protVersion >= ProtocolVersion::Mqtt5) { // At this point, FlashMQ has no mechanism that would reject unsubscribes, so just marking them all as success. for(ReasonCodes &rc : this->reasonCodes) { rc = ReasonCodes::Success; } } } size_t UnsubAck::getLengthWithoutFixedHeader() const { size_t result = 2; // Start with room for packet id if (this->protocol_version >= ProtocolVersion::Mqtt5) { result += this->reasonCodes.size(); const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; result += proplen; } return result; } Disconnect::Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code) : protocolVersion(protVersion), reasonCode(reason_code) { } size_t Disconnect::getLengthWithoutFixedHeader() const { if (this->protocolVersion < ProtocolVersion::Mqtt5) return 0; size_t result = 1; const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; result += proplen; return result; } Auth::Auth(ReasonCodes reasonCode, const std::string &authMethod, const std::string &authData) : reasonCode(reasonCode), propertyBuilder(std::make_shared()) { if (!authMethod.empty()) propertyBuilder->writeAuthenticationMethod(authMethod); if (!authData.empty()) propertyBuilder->writeAuthenticationData(authData); } size_t Auth::getLengthWithoutFixedHeader() const { size_t result = 1; const size_t proplen = propertyBuilder ? propertyBuilder->getLength() : 1; result += proplen; return result; } Connect::Connect(ProtocolVersion protocolVersion, const std::string &clientid) : protocolVersion(protocolVersion), clientid(clientid) { } std::string_view Connect::getMagicString() const { if (protocolVersion <= ProtocolVersion::Mqtt31) return "MQIsdp"; else return "MQTT"; } Subscribe::Subscribe(const std::string &topic, uint8_t qos) : topic(topic), qos(qos) { } Unsubscribe::Unsubscribe(const std::string &topic) : topic(topic) { } std::chrono::time_point PublishExpireInfo::expiresAt() const { auto result = this->createdAt + this->expiresAfter; return result; } std::chrono::seconds PublishExpireInfo::getCurrentTimeToExpire() const { const auto now = std::chrono::steady_clock::now(); std::chrono::seconds delay = std::chrono::duration_cast(now - createdAt); std::chrono::seconds newExpireAfter = std::max(expiresAfter - delay, std::chrono::seconds(0)); return newExpireAfter; } ================================================ FILE: types.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef TYPES_H #define TYPES_H #include #include #include #include #include #include #include #include "forward_declarations.h" #include "nocopy.h" #define FMQ_CLIENT_GROUP_ID "fmq_client_group_id" enum class PacketType { Reserved = 0, CONNECT = 1, CONNACK = 2, PUBLISH = 3, PUBACK = 4, PUBREC = 5, PUBREL = 6, PUBCOMP = 7, SUBSCRIBE = 8, SUBACK = 9, UNSUBSCRIBE = 10, UNSUBACK = 11, PINGREQ = 12, PINGRESP = 13, DISCONNECT = 14, AUTH = 15 }; enum class ProtocolVersion { None = 0, Mqtt31 = 0x03, Mqtt311 = 0x04, Mqtt5 = 0x05 }; enum class Mqtt5Properties { None = 0, PayloadFormatIndicator = 1, MessageExpiryInterval = 2, ContentType = 3, ResponseTopic = 8, CorrelationData = 9, SubscriptionIdentifier = 11, SessionExpiryInterval = 17, AssignedClientIdentifier = 18, ServerKeepAlive = 19, AuthenticationMethod = 21, AuthenticationData = 22, RequestProblemInformation = 23, WillDelayInterval = 24, RequestResponseInformation = 25, ResponseInformation = 26, ServerReference = 28, ReasonString = 31, ReceiveMaximum = 33, TopicAliasMaximum = 34, TopicAlias = 35, MaximumQoS = 36, RetainAvailable = 37, UserProperty = 38, MaximumPacketSize = 39, WildcardSubscriptionAvailable = 40, SubscriptionIdentifierAvailable = 41, SharedSubscriptionAvailable = 42 }; /** * @brief The ConnAckReturnCodes enum are for MQTT3 */ enum class ConnAckReturnCodes { Accepted = 0, UnacceptableProtocolVersion = 1, ClientIdRejected = 2, ServerUnavailable = 3, MalformedUsernameOrPassword = 4, NotAuthorized = 5 }; /** * @brief The ReasonCodes enum are for MQTT5. */ enum class ReasonCodes { Success = 0, GrantedQoS0 = 0, GrantedQoS1 = 1, GrantedQoS2 = 2, DisconnectWithWill = 4, NoMatchingSubscribers = 16, NoSubscriptionExisted = 17, ContinueAuthentication = 24, ReAuthenticate = 25, UnspecifiedError = 128, MalformedPacket = 129, ProtocolError = 130, ImplementationSpecificError = 131, UnsupportedProtocolVersion = 132, ClientIdentifierNotValid = 133, BadUserNameOrPassword = 134, NotAuthorized = 135, ServerUnavailable = 136, ServerBusy = 137, Banned = 138, ServerShuttingDown = 139, BadAuthenticationMethod = 140, KeepAliveTimeout = 141, SessionTakenOver = 142, TopicFilterInvalid = 143, TopicNameInvalid = 144, PacketIdentifierInUse = 145, PacketIdentifierNotFound = 146, ReceiveMaximumExceeded = 147, TopicAliasInvalid = 148, PacketTooLarge = 149, MessageRateTooHigh = 150, QuotaExceeded = 151, AdministrativeAction = 152, PayloadFormatInvalid = 153, RetainNotSupported = 154, QosNotSupported = 155, UseAnotherServer = 156, ServerMoved = 157, SharedSubscriptionsNotSupported = 158, ConnectionRateExceeded = 159, MaximumConnectTime = 160, SubscriptionIdentifiersNotSupported = 161, WildcardSubscriptionsNotSupported = 162 }; /** * What is primarily important to know, is that in MQTT3 there is a defacto standard that the MSB in the * protocol version byte signifies the client is a bridge. This helps in loop prevention. MQTT5 no longer * does that, because it has 'subscription options' for that. */ enum class ClientType : uint8_t { Normal, Mqtt3DefactoBridge, LocalBridge }; enum class RetainHandling: uint8_t { SendRetainedMessagesAtSubscribe = 0, SendRetainedMessagesAtNewSubscribeOnly = 1, DoNotSendRetainedMessages = 2 }; struct SubscriptionOptionsByte { const uint8_t b = 0; SubscriptionOptionsByte(uint8_t byte); SubscriptionOptionsByte(uint8_t qos, bool noLocal, bool retainAsPublished, RetainHandling retainHandling); bool getNoLocal() const; bool getRetainAsPublished() const; RetainHandling getRetainHandling() const; uint8_t getQos() const; }; class ConnAck { public: ConnAck(const ProtocolVersion protVersion, ReasonCodes return_code, bool session_present=false); const ProtocolVersion protocol_version; uint8_t return_code; bool session_present = false; bool supported_reason_code = false; std::shared_ptr propertyBuilder; size_t getLengthWithoutFixedHeader() const; }; class SubAck { public: const ProtocolVersion protocol_version; const uint16_t packet_id; std::list responses; std::shared_ptr propertyBuilder; SubAck(const ProtocolVersion protVersion, uint16_t packet_id, const std::list &subs_qos_reponses); size_t getLengthWithoutFixedHeader() const; }; class UnsubAck { public: const ProtocolVersion protocol_version; const uint16_t packet_id; std::shared_ptr propertyBuilder; std::vector reasonCodes; UnsubAck(const ProtocolVersion protVersion, uint16_t packet_id, const int unsubCount); size_t getLengthWithoutFixedHeader() const; }; struct PublishExpireInfo { std::chrono::time_point createdAt = std::chrono::steady_clock::now(); std::chrono::seconds expiresAfter; std::chrono::time_point expiresAt() const; std::chrono::seconds getCurrentTimeToExpire() const; }; class Publish { public: std::string client_id; std::string username; std::string topic; std::string payload; private: NoCopy> subtopics; public: uint8_t qos = 0; bool retain = false; // Note: existing subscribers don't get publishes of retained messages with retain=1. [MQTT-3.3.1-9] uint16_t topicAlias = 0; bool skipTopic = false; bool payloadUtf8 = false; std::optional expireInfo; std::optional correlationData; std::optional responseTopic; std::optional contentType; uint32_t subscriptionIdentifier = 0; #ifdef TESTING uint32_t subscriptionIdentifierTesting = 0; // Clunky... #endif std::shared_ptr>> userProperties; Publish() = default; Publish(const std::string &topic, const std::string &payload, uint8_t qos); Publish(const std::string &topic, const std::string &payload, uint8_t qos, bool retain, uint32_t expiryInterval, const std::vector> *userProperties, const std::string *responseTopic, const std::string *correlationData, const std::string *contentType); bool hasExpired() const; std::optional> expiresAt() const; template T getAge() const { if (!expireInfo) return T(0); return std::chrono::duration_cast(std::chrono::steady_clock::now() - this->expireInfo->createdAt); } std::vector> *getUserProperties() const; void addUserProperty(const std::string &key, const std::string &val); void addUserProperty(std::string &&key, std::string &&val); std::optional getFirstUserProperty(const std::string &key) const; std::optional getPropertyBuilder() const; void setExpireAfter(uint32_t s); void setExpireAfterToCeiling(std::chrono::seconds s); const std::vector &getSubtopics(); void resplitTopic(); }; class WillPublish : public Publish { bool isQueued = false; std::chrono::time_point queuedAt; public: uint32_t will_delay = 0; WillPublish() = default; WillPublish(const Publish &other); void setQueuedAt(); uint32_t getQueuedAtAge() const; std::optional getPropertyBuilder() const; }; class PubResponse { public: PubResponse(const PubResponse &other) = delete; PubResponse(const ProtocolVersion protVersion, const PacketType packet_type, ReasonCodes reason_code, uint16_t packet_id); const PacketType packet_type; const ProtocolVersion protocol_version; const ReasonCodes reason_code; uint16_t packet_id; uint8_t getLengthIncludingFixedHeader() const; uint8_t getRemainingLength() const; bool needsReasonCode() const; }; class Disconnect { public: ProtocolVersion protocolVersion; ReasonCodes reasonCode; std::shared_ptr propertyBuilder; Disconnect(const ProtocolVersion protVersion, ReasonCodes reason_code); size_t getLengthWithoutFixedHeader() const; }; class Auth { public: ReasonCodes reasonCode; std::shared_ptr propertyBuilder; Auth(ReasonCodes reasonCode, const std::string &authMethod, const std::string &authData); size_t getLengthWithoutFixedHeader() const; }; struct Connect { const ProtocolVersion protocolVersion; bool clean_start = true; bool bridgeProtocolBit = false; std::string clientid; std::optional fmq_client_group_id; std::optional username; std::optional password; uint16_t keepalive = 60; uint32_t sessionExpiryInterval = 0; uint16_t maxIncomingTopicAliasValue = 0; std::shared_ptr will; std::optional authenticationMethod; std::optional authenticationData; Connect(ProtocolVersion protocolVersion, const std::string &clientid); std::string_view getMagicString() const; }; /** * @brief The Subscribe struct can be used to construct a mqtt packet of type 'subscribe'. * * It's rudimentary. Offically you can subscribe to multiple topics at once, but I have no need for that. */ struct Subscribe { std::string topic; uint8_t qos; bool noLocal = false; bool retainAsPublished = false; RetainHandling retainHandling = RetainHandling::SendRetainedMessagesAtSubscribe; Subscribe(const std::string &topic, uint8_t qos); }; /** * @brief The Unsubscribe struct can be used to construct a mqtt packet of type 'unsubscribe'. * * It's rudimentary. Offically you can unsubscribe to multiple topics at once, but I have no need for that. */ struct Unsubscribe { std::string topic; Unsubscribe(const std::string &topic); }; enum class PacketDropReason { Success, ClientError, ClientOffline, AuthDenied, BiggerThanPacketLimit, BufferFull, QoSTODOSomethingSomething }; #endif // TYPES_H ================================================ FILE: unscopedlock.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "unscopedlock.h" UnscopedLock::~UnscopedLock() { if (locked) { managedMutex.unlock(); } } UnscopedLock::UnscopedLock(std::mutex &mutex) : managedMutex(mutex) { } void UnscopedLock::lock() { managedMutex.lock(); locked = true; } ================================================ FILE: unscopedlock.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef UNSCOPEDLOCK_H #define UNSCOPEDLOCK_H #include /** * @brief The UnscopedLock class is a simple variety of the std::lock_guard or std::scoped_lock that allows optional locking using RAII. * * STL doesn't provide a similar feature, or am I missing something? You could do it with smart pointers, but I want to avoid having to * use the free store. */ class UnscopedLock { std::mutex &managedMutex; bool locked = false; public: ~UnscopedLock(); UnscopedLock(std::mutex &mutex); void lock(); }; #endif // UNSCOPEDLOCK_H ================================================ FILE: utils.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include #include "utils.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "exceptions.h" #include "cirbuf.h" #include "sslctxmanager.h" #include "logger.h" #include "evpencodectxmanager.h" #include "settings.h" #include "threadglobals.h" std::list split(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { std::list list; std::string::const_iterator start = input.begin(); const std::string::const_iterator end = input.end(); std::string::const_iterator sep_pos; while (list.size() < max && (sep_pos = std::find(start, end, sep)) != end) { if (start != sep_pos || keep_empty_parts) list.emplace_back(start, sep_pos); start = sep_pos + 1; // increase by length of separator } if (start != end || keep_empty_parts) list.emplace_back(start, end); return list; } bool strContains(const std::string &s, const std::string &needle) { return s.find(needle) != std::string::npos; } // Only necessary for tests at this point. bool isValidUtf8Generic(const char *s, bool alsoCheckInvalidPublishChars) { const std::string s2(s); return isValidUtf8Generic(s2, alsoCheckInvalidPublishChars); } bool isValidPublishPath(const std::string &s) { if (s.empty()) return false; for (const char c : s) { if (c == '#' || c == '+') return false; } return true; } bool isValidSubscribePath(const std::string &s) { bool wildcardAllowed = true; bool nextMustBeSlash = false; bool poundSeen = false; for (const char c : s) { if (!wildcardAllowed && (c == '+' || c == '#')) return false; if (nextMustBeSlash && c != '/') return false; if (poundSeen) return false; wildcardAllowed = c == '/'; nextMustBeSlash = c == '+'; poundSeen = c == '#'; } return true; } bool isValidShareName(const std::string &s) { if (s.empty()) return false; for (const char c : s) { if ((c == '#') | (c == '+') | (c == '/')) return false; } return true; } bool containsDangerousCharacters(const std::string &s) { if (s.empty()) return false; for (const char c : s) { switch(c) { case '#': return true; case '+': return true; } } return false; } std::vector splitTopic(const std::string &topic) { const Settings *settings = ThreadGlobals::getSettings(); const size_t limit{settings ? settings->maxTopicSplitDepth : std::numeric_limits::max()}; assert(limit > 0); #ifdef __SSE4_2__ thread_local static SimdUtils simdUtils; const std::vector output = simdUtils.splitTopic(topic); #else std::vector output; output.reserve(16); std::string::const_iterator start = topic.begin(); std::string::const_iterator sep_pos; do { sep_pos = std::find(start, topic.end(), '/'); output.emplace_back(start, sep_pos); start = sep_pos + 1; } while (sep_pos != topic.end()); #endif if (output.size() > limit) throw std::runtime_error("Topic/filter contains more components than 'max_topic_split_depth'"); return output; } std::vector splitToVector(const std::string &input, const char sep, size_t max, bool keep_empty_parts) { std::vector output; output.reserve(16); std::string::const_iterator start = input.begin(); std::string::const_iterator sep_pos; while (output.size() < max && (sep_pos = std::find(start, input.end(), sep)) != input.end()) { if (start != sep_pos || keep_empty_parts) output.emplace_back(start, sep_pos); start = sep_pos + 1; // increase by length of separator } if (start != input.end() || keep_empty_parts) output.emplace_back(start, input.end()); return output; } void ltrim(std::string &s) { s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); } void rtrim(std::string &s) { s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); } void rtrim(std::string &s, char c) { s.erase(std::find_if(s.rbegin(), s.rend(), [c](unsigned char ch) { return c != ch; }).base(), s.end()); } void trim(std::string &s) { ltrim(s); rtrim(s); } std::string &rtrim(std::string &s, unsigned char c) { s.erase(std::find_if(s.rbegin(), s.rend(), [=](unsigned char ch) { return (c != ch); }).base(), s.end()); return s; } bool startsWith(const std::string &s, const std::string &needle) { if (s.length() < needle.length()) return false; size_t i; for (i = 0; i < needle.length(); i++) { if (s[i] != needle[i]) return false; } return i == needle.length(); } bool endsWith(const std::string &s, const std::string &ending) { if (ending.size() > s.size()) return false; return std::equal(ending.rbegin(), ending.rend(), s.rbegin()); } std::string getSecureRandomString(const ssize_t len) { std::vector buf(len); const ssize_t random_len = len * 8; ssize_t actual_len = -1; while ((actual_len = getrandom(buf.data(), random_len, 0)) < 0) { if (errno == EINTR) continue; break; } if (actual_len < 0 || actual_len != random_len) { throw std::runtime_error("Error requesting random data"); } static constexpr std::string_view possibleCharacters{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrtsuvwxyz1234567890"}; static constexpr size_t possibleCharactersCount = possibleCharacters.size(); std::string randomString(buf.size(), '\0'); std::transform(buf.begin(), buf.end(), randomString.begin(), [&](uint64_t v){ return possibleCharacters[v % possibleCharactersCount];}); return randomString; } std::string str_tolower(std::string s) { std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c){ return std::tolower(c); }); return s; } bool stringTruthiness(const std::string &val) { std::string val_ = str_tolower(val); trim(val_); if (val_ == "yes" || val_ == "true" || val_ == "on") return true; if (val_ == "no" || val_ == "false" || val_ == "off") return false; throw ConfigFileException("Value '" + val + "' can't be converted to boolean"); } bool isPowerOfTwo(int n) { return (n != 0) && (n & (n - 1)) == 0; } std::vector base64Decode(const std::string &s) { if (s.length() % 4 != 0) throw std::runtime_error("Decoding invalid base64 string"); if (s.empty()) throw std::runtime_error("Trying to base64 decode an empty string."); const std::vector input(s.begin(), s.end()); std::vector tmp(input.size() + 16); int outl = 0; int outl_total = 0; EvpEncodeCtxManager b64_ctx; if (EVP_DecodeUpdate(b64_ctx.ctx.get(), tmp.data(), &outl, input.data(), input.size()) < 0) throw std::runtime_error("Failure in EVP_DecodeUpdate()"); outl_total += outl; if (EVP_DecodeFinal(b64_ctx.ctx.get(), tmp.data() + outl_total, &outl) < 0) throw std::runtime_error("Failure in EVP_DecodeFinal()"); std::vector result = make_vector(tmp, 0, outl_total); return result; } std::string base64Encode(const unsigned char *input, const int length) { const int pl = 4*((length+2)/3); std::vector output(pl + 1); const int ol = EVP_EncodeBlock(output.data(), input, length); if (pl != ol) throw std::runtime_error("Base64 encode error."); std::string result = make_string(output, 0, ol); return result; } // Using a separate ssl context to test, because it's the easiest way to load certs and key atomitcally. void testSsl(const std::string &fullchain, const std::string &privkey) { if (fullchain.empty() && privkey.empty()) throw ConfigFileException("No privkey and fullchain specified."); if (fullchain.empty()) throw ConfigFileException("No private key specified for fullchain"); if (privkey.empty()) throw ConfigFileException("No fullchain specified for private key"); if (getFileSize(fullchain) == 0) throw ConfigFileException(formatString("SSL 'fullchain' file '%s' is empty or invalid", fullchain.c_str())); if (getFileSize(privkey) == 0) throw ConfigFileException(formatString("SSL 'privkey' file '%s' is empty or invalid", privkey.c_str())); SslCtxManager sslCtx; if (SSL_CTX_use_certificate_chain_file(sslCtx.get(), fullchain.c_str()) != 1) { ERR_print_errors_cb(logSslError, NULL); throw ConfigFileException("Error loading full chain " + fullchain); } if (SSL_CTX_use_PrivateKey_file(sslCtx.get(), privkey.c_str(), SSL_FILETYPE_PEM) != 1) { ERR_print_errors_cb(logSslError, NULL); throw ConfigFileException("Error loading private key " + privkey); } if (SSL_CTX_check_private_key(sslCtx.get()) != 1) { ERR_print_errors_cb(logSslError, NULL); throw ConfigFileException("Private key and certificate don't match."); } } void testSslVerifyLocations(const std::string &caFile, const std::string &caDir, const std::string &error) { if (!caFile.empty() && getFileSize(caFile) <= 0) throw ConfigFileException(formatString("SSL 'ca_file' file '%s' is empty or invalid", caFile.c_str())); SslCtxManager sslCtx(TLS_client_method()); const char *ca_file = caFile.empty() ? nullptr : caFile.c_str(); const char *ca_dir = caDir.empty() ? nullptr : caDir.c_str(); if (ca_file == nullptr && ca_dir == nullptr) return; if (SSL_CTX_load_verify_locations(sslCtx.get(), ca_file, ca_dir) != 1) { ERR_print_errors_cb(logSslError, NULL); throw ConfigFileException(error); } } std::string formatString(const std::string str, ...) { constexpr const int bufsize = 512; char buf[bufsize + 1]; buf[bufsize] = 0; va_list valist; va_start(valist, str); int rc = vsnprintf(buf, bufsize, str.c_str(), valist); va_end(valist); if (rc < 0) return std::string(); size_t len = std::min(strlen(buf), bufsize); std::string result(buf, len); return result; } std::string_view dirnameOf(std::string_view path) { size_t pos = path.find_last_of("\\/"); return (std::string::npos == pos) ? "" : path.substr(0, pos); } size_t getFileSize(const std::string &path) { struct stat statbuf {}; if (stat(path.c_str(), &statbuf) < 0) throw std::runtime_error("Can't get size of " + path); if (statbuf.st_size < 0) throw std::runtime_error("Size of " + path + " negative?"); return statbuf.st_size; } uint64_t getFreeSpace(const std::string &path) { struct statvfs statbuf {}; if (statvfs(path.c_str(), &statbuf) < 0) throw std::runtime_error("Can't get free space of " + path); const uint64_t result {statbuf.f_bsize * statbuf.f_bfree}; return result; } /** * @brief Get socket family from addr that works with strict type aliasing. * * Disgruntled: if type aliasing rules are so strict, why is there no library * function to obtain the family from a sockaddr...? */ sa_family_t getFamilyFromSockAddr(const sockaddr *addr) { if (!addr) return AF_UNSPEC; sockaddr tmp {}; std::memcpy(&tmp, addr, sizeof(tmp)); return tmp.sa_family; } std::string sockaddrToString(const sockaddr *addr) { if (!addr) return "[unknown address]"; std::array buf; std::fill(buf.begin(), buf.end(), 0); const int family = getFamilyFromSockAddr(addr); const char *rc = nullptr; if (family == AF_INET) { struct sockaddr_in ipv4sockAddr; std::memcpy(&ipv4sockAddr, addr, sizeof(struct sockaddr_in)); rc = inet_ntop(family, &ipv4sockAddr.sin_addr, buf.data(), buf.size()); } else if (family == AF_INET6) { struct sockaddr_in6 ipv6sockAddr; std::memcpy(&ipv6sockAddr, addr, sizeof(struct sockaddr_in6)); rc = inet_ntop(family, &ipv6sockAddr.sin6_addr, buf.data(), buf.size()); } else if (family == AF_UNIX) { return std::string("unix"); } if (rc == nullptr) return "[unknown address]"; std::string remote_addr(rc); return remote_addr; } std::string protocolVersionString(ProtocolVersion p) { switch (p) { case ProtocolVersion::None: return "none"; case ProtocolVersion::Mqtt31: return "3.1"; case ProtocolVersion::Mqtt311: return "3.1.1"; case ProtocolVersion::Mqtt5: return "5.0"; default: return "unknown"; } } /** * @brief Returns the edit distance between the two given strings * * This function uses the Wagner–Fischer algorithm to calculate the Levenshtein * distance between two strings: the total number of insertions, swaps, and * deletions that are needed to transform the one into the other. */ unsigned int distanceBetweenStrings(const std::string &stringA, const std::string &stringB) { // The matrix contains the distances between the substrings. // You can find a description of the algorithm online. // // Roughly: // // line_a: "dog" // line_b: "horse" // // --> // | | # | d | o | g // V --+---+---+---+---+---+---+--- // # | P // h | // o | // r | Q X Y // s | // e | Z // // P = [0, 0] = the distance from "" to "" (which is 0) // Q = [0, 3] = the distance from "" to "hor" (which is 3, all inserts) // X = [1, 3] = the distance from "d" to "hor" (which is 3, 1 swap and 2 inserts) // Y = [3, 3] = the distance from "dog" to "hor" (which is 2, both swaps) // Z = [3, 5] = the distance from "dog" to "horse" (which is 4, two swaps and 2 inserts) // // The matrix does not have to be square, the dimensions depends on the inputs // // the position within stringA should always be referred to as x // the position within stringB should always be referred to as y using mymatrix = std::vector>; // +1 because we also need to store the length from the empty strings int width = stringA.size() + 1; int height = stringB.size() + 1; mymatrix distances(width, std::vector(height)); // We know that the distance from the substrings of line_a to "" // is equal to the length of the substring of line_a for (int x = 0; x < width; x++) { distances.at(x).at(0) = x; } // We know that the distance from "" to the substrings of line_b // is equal to the length of the substring of line_b for (int y = 0; y < height; y++) { distances.at(0).at(y) = y; } // Now all we do is to fill out the rest of the matrix, easy peasy // note we start at 1 because the top row and left column have already been calculated for (int x = 1; x < width; x++) { for (int y = 1; y < height; y++) { if (stringA.at(x - 1) == stringB.at(y - 1)) { // the letters in both words are the same: we can travel from the top-left for free to the current state distances.at(x).at(y) = distances.at(x - 1).at(y - 1); } else { // let's calculate the different costs and pick the cheapest option // We use "+1" for all costs since they are all equally likely in our case int dinstance_with_deletion = distances.at(x).at(y - 1) + 1; int dinstance_with_insertion = distances.at(x - 1).at(y) + 1; int dinstance_with_substitution = distances.at(x - 1).at(y - 1) + 1; distances.at(x).at(y) = std::min({dinstance_with_deletion, dinstance_with_insertion, dinstance_with_substitution}); } } } return distances.at(width - 1).at(height - 1); // the last cell contains our answer } uint32_t ageFromTimePoint(const std::chrono::time_point &point) { auto duration = std::chrono::steady_clock::now() - point; auto seconds = std::chrono::duration_cast(duration); return seconds.count(); } std::chrono::time_point timepointFromAge(const uint32_t age) { std::chrono::seconds seconds(age); std::chrono::time_point newPoint = std::chrono::steady_clock::now() - seconds; return newPoint; } ReasonCodes authResultToReasonCode(AuthResult authResult) { switch (authResult) { case AuthResult::success: return ReasonCodes::Success; case AuthResult::auth_method_not_supported: return ReasonCodes::BadAuthenticationMethod; case AuthResult::acl_denied: case AuthResult::login_denied: return ReasonCodes::NotAuthorized; case AuthResult::server_not_available: return ReasonCodes::ServerUnavailable; case AuthResult::error: return ReasonCodes::UnspecifiedError; case AuthResult::auth_continue: return ReasonCodes::ContinueAuthentication; default: return ReasonCodes::UnspecifiedError; } } int maskAllSignalsCurrentThread() { sigset_t set; sigfillset(&set); int r = pthread_sigmask(SIG_SETMASK, &set, NULL); return r; } void parseSubscriptionShare(std::vector &subtopics, std::string &shareName, std::string &topic) { if (subtopics.size() < 3) { if (subtopics.size() == 2 && subtopics[0] == "$share") { throw ProtocolError("Topic filter for shared subscription cannot be empty.", ReasonCodes::ProtocolError); } return; } const std::string &match = subtopics[0]; if (match != "$share") return; const std::string _shareName = subtopics[1]; if (!isValidShareName(_shareName)) throw ProtocolError("Invalid character in share name", ReasonCodes::ProtocolError); for (int i = 0; i < 2; i++) { subtopics.erase(subtopics.begin()); } if (!(subtopics.size() > 1 || (subtopics.size() == 1 && !subtopics[0].empty()) )) throw ProtocolError("The / character after a shared subscription name MUST be followed by a topic filter.", ReasonCodes::ProtocolError); topic.clear(); for(const std::string &s : subtopics) { if (!topic.empty()) topic.append("/"); topic.append(s); } shareName = _shareName; } std::string timestampWithMillis() { const auto now = std::chrono::system_clock::now(); const auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; const time_t timer = std::chrono::system_clock::to_time_t(now); struct tm my_tm {}; struct tm *my_tm_result = localtime_r(&timer, &my_tm); if (!my_tm_result) return std::string("localtime-failed"); std::ostringstream oss; oss << std::put_time(my_tm_result, "%Y-%m-%d %H:%M:%S"); oss << '.' << std::setfill('0') << std::setw(3) << ms.count(); return oss.str(); } void exceptionOnNonMqtt(const std::vector &data) { const std::string str(data.data(), data.size()); std::istringstream is(str); bool firstLine = true; std::string line; while (std::getline(is, line)) { if (firstLine) { firstLine = false; if (strContains(line, "HTTP")) { throw BadClientException("This looks like HTTP traffic."); } } } } /** * # is 0 * one/# is 1 * one/two/+/four is 2 * one/two/three is 65535 * */ uint16_t getFirstWildcardDepth(const std::vector &subtopics) { uint16_t result = std::numeric_limits::max(); uint16_t i = 0; for (const std::string &s : subtopics) { if (s == "#" || s == "+") { result = i; break; } i++; } return result; } std::string reasonCodeToString(ReasonCodes code) { switch (code) { case ReasonCodes::Success: return "Success"; //case ReasonCodes::GrantedQoS0: // return "GrantedQoS0"; case (ReasonCodes::GrantedQoS1): return "GrantedQoS1"; case (ReasonCodes::GrantedQoS2): return "GrantedQoS2"; case (ReasonCodes::DisconnectWithWill): return "DisconnectWithWill"; case (ReasonCodes::NoMatchingSubscribers): return "NoMatchingSubscribers"; case (ReasonCodes::NoSubscriptionExisted): return "NoSubscriptionExisted"; case (ReasonCodes::ContinueAuthentication): return "ContinueAuthentication"; case (ReasonCodes::ReAuthenticate): return "ReAuthenticate"; case (ReasonCodes::UnspecifiedError): return "UnspecifiedError"; case (ReasonCodes::MalformedPacket): return "MalformedPacket"; case (ReasonCodes::ProtocolError): return "ProtocolError"; case (ReasonCodes::ImplementationSpecificError): return "ImplementationSpecificError"; case (ReasonCodes::UnsupportedProtocolVersion): return "UnsupportedProtocolVersion"; case (ReasonCodes::ClientIdentifierNotValid): return "ClientIdentifierNotValid"; case (ReasonCodes::BadUserNameOrPassword): return "BadUserNameOrPassword"; case (ReasonCodes::NotAuthorized): return "NotAuthorized"; case (ReasonCodes::ServerUnavailable): return "ServerUnavailable"; case (ReasonCodes::ServerBusy): return "ServerBusy"; case (ReasonCodes::Banned): return "Banned"; case (ReasonCodes::ServerShuttingDown): return "ServerShuttingDown"; case (ReasonCodes::BadAuthenticationMethod): return "BadAuthenticationMethod"; case (ReasonCodes::KeepAliveTimeout): return "KeepAliveTimeout"; case (ReasonCodes::SessionTakenOver): return "SessionTakenOver"; case (ReasonCodes::TopicFilterInvalid): return "TopicFilterInvalid"; case (ReasonCodes::TopicNameInvalid): return "TopicNameInvalid"; case (ReasonCodes::PacketIdentifierInUse): return "PacketIdentifierInUse"; case (ReasonCodes::PacketIdentifierNotFound): return "PacketIdentifierNotFound"; case (ReasonCodes::ReceiveMaximumExceeded): return "ReceiveMaximumExceeded"; case (ReasonCodes::TopicAliasInvalid): return "TopicAliasInvalid"; case (ReasonCodes::PacketTooLarge): return "PacketTooLarge"; case (ReasonCodes::MessageRateTooHigh): return "MessageRateTooHigh"; case (ReasonCodes::QuotaExceeded): return "QuotaExceeded"; case (ReasonCodes::AdministrativeAction): return "AdministrativeAction"; case (ReasonCodes::PayloadFormatInvalid): return "PayloadFormatInvalid"; case (ReasonCodes::RetainNotSupported): return "RetainNotSupported"; case (ReasonCodes::QosNotSupported): return "QosNotSupported"; case (ReasonCodes::UseAnotherServer): return "UseAnotherServer"; case (ReasonCodes::ServerMoved): return "ServerMoved"; case (ReasonCodes::SharedSubscriptionsNotSupported): return "SharedSubscriptionsNotSupported"; case (ReasonCodes::ConnectionRateExceeded): return "ConnectionRateExceeded"; case (ReasonCodes::MaximumConnectTime): return "MaximumConnectTime"; case (ReasonCodes::SubscriptionIdentifiersNotSupported): return "SubscriptionIdentifiersNotSupported"; case (ReasonCodes::WildcardSubscriptionsNotSupported): return "WildcardSubscriptionsNotSupported"; default: break; } std::ostringstream oss; oss << static_cast(code); return oss.str(); } std::string packetTypeToString(PacketType ptype) { switch (ptype) { case (PacketType::Reserved): return "Reserved"; case (PacketType::CONNECT): return "CONNECT"; case (PacketType::CONNACK): return "CONNACK"; case (PacketType::PUBLISH): return "PUBLISH"; case (PacketType::PUBACK): return "PUBACK"; case (PacketType::PUBREC): return "PUBREC"; case (PacketType::PUBREL): return "PUBREL"; case (PacketType::PUBCOMP): return "PUBCOMP"; case (PacketType::SUBSCRIBE): return "SUBSCRIBE"; case (PacketType::SUBACK): return "SUBACK"; case (PacketType::UNSUBSCRIBE): return "UNSUBSCRIBE"; case (PacketType::UNSUBACK): return "UNSUBACK"; case (PacketType::PINGREQ): return "PINGREQ"; case (PacketType::PINGRESP): return "PINGRESP"; case (PacketType::DISCONNECT): return "DISCONNECT"; case (PacketType::AUTH): return "AUTH"; default: break; } std::ostringstream oss; oss << static_cast(ptype); return oss.str(); } std::string propertyToString(Mqtt5Properties p) { switch (p) { case (Mqtt5Properties::None): return "None"; case (Mqtt5Properties::PayloadFormatIndicator): return "PayloadFormatIndicator"; case (Mqtt5Properties::MessageExpiryInterval): return "MessageExpiryInterval"; case (Mqtt5Properties::ContentType): return "ContentType"; case (Mqtt5Properties::ResponseTopic): return "ResponseTopic"; case (Mqtt5Properties::CorrelationData): return "CorrelationData"; case (Mqtt5Properties::SubscriptionIdentifier): return "SubscriptionIdentifier"; case (Mqtt5Properties::SessionExpiryInterval): return "SessionExpiryInterval"; case (Mqtt5Properties::AssignedClientIdentifier): return "AssignedClientIdentifier"; case (Mqtt5Properties::ServerKeepAlive): return "ServerKeepAlive"; case (Mqtt5Properties::AuthenticationMethod): return "AuthenticationMethod"; case (Mqtt5Properties::AuthenticationData): return "AuthenticationData"; case (Mqtt5Properties::RequestProblemInformation): return "RequestProblemInformation"; case (Mqtt5Properties::WillDelayInterval): return "WillDelayInterval"; case (Mqtt5Properties::RequestResponseInformation): return "RequestResponseInformation"; case (Mqtt5Properties::ResponseInformation): return "ResponseInformation"; case (Mqtt5Properties::ServerReference): return "ServerReference"; case (Mqtt5Properties::ReasonString): return "ReasonString"; case (Mqtt5Properties::ReceiveMaximum): return "ReceiveMaximum"; case (Mqtt5Properties::TopicAliasMaximum): return "TopicAliasMaximum"; case (Mqtt5Properties::TopicAlias): return "TopicAlias"; case (Mqtt5Properties::MaximumQoS): return "MaximumQoS"; case (Mqtt5Properties::RetainAvailable): return "RetainAvailable"; case (Mqtt5Properties::UserProperty): return "UserProperty"; case (Mqtt5Properties::MaximumPacketSize): return "MaximumPacketSize"; case (Mqtt5Properties::WildcardSubscriptionAvailable): return "WildcardSubscriptionAvailable"; case (Mqtt5Properties::SubscriptionIdentifierAvailable): return "SubscriptionIdentifierAvailable"; case (Mqtt5Properties::SharedSubscriptionAvailable): return "SharedSubscriptionAvailable"; default: break; } std::ostringstream oss; oss << static_cast(p); return oss.str(); } void unlink_if_sock(const std::string &path) { if (path.empty()) return; struct stat statbuf {}; if (lstat(path.c_str(), &statbuf) < 0) return; if ((statbuf.st_mode & S_IFMT) == S_IFSOCK) { unlink(path.c_str()); } } void fmq_ensure_fail(const char *file, int line) { std::cerr << "Assertion failure: " << file << ", line " << line << "." << std::endl; raise(SIGABRT); } std::optional get_pw_name(const std::string &user) { struct passwd pwd {}; struct passwd *result = nullptr; std::vector buf(16384); if (getpwnam_r(user.c_str(), &pwd, buf.data(), buf.size(), &result) != 0) { throw std::runtime_error("getpwnam_r error"); } if (result == nullptr) { return {}; } SysUserFields answer; answer.name = pwd.pw_name; answer.uid = pwd.pw_uid; answer.gid = pwd.pw_gid; return answer; } std::optional get_gr_name(const std::string &group) { struct group grp {}; struct group *result = nullptr; std::vector buf(16384); if (getgrnam_r(group.c_str(), &grp, buf.data(), buf.size(), &result) != 0) { throw std::runtime_error("getgrnam_r error"); } if (result == nullptr) { return {}; } SysGroupFields answer; answer.name = group; answer.gid = grp.gr_gid; return answer; } std::optional try_stoul(const std::string &s) noexcept { try { size_t len = 0; unsigned long answer = std::stoul(s, &len); if (len != s.length()) return {}; return answer; } catch (std::exception &ex) { return {}; } return {}; } ================================================ FILE: utils.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef UTILS_H #define UTILS_H #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "cirbuf.h" #include "bindaddr.h" #include "types.h" #include "flashmq_plugin.h" #include "threadlocalutils.h" #define UNUSED(expr) do { (void)(expr); } while (0) void fmq_ensure_fail(const char *file, int line); #define FMQ_ENSURE(val) \ if (!static_cast(val)) \ { \ fmq_ensure_fail(__FILE__, __LINE__); \ } template int check(ssize_t rc) { if (rc < 0) { char *err = strerror(errno); std::string msg(err); throw T(msg); } return rc; } std::list split(const std::string &input, const char sep, size_t max = std::numeric_limits::max(), bool keep_empty_parts = true); std::vector splitToVector(const std::string &input, const char sep, size_t max = std::numeric_limits::max(), bool keep_empty_parts = true); std::vector splitTopic(const std::string &topic); bool isValidUtf8Generic(const char *s, bool alsoCheckInvalidPublishChars = false); template bool isValidUtf8Generic(const T &s, bool alsoCheckInvalidPublishChars = false) { int multibyte_remain = 0; uint32_t cur_code_point = 0; int total_char_len = 0; for(const uint8_t x : s) { if (alsoCheckInvalidPublishChars && (x == '#' || x == '+')) return false; if(!multibyte_remain) { cur_code_point = 0; if ((x & 0b10000000) == 0) // when the MSB is 0, it's ASCII, most common case { cur_code_point += (x & 0b01111111); } else if((x & 0b11100000) == 0b11000000) // 2 byte char { multibyte_remain = 1; cur_code_point += ((x & 0b00011111) << 6); } else if((x & 0b11110000) == 0b11100000) // 3 byte char { multibyte_remain = 2; cur_code_point += ((x & 0b00001111) << 12); } else if((x & 0b11111000) == 0b11110000) // 4 byte char { multibyte_remain = 3; cur_code_point += ((x & 0b00000111) << 18); } else if((x & 0b10000000) != 0) return false; else cur_code_point += (x & 0b01111111); total_char_len = multibyte_remain + 1; } else // All remainer bytes of this code point needs to start with 10 { if((x & 0b11000000) != 0b10000000) return false; multibyte_remain--; cur_code_point += ((x & 0b00111111) << (6*multibyte_remain)); } if (multibyte_remain == 0) { // Check overlong values, to avoid having mulitiple representations of the same value. if (total_char_len == 1) { } else if (total_char_len == 2 && cur_code_point < 0x80) return false; else if (total_char_len == 3 && cur_code_point < 0x800) return false; else if (total_char_len == 4 && cur_code_point < 0x10000) return false; if (cur_code_point <= 0x001F) return false; if (cur_code_point >= 0x007F && cur_code_point <= 0x009F) return false; if (total_char_len > 1) { // Invalid range for MQTT. [MQTT-1.5.3-1] if (cur_code_point >= 0xD800 && cur_code_point <= 0xDFFF) // Dec 55296-57343 return false; // Unicode spec: "Which code points are noncharacters?". if (cur_code_point >= 0xFDD0 && cur_code_point <= 0xFDEF) return false; // The last two code points of each of the 17 planes are the remaining 34 non-chars. const uint32_t plane = (cur_code_point & 0x1F0000) >> 16; const uint32_t last_16_bit = cur_code_point & 0xFFFF; if (plane <= 16 && (last_16_bit == 0xFFFE || last_16_bit == 0xFFFF)) return false; } cur_code_point = 0; } } return multibyte_remain == 0; } template bool isValidUtf8(const T &s, bool alsoCheckInvalidPublishChars = false) { #ifdef __SSE4_2__ thread_local static SimdUtils simdUtils; return simdUtils.isValidUtf8(s, alsoCheckInvalidPublishChars); #else return isValidUtf8Generic(s, alsoCheckInvalidPublishChars); #endif } bool strContains(const std::string &s, const std::string &needle); bool isValidShareName(const std::string &s); bool isValidPublishPath(const std::string &s); bool isValidSubscribePath(const std::string &s); bool containsDangerousCharacters(const std::string &s); void ltrim(std::string &s); void rtrim(std::string &s); void rtrim(std::string &s, char c); void trim(std::string &s); bool startsWith(const std::string &s, const std::string &needle); bool endsWith(const std::string &s, const std::string &ending); std::string &rtrim(std::string &s, unsigned char c); std::string getSecureRandomString(const ssize_t len); std::string str_tolower(std::string s); bool stringTruthiness(const std::string &val); bool isPowerOfTwo(int val); std::vector base64Decode(const std::string &s); std::string base64Encode(const unsigned char *input, const int length); void testSsl(const std::string &fullchain, const std::string &privkey); void testSslVerifyLocations(const std::string &caFile, const std::string &caDir, const std::string &error); std::string formatString(const std::string str, ...); std::string_view dirnameOf(std::string_view fname); size_t getFileSize(const std::string &path); uint64_t getFreeSpace(const std::string &path); sa_family_t getFamilyFromSockAddr(const struct sockaddr *addr); std::string sockaddrToString(const struct sockaddr *addr); template void checkWritableDir(const std::string &path) { if (path.empty()) throw ex("Dir path to check is an empty string."); if (access(path.c_str(), W_OK) != 0) { std::string msg = formatString("Path '%s' is not there or not writable", path.c_str()); throw ex(msg); } struct stat statbuf; memset(&statbuf, 0, sizeof(struct stat)); if (stat(path.c_str(), &statbuf) < 0) { // We checked for W_OK above, so this shouldn't happen. std::string msg = formatString("Error getting information about '%s'.", path.c_str()); throw ex(msg); } if (!S_ISDIR(statbuf.st_mode)) { std::string msg = formatString("Path '%s' is not a directory.", path.c_str()); throw ex(msg); } } std::string protocolVersionString(ProtocolVersion p); unsigned int distanceBetweenStrings(const std::string &stringA, const std::string &stringB); template it findCloseStringMatch(it first, it last, const std::string &s) { it alternative = last; unsigned int alternative_distance = UINT_MAX; for (auto possible_key = first; possible_key != last; ++possible_key) { unsigned int distance = distanceBetweenStrings(s, *possible_key); // We only want to suggest options that look a bit like the unknown // one. Experimentally I found 50% of the total length a decent // cutoff. // // The mathemathical formula "distance/length < 0.5" can be // approximated with integers as "distance*2/length < 1" if ((distance * 2) / s.length() < 1 && distance < alternative_distance) { alternative = possible_key; alternative_distance = distance; } } return alternative; } uint32_t ageFromTimePoint(const std::chrono::time_point &point); std::chrono::time_point timepointFromAge(const uint32_t age); ReasonCodes authResultToReasonCode(AuthResult authResult); int maskAllSignalsCurrentThread(); void parseSubscriptionShare(std::vector &subtopics, std::string &shareName, std::string &topic); std::string timestampWithMillis(); template T get_random_int() { std::vector buf(1); // We use urandom, so we don't have check for blocking / interrupted conditions. if (getrandom(buf.data(), sizeof(T), 0) < 0) throw std::runtime_error(strerror(errno)); T val = buf.at(0); return val; } void exceptionOnNonMqtt(const std::vector &data); uint16_t getFirstWildcardDepth(const std::vector &subtopics); std::string reasonCodeToString(ReasonCodes code); std::string packetTypeToString(PacketType ptype); std::string propertyToString(Mqtt5Properties p); /** * @brief parseValuesWithOptionalQuoting parses argument lists space encoded, with quote and escaping support. * @param s * @return * * So, like * * "hallo" "you": becomes a vector with those two strings, but without the quote. * hallo you: is the same as above. * "hallo" you: is the same as above. * "hallo you": is a vector with one element. * "I quote you with \"" is: I quote you with " * 'I quote you with "' is: I quote you with " */ template std::vector parseValuesWithOptionalQuoting(std::string s) { trim(s); std::vector result; char quote = 0; std::string cur; bool escape = false; for (char c : s) { if (escape) { if (!(c == '"' || c == '\'' || c == '\\')) throw ex("Invalid escape"); cur.push_back(c); escape = false; } else if (c == '\\') { escape = true; } else if (std::isspace(c)) { if (quote) cur.push_back(c); else if (!cur.empty()) { result.push_back(cur); cur.clear(); } } else if (c == '"' || c == '\'') { if (quote == 0) { quote = c; } else if (quote == c) { result.push_back(cur); cur.clear(); quote = 0; } else { cur.push_back(c); } } else { cur.push_back(c); } } if (!cur.empty()) result.push_back(cur); if (quote) throw ex("Unterminated quote"); return result; } template class DecrementGuard { T &n; public: DecrementGuard(T &n) : n(n) { } ~DecrementGuard() { n--; assert(n >= 0); } }; template std::optional &non_optional(std::optional &o) { if (!o) o.emplace(); return o; } void unlink_if_sock(const std::string &path); struct SysUserFields { std::string name; uid_t uid = 0; gid_t gid = 0; }; std::optional get_pw_name(const std::string &user); struct SysGroupFields { std::string name; gid_t gid = 0; }; std::optional get_gr_name(const std::string &group); std::optional try_stoul(const std::string &s) noexcept; template std::vector make_vector(const Tinput &input, const size_t offset, const size_t len) { if (len + offset > input.size()) throw std::out_of_range("make_vector"); return std::vector(input.begin() + offset, input.begin() + offset + len); } template std::string make_string(const T &input, const size_t offset, const size_t len) { if (len + offset > input.size()) throw std::out_of_range("make_string"); return std::string(input.begin() + offset, input.begin() + offset + len); } #endif // UTILS_H ================================================ FILE: variablebyteint.cpp ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #include "variablebyteint.h" #include #include #include VariableByteInt::VariableByteInt(uint32_t val) { *this = val; } void VariableByteInt::readIntoBuf(CirBuf &buf) const { assert(len > 0); buf.writerange(bytes.begin(), bytes.begin() + len); } VariableByteInt &VariableByteInt::operator=(uint32_t x) { if (x > 268435455) throw std::runtime_error("Value of variable byte int to encode too big. Bug or corrupt packet?"); len = 0; do { uint8_t encodedByte = x % 128; x = x / 128; if (x > 0) encodedByte = encodedByte | 128; bytes[len++] = encodedByte; } while(x > 0); return *this; } uint8_t VariableByteInt::getLen() const { return len; } const char *VariableByteInt::data() const { return &bytes[0]; } ================================================ FILE: variablebyteint.h ================================================ /* This file is part of FlashMQ (https://www.flashmq.org) Copyright (C) 2021-2023 Wiebe Cazemier FlashMQ is free software: you can redistribute it and/or modify it under the terms of The Open Software License 3.0 (OSL-3.0). See LICENSE for license details. */ #ifndef VARIABLEBYTEINT_H #define VARIABLEBYTEINT_H #include #include "cirbuf.h" class VariableByteInt { std::array bytes; uint8_t len = 0; public: VariableByteInt(uint32_t val); VariableByteInt() = default; void readIntoBuf(CirBuf &buf) const; VariableByteInt &operator=(uint32_t x); uint8_t getLen() const; const char *data() const; auto begin() const { return bytes.cbegin(); } auto end() const { return bytes.begin() + len; } size_t size() const { return len; } }; #endif // VARIABLEBYTEINT_H ================================================ FILE: x509manager.cpp ================================================ #include "x509manager.h" #include X509Manager::X509Manager(const SSL *ssl) : d(nullptr, X509_free) { #if OPENSSL_VERSION_NUMBER < 0x30000000L this->d.reset(SSL_get_peer_certificate(ssl)); #else this->d.reset(SSL_get1_peer_certificate(ssl)); #endif } X509 *X509Manager::get() { return this->d.get(); } X509Manager::operator bool() const { return d != nullptr; } ================================================ FILE: x509manager.h ================================================ #ifndef X509MANAGER_H #define X509MANAGER_H #include #include class X509Manager { std::unique_ptr d; public: X509Manager(const SSL *ssl); X509 *get(); operator bool() const; }; #endif // X509MANAGER_H