Repository: dragonflydb/dragonfly Branch: main Commit: 9dc515b4c7fa Files: 1014 Total size: 8.9 MB Directory structure: gitextract_74218km6/ ├── .agent/ │ └── rules/ │ └── ANTIGRAVITY_INSTRUCTIONS.md ├── .circleci/ │ └── config.yml ├── .clang-format ├── .clang-tidy ├── .clangd ├── .claude/ │ ├── hooks/ │ │ └── format-after-edit.sh │ ├── settings.json │ └── skills/ │ └── reproduce-fuzz-crash/ │ └── SKILL.md ├── .ct.yaml ├── .cursorrules ├── .devcontainer/ │ ├── alpine/ │ │ ├── devcontainer.json │ │ └── post-create.sh │ ├── fedora/ │ │ └── devcontainer.json │ ├── fedora41/ │ │ └── devcontainer.json │ ├── ubuntu20/ │ │ ├── cmake-tools-kits.json │ │ ├── devcontainer.json │ │ └── post-create.sh │ ├── ubuntu20-gcc14/ │ │ └── devcontainer.json │ ├── ubuntu22/ │ │ ├── devcontainer.json │ │ └── post-create.sh │ └── ubuntu24/ │ └── devcontainer.json ├── .dockerignore ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── config.yml │ │ └── feature_request.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── actions/ │ │ ├── builder/ │ │ │ └── action.yml │ │ ├── fuzzing/ │ │ │ └── action.yml │ │ ├── lint-test-chart/ │ │ │ └── action.yml │ │ ├── multi-registry-docker-login/ │ │ │ └── action.yml │ │ ├── regression-tests/ │ │ │ └── action.yml │ │ ├── repeat/ │ │ │ └── action.yml │ │ ├── sync-valkey-tests/ │ │ │ └── action.yml │ │ └── test-docker/ │ │ └── action.yml │ ├── bullmq-skipped-tests.txt │ ├── copilot-instructions.md │ ├── dependabot.yml │ ├── instructions/ │ │ └── code-review.instructions.md │ └── workflows/ │ ├── benchmark.yml │ ├── bullmq-tests.yml │ ├── ci.yml │ ├── copilot-setup-steps.yml │ ├── cov.yml │ ├── daily-builds.yml │ ├── docker-dev-release.yml │ ├── docker-release2.yml │ ├── epoll-regression-tests.yml │ ├── fuzz-long.yml │ ├── fuzz-pr.yml │ ├── generate-osrepo-site.yml │ ├── heavy-tests.yml │ ├── ioloop-v2-regtests.yml │ ├── mastodon-ruby-tests.yml │ ├── package-install.yml │ ├── regression-tests.yml │ ├── release.yml │ ├── repeat-tests.yml │ └── test-fakeredis.yml ├── .gitignore ├── .gitmodules ├── .gitorderfile ├── .nvmrc ├── .pre-commit-config.yaml ├── .pre-commit-hooks.yaml ├── .snyk ├── .vscode/ │ └── c_cpp_properties.json ├── AGENTS.md ├── CLA.txt ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── CONTRIBUTORS.md ├── LICENSE.md ├── Makefile ├── README.ja-JP.md ├── README.ko-KR.md ├── README.md ├── README.pt-BR.md ├── README.zh-CN.md ├── TODO.md ├── contrib/ │ ├── charts/ │ │ └── dragonfly/ │ │ ├── .helmignore │ │ ├── Chart.yaml │ │ ├── README.md │ │ ├── ci/ │ │ │ ├── affinity-values.golden.yaml │ │ │ ├── affinity-values.yaml │ │ │ ├── command_extraargs-values.golden.yaml │ │ │ ├── command_extraargs-values.yaml │ │ │ ├── commonlabels-values.golden.yaml │ │ │ ├── commonlabels-values.yaml │ │ │ ├── extracontainer-string-values.golden.yaml │ │ │ ├── extracontainer-string-values.yaml │ │ │ ├── extracontainer-tpl-values.golden.yaml │ │ │ ├── extracontainer-tpl-values.yaml │ │ │ ├── extraenv-and-passwordSecret-values.golden.yaml │ │ │ ├── extraenv-and-passwordSecret-values.yaml │ │ │ ├── extraenv-values.golden.yaml │ │ │ ├── extraenv-values.yaml │ │ │ ├── extravolumes-values.golden.yaml │ │ │ ├── extravolumes-values.yaml │ │ │ ├── initcontainer-string-values.golden.yaml │ │ │ ├── initcontainer-string-values.yaml │ │ │ ├── initcontainer-tpl-values.golden.yaml │ │ │ ├── initcontainer-tpl-values.yaml │ │ │ ├── password-old-env-values.golden.yaml │ │ │ ├── password-old-env-values.yaml │ │ │ ├── passwordsecret-values.golden.yaml │ │ │ ├── passwordsecret-values.tpl.golden.yaml │ │ │ ├── passwordsecret-values.tpl.yaml │ │ │ ├── passwordsecret-values.yaml │ │ │ ├── persistence-and-existing-secret.golden.yaml │ │ │ ├── persistence-and-existing-secret.yaml │ │ │ ├── persistent-values.golden.yaml │ │ │ ├── persistent-values.yaml │ │ │ ├── priorityclassname-values.golden.yaml │ │ │ ├── priorityclassname-values.yaml │ │ │ ├── prometheusrules-values.golden.yaml │ │ │ ├── prometheusrules-values.yaml │ │ │ ├── resources-values.golden.yaml │ │ │ ├── resources-values.yaml │ │ │ ├── securitycontext-values.golden.yaml │ │ │ ├── securitycontext-values.yaml │ │ │ ├── service-loadbalancer-ip.golden.yaml │ │ │ ├── service-loadbalancer-ip.yaml │ │ │ ├── service-monitor-values.golden.yaml │ │ │ ├── service-monitor-values.yaml │ │ │ ├── taints-tolerations-values.golden.yaml │ │ │ ├── taints-tolerations-values.yaml │ │ │ ├── tls-values.golden.yaml │ │ │ ├── tls-values.yaml │ │ │ ├── tolerations-values.golden.yaml │ │ │ └── tolerations-values.yaml │ │ ├── go.mod │ │ ├── go.sum │ │ ├── golden_test.go │ │ ├── templates/ │ │ │ ├── NOTES.txt │ │ │ ├── _helpers.tpl │ │ │ ├── _pod.tpl │ │ │ ├── certificate.yaml │ │ │ ├── deployment.yaml │ │ │ ├── extra-manifests.yaml │ │ │ ├── metrics-service.yaml │ │ │ ├── prometheusrule.yaml │ │ │ ├── service.yaml │ │ │ ├── serviceaccount.yaml │ │ │ ├── servicemonitor.yaml │ │ │ ├── statefulset.yaml │ │ │ └── tls-secret.yaml │ │ └── values.yaml │ ├── docker/ │ │ ├── README.md │ │ └── docker-compose.yml │ └── scripts/ │ ├── conventional-commits │ └── signed-commit ├── docs/ │ ├── README.md │ ├── async-tiering.md │ ├── cluster-node-health.md │ ├── coordinator.excalidraw │ ├── dashtable.md │ ├── dense_set.excalidraw │ ├── dense_set.md │ ├── df-share-nothing.md │ ├── differences.md │ ├── faq.md │ ├── memcached_benchmark.md │ ├── memory_bgsave.tsv │ ├── namespaces.md │ ├── quick-start/ │ │ └── README.md │ ├── rdbsave.excalidraw │ ├── rdbsave.md │ ├── shard-serialization.md │ ├── thread-per-core.excalidraw │ └── transaction.md ├── fuzz/ │ ├── FUZZING.md │ ├── dict/ │ │ ├── memcache.dict │ │ └── resp.dict │ ├── generate_targeted_seeds.py │ ├── memcache_mutator.py │ ├── package_crash.sh │ ├── replay_crash.py │ ├── resp_mutator.py │ ├── run_fuzzer.sh │ ├── seeds/ │ │ ├── memcache/ │ │ │ ├── add_replace.mc │ │ │ ├── append_prepend.mc │ │ │ ├── cas.mc │ │ │ ├── delete.mc │ │ │ ├── expiry.mc │ │ │ ├── flags.mc │ │ │ ├── flush.mc │ │ │ ├── gat.mc │ │ │ ├── incr_decr.mc │ │ │ ├── large_value.mc │ │ │ ├── meta_commands.mc │ │ │ ├── multiget.mc │ │ │ ├── noreply.mc │ │ │ ├── set_get.mc │ │ │ └── stats_version.mc │ │ └── resp/ │ │ ├── acl.resp │ │ ├── acl_ops.resp │ │ ├── acl_ops2.resp │ │ ├── bf_add.resp │ │ ├── bitfield.resp │ │ ├── bitfield_ops.resp │ │ ├── bitops.resp │ │ ├── bloom_ops.resp │ │ ├── client.resp │ │ ├── config.resp │ │ ├── copy.resp │ │ ├── del.resp │ │ ├── eval.resp │ │ ├── expire_ops.resp │ │ ├── function.resp │ │ ├── function_ops.resp │ │ ├── generic_ops.resp │ │ ├── generic_ops2.resp │ │ ├── geo_ops.resp │ │ ├── geo_ops2.resp │ │ ├── geoadd.resp │ │ ├── get.resp │ │ ├── getdel.resp │ │ ├── hash_ops.resp │ │ ├── hash_ops2.resp │ │ ├── hll_ops.resp │ │ ├── hset.resp │ │ ├── json.resp │ │ ├── json_ops.resp │ │ ├── json_ops2.resp │ │ ├── list_blocking.resp │ │ ├── list_ops.resp │ │ ├── lmpop.resp │ │ ├── lpos.resp │ │ ├── lpush.resp │ │ ├── memory.resp │ │ ├── monitor.resp │ │ ├── mset.resp │ │ ├── multi_type_pipeline.resp │ │ ├── object.resp │ │ ├── pfadd.resp │ │ ├── ping.resp │ │ ├── pipeline.resp │ │ ├── pubsub_ops.resp │ │ ├── pubsub_ops2.resp │ │ ├── rename.resp │ │ ├── rpoplpush.resp │ │ ├── sadd.resp │ │ ├── scan_hscan.resp │ │ ├── script_ops.resp │ │ ├── script_ops2.resp │ │ ├── sdiffstore.resp │ │ ├── search_ops.resp │ │ ├── search_ops2.resp │ │ ├── server_ops.resp │ │ ├── server_ops2.resp │ │ ├── set.resp │ │ ├── set_ops.resp │ │ ├── set_ops2.resp │ │ ├── smove.resp │ │ ├── sort.resp │ │ ├── srandmember.resp │ │ ├── stream_ops.resp │ │ ├── stream_ops2.resp │ │ ├── string_ops.resp │ │ ├── string_ops2.resp │ │ ├── subscribe.resp │ │ ├── throttle.resp │ │ ├── transaction.resp │ │ ├── transaction_ops2.resp │ │ ├── watch.resp │ │ ├── watch_multi.resp │ │ ├── xadd.resp │ │ ├── xread.resp │ │ ├── zadd.resp │ │ ├── zmpop.resp │ │ ├── zrangebyscore.resp │ │ ├── zset_ops.resp │ │ └── zset_ops2.resp │ └── triage_crashes.sh ├── go.work ├── go.work.sum ├── patches/ │ └── mimalloc-v2.2.4/ │ ├── 0_base.patch │ ├── 1_add_stat_type.patch │ ├── 2_return_stat.patch │ ├── 3_track_full_size.patch │ └── 4_fix_heap_collect.patch ├── pyproject.toml ├── src/ │ ├── .gitignore │ ├── CMakeLists.txt │ ├── GetGitRevisionDescription.cmake │ ├── GetGitRevisionDescription.cmake.in │ ├── common/ │ │ ├── arg_range.h │ │ ├── backed_args.h │ │ ├── heap_size.h │ │ └── string_or_view.h │ ├── core/ │ │ ├── CMakeLists.txt │ │ ├── allocation_tracker.cc │ │ ├── allocation_tracker.h │ │ ├── allocation_tracker_test.cc │ │ ├── bloom.cc │ │ ├── bloom.h │ │ ├── bloom_test.cc │ │ ├── bptree_set.h │ │ ├── bptree_set_test.cc │ │ ├── cms.cc │ │ ├── cms.h │ │ ├── cms_test.cc │ │ ├── collection_entry.h │ │ ├── compact_object.cc │ │ ├── compact_object.h │ │ ├── compact_object_test.cc │ │ ├── dash.h │ │ ├── dash_bench.cc │ │ ├── dash_internal.h │ │ ├── dash_test.cc │ │ ├── dense_set.cc │ │ ├── dense_set.h │ │ ├── detail/ │ │ │ ├── bitpacking.cc │ │ │ ├── bitpacking.h │ │ │ ├── bptree_internal.h │ │ │ ├── gen_utils.h │ │ │ ├── listpack.cc │ │ │ ├── listpack.h │ │ │ ├── listpack_wrap.cc │ │ │ ├── listpack_wrap.h │ │ │ └── stateless_allocator.h │ │ ├── dfly_core_test.cc │ │ ├── dict_builder.cc │ │ ├── dict_builder.h │ │ ├── dict_builder_test.cc │ │ ├── dragonfly_core.cc │ │ ├── expire_period.h │ │ ├── extent_tree.cc │ │ ├── extent_tree.h │ │ ├── extent_tree_test.cc │ │ ├── flatbuffers.h │ │ ├── flatbuffers_test.cc │ │ ├── generate_bin_sizes.py │ │ ├── glob_matcher.cc │ │ ├── glob_matcher.h │ │ ├── huff_coder.cc │ │ ├── huff_coder.h │ │ ├── intent_lock.h │ │ ├── interpreter.cc │ │ ├── interpreter.h │ │ ├── interpreter_polyfill.h │ │ ├── interpreter_test.cc │ │ ├── json/ │ │ │ ├── CMakeLists.txt │ │ │ ├── detail/ │ │ │ │ ├── common.h │ │ │ │ ├── flat_dfs.cc │ │ │ │ ├── flat_dfs.h │ │ │ │ ├── interned_blob.cc │ │ │ │ ├── interned_blob.h │ │ │ │ ├── interned_string.cc │ │ │ │ ├── interned_string.h │ │ │ │ ├── jsoncons_dfs.cc │ │ │ │ └── jsoncons_dfs.h │ │ │ ├── driver.cc │ │ │ ├── driver.h │ │ │ ├── interned_blob_test.cc │ │ │ ├── json_object.cc │ │ │ ├── json_object.h │ │ │ ├── json_test.cc │ │ │ ├── jsonpath_grammar.y │ │ │ ├── jsonpath_lexer.lex │ │ │ ├── jsonpath_test.cc │ │ │ ├── lexer_impl.cc │ │ │ ├── lexer_impl.h │ │ │ ├── path.cc │ │ │ └── path.h │ │ ├── linear_search_map.h │ │ ├── linear_search_map_test.cc │ │ ├── listpack_test.cc │ │ ├── memory_test.cc │ │ ├── mi_memory_resource.cc │ │ ├── mi_memory_resource.h │ │ ├── oah_entry.cc │ │ ├── oah_entry.h │ │ ├── oah_set.h │ │ ├── oah_set_test.cc │ │ ├── overloaded.h │ │ ├── page_usage/ │ │ │ ├── CMakeLists.txt │ │ │ ├── page_usage_stats.cc │ │ │ └── page_usage_stats.h │ │ ├── page_usage_stats_test.cc │ │ ├── qlist.cc │ │ ├── qlist.h │ │ ├── qlist_test.cc │ │ ├── score_map.cc │ │ ├── score_map.h │ │ ├── score_map_test.cc │ │ ├── sds_utils.cc │ │ ├── sds_utils.h │ │ ├── search/ │ │ │ ├── CMakeLists.txt │ │ │ ├── ast_expr.cc │ │ │ ├── ast_expr.h │ │ │ ├── base.cc │ │ │ ├── base.h │ │ │ ├── block_list.cc │ │ │ ├── block_list.h │ │ │ ├── block_list_test.cc │ │ │ ├── compressed_sorted_set.cc │ │ │ ├── compressed_sorted_set.h │ │ │ ├── compressed_sorted_set_test.cc │ │ │ ├── hnsw_alg.h │ │ │ ├── hnsw_index.cc │ │ │ ├── hnsw_index.h │ │ │ ├── index_result.h │ │ │ ├── indices.cc │ │ │ ├── indices.h │ │ │ ├── lexer.lex │ │ │ ├── mrmw_mutex.h │ │ │ ├── mrmw_mutex_test.cc │ │ │ ├── parser.y │ │ │ ├── query_driver.cc │ │ │ ├── query_driver.h │ │ │ ├── range_tree.cc │ │ │ ├── range_tree.h │ │ │ ├── range_tree_test.cc │ │ │ ├── rax_tree.h │ │ │ ├── rax_tree_test.cc │ │ │ ├── renewable_quota.cc │ │ │ ├── renewable_quota.h │ │ │ ├── scanner.h │ │ │ ├── search.cc │ │ │ ├── search.h │ │ │ ├── search_parser_test.cc │ │ │ ├── search_test.cc │ │ │ ├── sort_indices.cc │ │ │ ├── sort_indices.h │ │ │ ├── stateless_allocator.h │ │ │ ├── synonyms.cc │ │ │ ├── synonyms.h │ │ │ ├── tag_types.h │ │ │ ├── vector_utils.cc │ │ │ └── vector_utils.h │ │ ├── segment_allocator.cc │ │ ├── segment_allocator.h │ │ ├── size_tracking_channel.h │ │ ├── small_string.cc │ │ ├── small_string.h │ │ ├── sorted_map.cc │ │ ├── sorted_map.h │ │ ├── sorted_map_test.cc │ │ ├── sse_port.h │ │ ├── string_map.cc │ │ ├── string_map.h │ │ ├── string_map_test.cc │ │ ├── string_set.cc │ │ ├── string_set.h │ │ ├── string_set_test.cc │ │ ├── task_queue.cc │ │ ├── task_queue.h │ │ ├── testdata/ │ │ │ ├── ids.txt.zst │ │ │ └── list.txt.zst │ │ ├── tiering_types.cc │ │ ├── tiering_types.h │ │ ├── top_keys.cc │ │ ├── top_keys.h │ │ ├── top_keys_test.cc │ │ ├── topk.cc │ │ ├── topk.h │ │ ├── topk_test.cc │ │ ├── tx_queue.cc │ │ ├── tx_queue.h │ │ └── zstd_test.cc │ ├── external_libs.cmake │ ├── facade/ │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── cmd_arg_parser.cc │ │ ├── cmd_arg_parser.h │ │ ├── cmd_arg_parser_test.cc │ │ ├── command_id.h │ │ ├── conn_context.h │ │ ├── connection_ref.h │ │ ├── disk_backed_queue.cc │ │ ├── disk_backed_queue.h │ │ ├── disk_backed_queue_test.cc │ │ ├── dragonfly_connection.cc │ │ ├── dragonfly_connection.h │ │ ├── dragonfly_listener.cc │ │ ├── dragonfly_listener.h │ │ ├── error.h │ │ ├── facade.cc │ │ ├── facade_stats.h │ │ ├── facade_test.cc │ │ ├── facade_test.h │ │ ├── facade_types.h │ │ ├── memcache_parser.cc │ │ ├── memcache_parser.h │ │ ├── memcache_parser_test.cc │ │ ├── ok_main.cc │ │ ├── op_status.cc │ │ ├── op_status.h │ │ ├── parsed_command.cc │ │ ├── parsed_command.h │ │ ├── redis_parser.cc │ │ ├── redis_parser.h │ │ ├── redis_parser_test.cc │ │ ├── reply_builder.cc │ │ ├── reply_builder.h │ │ ├── reply_builder_test.cc │ │ ├── reply_capture.cc │ │ ├── reply_capture.h │ │ ├── reply_mode.h │ │ ├── reply_payload.h │ │ ├── resp_expr.cc │ │ ├── resp_expr.h │ │ ├── resp_expr_test_utils.cc │ │ ├── resp_expr_test_utils.h │ │ ├── resp_parser.cc │ │ ├── resp_parser.h │ │ ├── resp_parser_test.cc │ │ ├── resp_srv_parser.cc │ │ ├── resp_srv_parser.h │ │ ├── resp_srv_parser_test.cc │ │ ├── resp_validator.cc │ │ ├── service_interface.cc │ │ ├── service_interface.h │ │ ├── socket_utils.cc │ │ ├── socket_utils.h │ │ ├── tls_helpers.cc │ │ └── tls_helpers.h │ ├── huff/ │ │ ├── LICENSE │ │ ├── README.md │ │ ├── hist.h │ │ ├── huf.h │ │ └── mem.h │ ├── redis/ │ │ ├── CMakeLists.txt │ │ ├── LICENSE.redis │ │ ├── config.h │ │ ├── crc16.c │ │ ├── crc16.h │ │ ├── crc64.c │ │ ├── crc64.h │ │ ├── crcspeed.c │ │ ├── crcspeed.h │ │ ├── debug.c │ │ ├── dict.c │ │ ├── dict.h │ │ ├── endianconv.h │ │ ├── geo.c │ │ ├── geo.h │ │ ├── geohash.c │ │ ├── geohash.h │ │ ├── geohash_helper.c │ │ ├── geohash_helper.h │ │ ├── hiredis.c │ │ ├── hiredis.h │ │ ├── hyperloglog.c │ │ ├── hyperloglog.h │ │ ├── intset.c │ │ ├── intset.h │ │ ├── listpack.c │ │ ├── listpack.h │ │ ├── lua/ │ │ │ ├── CMakeLists.txt │ │ │ ├── README.md │ │ │ ├── bit/ │ │ │ │ └── bit.c │ │ │ ├── cjson/ │ │ │ │ ├── fpconv.c │ │ │ │ ├── fpconv.h │ │ │ │ ├── lua_cjson.c │ │ │ │ ├── strbuf.c │ │ │ │ └── strbuf.h │ │ │ ├── cmsgpack/ │ │ │ │ └── lua_cmsgpack.c │ │ │ └── struct/ │ │ │ └── lua_struct.c │ │ ├── lzf.h │ │ ├── lzfP.h │ │ ├── lzf_c.c │ │ ├── lzf_d.c │ │ ├── rax.c │ │ ├── rax.h │ │ ├── rax_malloc.h │ │ ├── rdb.h │ │ ├── read.c │ │ ├── read.h │ │ ├── redis_aux.c │ │ ├── redis_aux.h │ │ ├── sds.c │ │ ├── sds.h │ │ ├── sdsalloc.h │ │ ├── siphash.c │ │ ├── stream.h │ │ ├── t_stream.c │ │ ├── util.c │ │ ├── util.h │ │ ├── ziplist.c │ │ ├── ziplist.h │ │ ├── zmalloc.c │ │ ├── zmalloc.h │ │ └── zmalloc_mi.c │ └── server/ │ ├── CMakeLists.txt │ ├── acl/ │ │ ├── acl_commands_def.h │ │ ├── acl_family.cc │ │ ├── acl_family.h │ │ ├── acl_family_test.cc │ │ ├── acl_log.cc │ │ ├── acl_log.h │ │ ├── user.cc │ │ ├── user.h │ │ ├── user_registry.cc │ │ ├── user_registry.h │ │ ├── validator.cc │ │ └── validator.h │ ├── bitops_family.cc │ ├── bitops_family_test.cc │ ├── blocking_controller.cc │ ├── blocking_controller.h │ ├── blocking_controller_test.cc │ ├── bloom_family.cc │ ├── bloom_family_test.cc │ ├── channel_store.cc │ ├── channel_store.h │ ├── cluster/ │ │ ├── CMakeLists.txt │ │ ├── cluster_config.cc │ │ ├── cluster_config.h │ │ ├── cluster_config_test.cc │ │ ├── cluster_defs.cc │ │ ├── cluster_defs.h │ │ ├── cluster_family.cc │ │ ├── cluster_family.h │ │ ├── cluster_family_test.cc │ │ ├── cluster_utility.cc │ │ ├── cluster_utility.h │ │ ├── coordinator.cc │ │ ├── coordinator.h │ │ ├── incoming_slot_migration.cc │ │ ├── incoming_slot_migration.h │ │ ├── outgoing_slot_migration.cc │ │ ├── outgoing_slot_migration.h │ │ └── slot_set.h │ ├── cluster_support.cc │ ├── cluster_support.h │ ├── cmd_support.cc │ ├── cmd_support.h │ ├── cms_family.cc │ ├── cms_family_test.cc │ ├── collection_family_fallback.cc │ ├── command_families.h │ ├── command_registry.cc │ ├── command_registry.h │ ├── common.cc │ ├── common.h │ ├── common_types.h │ ├── config_registry.cc │ ├── config_registry.h │ ├── conn_context.cc │ ├── conn_context.h │ ├── container_utils.cc │ ├── container_utils.h │ ├── db_slice.cc │ ├── db_slice.h │ ├── debugcmd.cc │ ├── debugcmd.h │ ├── detail/ │ │ ├── compressor.cc │ │ ├── compressor.h │ │ ├── decompress.cc │ │ ├── decompress.h │ │ ├── save_stages_controller.cc │ │ ├── save_stages_controller.h │ │ ├── snapshot_storage.cc │ │ ├── snapshot_storage.h │ │ ├── table.h │ │ └── wrapped_json_path.h │ ├── dfly_bench.cc │ ├── dfly_main.cc │ ├── dflycmd.cc │ ├── dflycmd.h │ ├── dragonfly_test.cc │ ├── engine_shard.cc │ ├── engine_shard.h │ ├── engine_shard_set.cc │ ├── engine_shard_set.h │ ├── engine_shard_set_test.cc │ ├── error.cc │ ├── error.h │ ├── execution_state.cc │ ├── execution_state.h │ ├── family_utils.cc │ ├── family_utils.h │ ├── generic_family.cc │ ├── generic_family.h │ ├── generic_family_test.cc │ ├── geo_family.cc │ ├── geo_family_test.cc │ ├── hll_family.cc │ ├── hll_family_test.cc │ ├── hset_family.cc │ ├── hset_family.h │ ├── hset_family_test.cc │ ├── http_api.cc │ ├── http_api.h │ ├── journal/ │ │ ├── CMakeLists.txt │ │ ├── cmd_serializer.cc │ │ ├── cmd_serializer.h │ │ ├── executor.cc │ │ ├── executor.h │ │ ├── journal.cc │ │ ├── journal.h │ │ ├── journal_slice.cc │ │ ├── journal_slice.h │ │ ├── journal_test.cc │ │ ├── pending_buf.h │ │ ├── serializer.cc │ │ ├── serializer.h │ │ ├── streamer.cc │ │ ├── streamer.h │ │ ├── tx_executor.cc │ │ ├── tx_executor.h │ │ ├── types.cc │ │ └── types.h │ ├── json_family.cc │ ├── json_family_memory_test.cc │ ├── json_family_test.cc │ ├── list_family.cc │ ├── list_family_test.cc │ ├── main_service.cc │ ├── main_service.h │ ├── memory_cmd.cc │ ├── memory_cmd.h │ ├── multi_command_squasher.cc │ ├── multi_command_squasher.h │ ├── multi_test.cc │ ├── namespaces.cc │ ├── namespaces.h │ ├── protocol_client.cc │ ├── protocol_client.h │ ├── rdb_extensions.h │ ├── rdb_load.cc │ ├── rdb_load.h │ ├── rdb_load_context.cc │ ├── rdb_load_context.h │ ├── rdb_save.cc │ ├── rdb_save.h │ ├── rdb_test.cc │ ├── replica.cc │ ├── replica.h │ ├── replica_types.h │ ├── script_mgr.cc │ ├── script_mgr.h │ ├── search/ │ │ ├── CMakeLists.txt │ │ ├── aggregator.cc │ │ ├── aggregator.h │ │ ├── aggregator_test.cc │ │ ├── doc_accessors.cc │ │ ├── doc_accessors.h │ │ ├── doc_index.cc │ │ ├── doc_index.h │ │ ├── doc_index_fallback.cc │ │ ├── global_hnsw_index.cc │ │ ├── global_hnsw_index.h │ │ ├── index_builder.cc │ │ ├── index_builder.h │ │ ├── index_join.cc │ │ ├── index_join.h │ │ ├── index_join_test.cc │ │ ├── search_family.cc │ │ ├── search_family.h │ │ └── search_family_test.cc │ ├── serializer_base.cc │ ├── serializer_base.h │ ├── serializer_base_test.cc │ ├── serializer_commons.cc │ ├── serializer_commons.h │ ├── server_family.cc │ ├── server_family.h │ ├── server_family_test.cc │ ├── server_state.cc │ ├── server_state.h │ ├── set_family.cc │ ├── set_family.h │ ├── set_family_test.cc │ ├── sharding.cc │ ├── sharding.h │ ├── slowlog.cc │ ├── slowlog.h │ ├── snapshot.cc │ ├── snapshot.h │ ├── stats.cc │ ├── stats.h │ ├── stream_family.cc │ ├── stream_family.h │ ├── stream_family_test.cc │ ├── string_family.cc │ ├── string_family_test.cc │ ├── string_stats.cc │ ├── string_stats.h │ ├── string_stats_test.cc │ ├── synchronization.cc │ ├── synchronization.h │ ├── table.cc │ ├── table.h │ ├── test_utils.cc │ ├── test_utils.h │ ├── testdata/ │ │ ├── RDB_TYPE_STREAM_LISTPACKS_2.rdb │ │ ├── RDB_TYPE_STREAM_LISTPACKS_3.rdb │ │ ├── empty.rdb │ │ ├── hll.rdb │ │ ├── ignore_expiry.rdb │ │ ├── redis6_small.rdb │ │ ├── redis6_stream.rdb │ │ ├── redis7_small.rdb │ │ └── redis_json.rdb │ ├── tiered_storage.cc │ ├── tiered_storage.h │ ├── tiered_storage_test.cc │ ├── tiering/ │ │ ├── CMakeLists.txt │ │ ├── common.h │ │ ├── decoders.cc │ │ ├── decoders.h │ │ ├── disk_storage.cc │ │ ├── disk_storage.h │ │ ├── disk_storage_test.cc │ │ ├── entry_map.h │ │ ├── external_alloc.cc │ │ ├── external_alloc.h │ │ ├── external_alloc_test.cc │ │ ├── op_manager.cc │ │ ├── op_manager.h │ │ ├── op_manager_test.cc │ │ ├── serialized_map.cc │ │ ├── serialized_map.h │ │ ├── serialized_map_test.cc │ │ ├── small_bins.cc │ │ ├── small_bins.h │ │ ├── small_bins_test.cc │ │ └── test_common.h │ ├── transaction.cc │ ├── transaction.h │ ├── tx_base.cc │ ├── tx_base.h │ ├── version.cc.in │ ├── version.h │ ├── version_monitor.cc │ ├── version_monitor.h │ ├── zset_family.cc │ ├── zset_family.h │ └── zset_family_test.cc ├── tests/ │ ├── README.md │ ├── dragonfly/ │ │ ├── __init__.py │ │ ├── acl_family_test.py │ │ ├── bull_sidekiq_test.py │ │ ├── celery_test.py │ │ ├── cluster_mgr_test.py │ │ ├── cluster_test.py │ │ ├── config_test.py │ │ ├── conftest.py │ │ ├── connection_test.py │ │ ├── eval_test.py │ │ ├── generic_test.py │ │ ├── http_conf_test.py │ │ ├── instance.py │ │ ├── json_test.py │ │ ├── list_family_test.py │ │ ├── management_test.py │ │ ├── memcache_meta.py │ │ ├── memory_test.py │ │ ├── proxy.py │ │ ├── pymemcached_test.py │ │ ├── redis_replication_test.py │ │ ├── replication_test.py │ │ ├── requirements.txt │ │ ├── search_benchmark_test.py │ │ ├── search_benchmark_utils.py │ │ ├── search_test.py │ │ ├── seeder/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── script-generate.lua │ │ │ ├── script-genlib.lua │ │ │ ├── script-hash.lua │ │ │ ├── script-hashlib.lua │ │ │ └── script-utillib.lua │ │ ├── seeder_test.py │ │ ├── sentinel_test.py │ │ ├── server_family_test.py │ │ ├── set_test.py │ │ ├── shutdown_test.py │ │ ├── snapshot_test.py │ │ ├── test_dash_gc.py │ │ ├── tiering_test.py │ │ ├── tls_conf_test.py │ │ ├── utility.py │ │ └── valkey_search/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── sync-valkey-search-tests.sh │ │ ├── util.py │ │ └── valkey_search_test_case_dragonfly.py │ ├── fakeredis/ │ │ ├── README.md │ │ ├── pyproject.toml │ │ └── test/ │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_asyncredis.py │ │ ├── test_hypotesis_joint/ │ │ │ ├── __init__.py │ │ │ └── test_joint.py │ │ ├── test_hypothesis/ │ │ │ ├── __init__.py │ │ │ ├── _server_info.py │ │ │ ├── base.py │ │ │ ├── test_connection.py │ │ │ ├── test_hash.py │ │ │ ├── test_list.py │ │ │ ├── test_server.py │ │ │ ├── test_set.py │ │ │ ├── test_string.py │ │ │ ├── test_transaction.py │ │ │ └── test_zset.py │ │ ├── test_issues.py │ │ ├── test_json/ │ │ │ ├── __init__.py │ │ │ ├── test_json.py │ │ │ ├── test_json_arr_commands.py │ │ │ └── test_json_commands.py │ │ ├── test_mixins/ │ │ │ ├── __init__.py │ │ │ ├── test_bitmap_commands.py │ │ │ ├── test_connection.py │ │ │ ├── test_generic_commands.py │ │ │ ├── test_geo_commands.py │ │ │ ├── test_hash_commands.py │ │ │ ├── test_list_commands.py │ │ │ ├── test_pubsub_commands.py │ │ │ ├── test_scan.py │ │ │ ├── test_scripting.py │ │ │ ├── test_server_commands.py │ │ │ ├── test_set_commands.py │ │ │ ├── test_sortedset_commands.py │ │ │ ├── test_streams_commands.py │ │ │ ├── test_string_commands.py │ │ │ └── test_zadd.py │ │ ├── test_stack/ │ │ │ ├── __init__.py │ │ │ ├── test_bloomfilter.py │ │ │ ├── test_cms.py │ │ │ ├── test_cuckoofilter.py │ │ │ ├── test_tdigest.py │ │ │ └── test_topk.py │ │ ├── test_transactions.py │ │ └── testtools.py │ ├── integration/ │ │ ├── .dockerignore │ │ ├── .run_ioredis_valid_test.sh │ │ ├── async.py │ │ ├── gen_sets.sh │ │ ├── generate_sets.py │ │ ├── ioredis.Dockerfile │ │ ├── jedis.Dockerfile │ │ ├── node-redis.Dockerfile │ │ ├── pascaldekloe.Dockerfile │ │ ├── relay.Dockerfile │ │ ├── run_ioredis_on_docker.sh │ │ └── stress_shutdown.sh │ └── pytest.ini └── tools/ ├── balls_bins.py ├── benchmark/ │ ├── k8s-benchmark-job.yaml │ └── post_run_checks.py ├── cache_logs_player.py ├── cache_testing.py ├── cluster_mgr.py ├── defrag_db.py ├── defrag_mem_test.py ├── docker/ │ ├── entrypoint.sh │ ├── fetch_release.sh │ └── healthcheck.sh ├── eviction/ │ ├── fill_db.py │ ├── run_fill_db.sh │ └── stop_fill_db.sh ├── faulty_io.sh ├── generate-tls-files.sh ├── json_benchmark.py ├── local/ │ ├── gen-test-certs.sh │ └── monitoring/ │ ├── docker-compose.yml │ ├── grafana/ │ │ ├── config.monitoring │ │ └── provisioning/ │ │ ├── dashboards/ │ │ │ ├── dashboard.yml │ │ │ ├── dragonfly.json │ │ │ ├── memcached.json │ │ │ ├── node-exporter.json │ │ │ └── redis.json │ │ └── datasources/ │ │ └── datasource.yml │ └── prometheus/ │ └── prometheus.yml ├── packaging/ │ ├── Dockerfile.alpine-dev │ ├── Dockerfile.ubuntu-dev │ ├── Dockerfile.ubuntu-prod │ ├── README.md │ ├── debian/ │ │ ├── compat │ │ ├── control │ │ ├── dragonfly.conf │ │ ├── dragonfly.install │ │ ├── dragonfly.logrotate │ │ ├── dragonfly.postinst │ │ ├── dragonfly.postrm │ │ ├── dragonfly.preinst │ │ ├── dragonfly.service │ │ └── rules │ ├── generate_changelog.sh │ ├── generate_debian_package.sh │ ├── osrepos/ │ │ ├── README.md │ │ ├── dragonfly.repo │ │ ├── dragonfly.sources │ │ ├── pgp-key.public │ │ ├── reprepro-config/ │ │ │ ├── distributions │ │ │ └── options │ │ ├── requirements.txt │ │ └── scripts/ │ │ ├── fetch-releases.py │ │ ├── generate-apt-repo.sh │ │ ├── generate-index.py │ │ └── sign-rpms.sh │ └── rpm/ │ ├── build_rpm.sh │ ├── dragonfly.service │ └── dragonfly.spec ├── parse_allocator_tracking_logs.py ├── plot_memtier_latency.py ├── release.sh ├── replay/ │ ├── go.mod │ ├── go.sum │ ├── main.go │ ├── parsing.go │ └── workers.go ├── requirements.txt ├── run_master_replica.sh └── vector-benches/ ├── README.md ├── go.mod ├── go.sum └── main.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .agent/rules/ANTIGRAVITY_INSTRUCTIONS.md ================================================ # Antigravity Agent Instructions for Dragonfly **READ [AGENTS.md](../../AGENTS.md)** All project information, workflows, patterns, and guidelines are in `AGENTS.md`. ================================================ FILE: .circleci/config.yml ================================================ version: 2.1 machine: true jobs: build-ubuntu: docker: - image: ghcr.io/romange/ubuntu-dev:22 steps: - checkout - run: name: Set up environment environment: BUILD_TYPE: Debug command: | git submodule update --init --recursive cmake -B build -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -GNinja -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - run: name: Build & Test command: | cd build && pwd ninja -j4 src/all ctest -V -L DFLY # Orchestrate our job run sequence workflows: build_and_test: jobs: - build-ubuntu ================================================ FILE: .clang-format ================================================ # --- # We'll use defaults from the Google style, but with 2 columns indentation. BasedOnStyle: Google IndentWidth: 2 ColumnLimit: 100 --- Language: Cpp AllowShortLoopsOnASingleLine: false AllowShortFunctionsOnASingleLine: false AllowShortIfStatementsOnASingleLine: false AlwaysBreakTemplateDeclarations: false PackConstructorInitializers: NextLine DerivePointerAlignment: false PointerAlignment: Left BasedOnStyle: Google ColumnLimit: 100 --- Language: Proto BasedOnStyle: Google ================================================ FILE: .clang-tidy ================================================ --- Checks: > -abseil-no-namespace, bugprone*, # Sadly narrowing conversions is too noisy -bugprone-narrowing-conversions, -bugprone-easily-swappable-parameters, -bugprone-branch-clone, -bugprone-implicit-widening-of-multiplication-result, -bugprone-too-small-loop-variable, -bugprone-reserved-identifier, boost-use-to-string, performance*, -cert-err58-cpp, -cert-dcl58-cpp, # Ignore std changes -cert-dcl51-cpp, # bugprone-reserved-identifier # Doesn't work with abseil flags clang-analyzer*, google-*, -google-runtime-int, -google-readability-*, -google-build-using-namespace, misc-definitions-in-headers, misc-misleading*, misc-misplaced-const, misc-new-delete-overloads, misc-non-copyable-objects, misc-redundant-expression, misc-static-assert, misc-throw-by-value-catch-by-reference, misc-unconventional-assign-operator, misc-uniqueptr-reset-release, misc-unused-alias-decls, misc-unused-using-decls, modernize-deprecated-headers, modernize-macro-to-enum, modernize-make-shared, modernize-make-unique, modernize-pass-by-value, modernize-raw-string-literal, modernize-redundant-void-arg, modernize-replace-disallow-copy-and-assign-macro, modernize-return-braced-init-list, modernize-shrink-to-fit, modernize-unary-static-assert, modernize-use-emplace, modernize-use-equals-delete, modernize-use-noexcept, modernize-use-transparent-functors, modernize-use-uncaught-exceptions, modernize-use-using, readability-avoid-const-params-in-decls, readability-const-return-type, readability-container-contains, readability-container-size-empty, readability-delete-null-pointer, readability-duplicate-include, readability-function-size, readability-identifier-naming, -readability-inconsistent-declaration-parameter-name, readability-make-member-function-const, readability-misplaced-array-index, readability-named-parameter, readability-non-const-parameter, readability-redundant-access-specifiers, readability-redundant-control-flow, readability-redundant-declaration, readability-redundant-function-ptr-dereference, readability-redundant-member-init, readability-redundant-preprocessor, readability-redundant-smartptr-get, readability-redundant-string-cstr, readability-redundant-string-init, readability-simplify-subscript-expr, readability-static-definition-in-anonymous-namespace, readability-string-compare, readability-suspicious-call-argument, readability-uniqueptr-delete-release, readability-use-anyofallof # Disabled because they're currently too disruptive, but one day might be nice to have: # modernize-use-nullptr, # modernize-use-equals-default, # readability-qualified-auto, CheckOptions: - key: bugprone-narrowing-conversions.WarnOnIntegerNarrowingConversion value: false - key: bugprone-narrowing-conversions.WarnOnEquivalentBitWidth value: false ================================================ FILE: .clangd ================================================ Diagnostics: UnusedIncludes: None MissingIncludes: None Includes: IgnoreHeader: base/*.h CompileFlags: CompilationDatabase: build-dbg/ # Search for compile_commands.json ================================================ FILE: .claude/hooks/format-after-edit.sh ================================================ #!/bin/bash # Hook to automatically format files after Edit/Write operations # Filters out src/redis directory from formatting # Read JSON input from stdin INPUT=$(cat) FILE_PATH=$(echo "$INPUT" | jq -r '.tool_input.file_path // empty') # Skip if no file path if [ -z "$FILE_PATH" ]; then exit 0 fi # Skip if file is in src/redis directory if [[ "$FILE_PATH" == */src/redis/* ]]; then echo "Skipping formatting for src/redis file: $FILE_PATH" >&2 exit 0 fi # Skip if file doesn't exist if [ ! -f "$FILE_PATH" ]; then exit 0 fi # Run pre-commit on the file pre-commit run --files "$FILE_PATH" 2>&1 # Always exit 0 to not block the operation even if formatting fails exit 0 ================================================ FILE: .claude/settings.json ================================================ { "permissions": { "allow": [ "Read($CLAUDE_PROJECT_DIR/**)", "Edit($CLAUDE_PROJECT_DIR/**)", "Write($CLAUDE_PROJECT_DIR/**)", "Bash(./*_test:*)", "Bash(ninja:*)", "Bash(git add:*)", "Bash(git reset:*)", "Bash(gh issue view:*)", "Bash(git log:*)", "Bash(git show:*)", "WebSearch", "Bash(grep:*)", "Bash(pre-commit run:*)", "Bash(clang-format:*)", "Bash(git checkout:*)", "Bash(tee:*)", "Bash(sort:*)", "Bash(git patch-id:*)" ] }, "hooks": { "PostToolUse": [ { "matcher": "Edit|Write", "hooks": [ { "type": "command", "command": "\"$CLAUDE_PROJECT_DIR\"/.claude/hooks/format-after-edit.sh", "timeout": 30, "statusMessage": "Formatting code..." } ] } ] } } ================================================ FILE: .claude/skills/reproduce-fuzz-crash/SKILL.md ================================================ --- name: reproduce-fuzz-crash description: > Reproduce AFL++ fuzz crashes from GitHub Actions. Use when user provides a GitHub Actions fuzz run URL and wants to reproduce and analyze the crash locally. argument-hint: allowed-tools: Bash, Read, Grep, Glob, Write --- # Reproduce Fuzz Crash Given a GitHub Actions fuzz run URL, download crash artifacts, triage them with `fuzz/triage_crashes.sh`, and produce a crash analysis report. **Input**: `$ARGUMENTS` — a GitHub Actions run URL like: `https://github.com/dragonflydb/dragonfly/actions/runs/22906484769` or with query params like `?pr=6855`. ## Workflow ### Step 1: Parse the URL Extract `owner/repo` and `run_id` from the URL. ``` https://github.com/{owner}/{repo}/actions/runs/{run_id}[?...] ``` Strip any query parameters from `run_id`. ### Step 2: Download artifacts List crash artifacts via the GitHub API, then download each as a `.zip` directly: **IMPORTANT**: Run each command as a separate Bash tool call (no `&&` chaining) to ensure auto-approval works with the user's permission patterns. ```bash # List artifacts — filter for names containing "crash" gh api repos/{owner}/{repo}/actions/runs/{run_id}/artifacts # Create output directory mkdir -p /tmp/fuzz-repro-{run_id} # Download each crash artifact by ID (separate command) gh api repos/{owner}/{repo}/actions/artifacts/{artifact_id}/zip > /tmp/fuzz-repro-{run_id}/.zip ``` This gives real `.zip` files that the triage script can consume directly. If no crash artifacts are found, report that the run has no crash artifacts and stop. Note: there may be duplicate artifact names (same name, different IDs) from retried jobs. Download the **most recent** one (highest artifact ID). ### Step 3: Determine mode Infer the protocol mode from the artifact name: - Contains "memcache" → `memcache` - Otherwise → `resp` ### Step 4: Check Dragonfly binary Check if the debug binary already exists and runs: ```bash ./build-dbg/dragonfly --version ``` Only build if the binary doesn't exist or fails to run: ```bash ninja -C build-dbg dragonfly ``` If `build-dbg` doesn't exist, run `./helio/blaze.sh` first. ### Step 5: Run triage_crashes.sh For each zip file, run: ```bash ./fuzz/triage_crashes.sh ./build-dbg/dragonfly /tmp/fuzz-repro-{run_id}/.zip ``` Capture the full output. ### Step 6: Analyze and report Parse the triage output for confirmed crashes. For each confirmed crash: 1. **Read the source** at the crash location — use the stack trace to identify the source file and line number, then read that code. 2. **Provide analysis**: likely root cause, what to investigate. Print a structured report: ``` ## Fuzz Crash Report **Run**: {url} **Artifacts**: {number} crash(es) found --- ### Crash NNNNNN **Reproduced**: Yes / No (false positive) **Signal**: SIGABRT (6) / SIGSEGV (11) / etc. **Stack trace**: \``` \``` **Analysis**: <1-3 sentences explaining the likely root cause based on the stack trace, the assertion message, and the crash input. Identify the source file and line number. Suggest what to investigate.> ``` ## Important Notes - The triage script uses port **6379** (resp) or **11211** (memcache). Ensure no other Dragonfly or Redis instance is using these ports. - The script adds `--rename_command` flags to avoid false positives from commands like DEBUG SLEEP that the fuzzer might generate. - Some crashes are non-deterministic (thread timing). The script reports these as "FALSE POSITIVE" — note this clearly, it doesn't mean the bug is invalid, just that it didn't reproduce on this run. - The script handles its own cleanup of Dragonfly processes. - Do NOT delete `/tmp/fuzz-repro-{run_id}/` — the user may want to inspect it. - If `gh run download` fails with permissions, suggest the user authenticate with `gh auth login`. ================================================ FILE: .ct.yaml ================================================ # See https://github.com/helm/chart-testing#configuration remote: origin target-branch: main chart-dirs: - contrib/charts helm-extra-args: --debug --timeout 60s check-version-increment: false validate-maintainers: false ================================================ FILE: .cursorrules ================================================ # Cursor AI Rules for Dragonfly **READ `AGENTS.md`** All project information, workflows, patterns, and guidelines are in `AGENTS.md`. ================================================ FILE: .devcontainer/alpine/devcontainer.json ================================================ { "name": "alpine-dev", "image": "ghcr.io/romange/alpine-dev", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake" ], "settings": { "cmake.buildDirectory": "/build", "extensions.ignoreRecommendations": true, "cmake.configureArgs": [] } } }, "mounts": [ "source=alpine-vol,target=/build,type=volume" ], "postCreateCommand": ".devcontainer/alpine/post-create.sh ${containerWorkspaceFolder}" } ================================================ FILE: .devcontainer/alpine/post-create.sh ================================================ #!/bin/bash containerWorkspaceFolder=$1 git config --global --add safe.directory ${containerWorkspaceFolder} git config --global --add safe.directory ${containerWorkspaceFolder}/helio mkdir -p /root/.local/share/CMakeTools ================================================ FILE: .devcontainer/fedora/devcontainer.json ================================================ { "name": "fedora30", "image": "ghcr.io/romange/fedora:30", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake" ], "settings": { "cmake.buildDirectory": "/build", "extensions.ignoreRecommendations": true } } }, "mounts": [ "source=fedora-vol,target=/build,type=volume" ] } ================================================ FILE: .devcontainer/fedora41/devcontainer.json ================================================ { "name": "fedora41", "image": "ghcr.io/romange/fedora:41", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake" ], "settings": { "cmake.buildDirectory": "/build", "extensions.ignoreRecommendations": true } } }, "mounts": [ "source=fedora41-vol,target=/build,type=volume" ] } ================================================ FILE: .devcontainer/ubuntu20/cmake-tools-kits.json ================================================ [ { "name": "GCC x86_64-linux-gnu", "compilers": { "C": "gcc", "CXX": "g++" }, "isTrusted": true } ] ================================================ FILE: .devcontainer/ubuntu20/devcontainer.json ================================================ { "name": "ubuntu20", "image": "ghcr.io/romange/ubuntu-dev:20", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake" ], "settings": { "cmake.buildDirectory": "/build", "extensions.ignoreRecommendations": true } } }, "mounts": [ "source=ubuntu20-vol,target=/build,type=volume" ], "postCreateCommand": ".devcontainer/ubuntu20/post-create.sh ${containerWorkspaceFolder}" } ================================================ FILE: .devcontainer/ubuntu20/post-create.sh ================================================ #!/bin/bash containerWorkspaceFolder=$1 git config --global --add safe.directory '*' mkdir -p /root/.local/share/CMakeTools cp ${containerWorkspaceFolder}/.devcontainer/ubuntu20/cmake-tools-kits.json /root/.local/share/CMakeTools/ ================================================ FILE: .devcontainer/ubuntu20-gcc14/devcontainer.json ================================================ { "name": "ubuntu20-gcc14", "image": "ghcr.io/romange/ubuntu-dev:20-gcc14", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake", "mk12.better-git-line-blame" ], "settings": { "cmake.buildDirectory": "/build", "cmake.configureArgs": [ "-DWITH_AWS=OFF", "-DWITH_GCP=OFF", "-DWITH_GPERF=OFF" ], "extensions.ignoreRecommendations": true } } }, "mounts": [ "source=ubuntu20-gcc14-vol,target=/build,type=volume" ], "postCreateCommand": ".devcontainer/ubuntu20/post-create.sh ${containerWorkspaceFolder}" } ================================================ FILE: .devcontainer/ubuntu22/devcontainer.json ================================================ { "name": "ubuntu22", "image": "ghcr.io/romange/ubuntu-dev:22", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake" ], "settings": { "cmake.buildDirectory": "/build", "extensions.ignoreRecommendations": true } } }, "mounts": [ "source=ubuntu22-vol,target=/build,type=volume" ], "postCreateCommand": ".devcontainer/ubuntu22/post-create.sh ${containerWorkspaceFolder}" } ================================================ FILE: .devcontainer/ubuntu22/post-create.sh ================================================ #!/bin/bash containerWorkspaceFolder=$1 git config --global --add safe.directory ${containerWorkspaceFolder} git config --global --add safe.directory ${containerWorkspaceFolder}/helio mkdir -p /root/.local/share/CMakeTools ================================================ FILE: .devcontainer/ubuntu24/devcontainer.json ================================================ { "name": "ubuntu24", "image": "ghcr.io/romange/ubuntu-dev:24", "customizations": { "vscode": { "extensions": [ "ms-vscode.cpptools", "ms-vscode.cmake-tools", "ms-vscode.cpptools-themes", "twxs.cmake" ], "settings": { "cmake.buildDirectory": "/build", "extensions.ignoreRecommendations": true } } }, "mounts": [ "source=ubuntu24-vol,target=/build,type=volume" ], "postCreateCommand": ".devcontainer/ubuntu24/post-create.sh ${containerWorkspaceFolder}" } ================================================ FILE: .dockerignore ================================================ _deps/* build-* tools/packaging/* .github/* ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help Dragonfly DB improve title: '' labels: 'bug' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Insert records using `command` 2. Query records using `command` 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Environment (please complete the following information):** - OS: [ubuntu 20.04] - Kernel: # Command: `uname -a` - Containerized?: [Bare Metal, Docker, Docker Compose, Docker Swarm, Kubernetes, Other] - Dragonfly Version: [e.g. 0.3.0] **Reproducible Code Snippet** ``` # Minimal code snippet to reproduce this bug ``` **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: true contact_links: - name: Dragonfly DB Discord Channel url: https://discord.gg/HsPjXGVH85 about: Get help! Ask questions, get support, and share ideas. - name: GitHub Discussions url: https://github.com/dragonflydb/dragonfly/discussions about: Ask Questions. Benchmark Questions Belong here. - name: Twitter url: https://twitter.com/romanger about: Follow Roman on Twitter ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for Dragonfly DB title: '' labels: 'feature request' assignees: '' --- **Did you search GitHub Issues and GitHub Discussions First?** Many users may find their feature is already being discussed. Help us keep duplicates to a minimum by searching for your feature first to see if it is already in progress. **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ ================================================ FILE: .github/actions/builder/action.yml ================================================ name: Build Dragonfly description: "Build Dragonfly with configurable CMake options" inputs: build-type: description: "CMake build type (Debug or Release)" required: false default: 'Debug' type: string build-dir: description: "Build directory name (relative to workspace root)" required: false default: 'build' type: string c-compiler: description: "C compiler to use" required: false default: '' type: string cxx-compiler: description: "C++ compiler to use" required: false default: '' type: string cxx-flags: description: "C++ compiler flags" required: false default: '-no-pie' type: string sanitizers: description: "Enable sanitizers (NoSanitizers or Sanitizers)" required: false default: 'NoSanitizers' type: string with-aws: description: "Build with AWS support" required: false default: 'ON' type: string targets: description: "Build targets to compile" required: false default: 'src/all' type: string runs: using: "composite" steps: - name: Configure CMake shell: bash run: | # Set sanitizer flags ASAN="OFF" USAN="OFF" if [ '${{ inputs.sanitizers }}' = 'Sanitizers' ]; then echo "Enabling ASAN/USAN" ASAN="ON" USAN="ON" fi # Build cmake command array CMAKE_CMD=(cmake -B "${GITHUB_WORKSPACE}/${{ inputs.build-dir }}" -DCMAKE_BUILD_TYPE="${{ inputs.build-type }}" -GNinja ) # Add optional compiler flags if [ -n "${{ inputs.c-compiler }}" ]; then CMAKE_CMD+=(-DCMAKE_C_COMPILER="${{ inputs.c-compiler }}") fi if [ -n "${{ inputs.cxx-compiler }}" ]; then CMAKE_CMD+=(-DCMAKE_CXX_COMPILER="${{ inputs.cxx-compiler }}") fi if [ -n "${{ inputs.cxx-flags }}" ]; then CMAKE_CMD+=(-DCMAKE_CXX_FLAGS="${{ inputs.cxx-flags }}") fi # Add fixed options CMAKE_CMD+=( -DPRINT_STACKTRACES_ON_SIGNAL=ON -DWITH_AWS="${{ inputs.with-aws }}" -DWITH_GCP=OFF -DWITH_UNWIND=OFF -DWITH_GPERF=OFF -DWITH_ASAN="${ASAN}" -DWITH_USAN="${USAN}" ) # Execute CMake echo "Running: ${CMAKE_CMD[@]}" "${CMAKE_CMD[@]}" - name: Build shell: bash run: | cd ${GITHUB_WORKSPACE}/${{ inputs.build-dir }} echo "Building target: ${{ inputs.targets }}" ninja ${{ inputs.targets }} ================================================ FILE: .github/actions/fuzzing/action.yml ================================================ name: Run AFL++ Fuzzing description: "Run AFL++ fuzzing campaign with configurable parameters" inputs: mode: description: "Fuzzing mode: 'smoke' (stop on first crash) or 'long' (collect all crashes)" required: true type: string target: description: "Fuzz target: 'resp' or 'memcache'" required: false default: 'resp' type: string duration-minutes: description: "Fuzzing duration in minutes" required: true type: string run-number: description: "GitHub run number for artifact naming" required: true type: string extra-seeds-dir: description: "Directory with additional seed files (initial fuzzer inputs) to merge into the corpus" required: false default: '' focus-commands: description: "JSON list of command names for the mutator to prefer (~70% selection weight)" required: false default: '' build: description: "Build the binary before fuzzing. Set to 'false' when reusing a binary built by a previous action call in the same job — fails if the binary is missing." required: false default: 'true' outputs: hang_count: description: "Number of unique hangs found during fuzzing" value: ${{ steps.analyze.outputs.hang_count }} crash_count: description: "Number of unique crashes found during fuzzing" value: ${{ steps.analyze.outputs.crash_count }} runs: using: "composite" steps: - name: Verify AFL++ installation shell: bash run: | echo "Verifying AFL++ installation..." afl-fuzz -h | head -5 || true # Verify AFL++ compilers are available which afl-clang-fast which afl-clang-fast++ afl-clang-fast --version - name: Configure system for fuzzing shell: bash run: | echo "Configuring system for AFL++ fuzzing..." afl-system-config || true echo core > /proc/sys/kernel/core_pattern || echo "Warning: Could not set core_pattern" echo "System configured" - name: Build Dragonfly with AFL++ shell: bash run: | if [ "${{ inputs.build }}" = "false" ]; then if [ ! -f "./build-dbg/dragonfly" ]; then echo "::error::build=false but binary not found at ./build-dbg/dragonfly" exit 1 fi echo "Skipping build, reusing existing binary" ls -lh ./build-dbg/dragonfly else echo "Building Dragonfly with AFL++ instrumentation..." ./helio/blaze.sh -DUSE_AFL:BOOL=ON cd ./build-dbg && ninja dragonfly && cd .. echo "Build complete" ls -lh ./build-dbg/dragonfly fi - name: Merge targeted seeds shell: bash if: ${{ inputs.extra-seeds-dir != '' }} run: | EXTRA_DIR="${{ inputs.extra-seeds-dir }}" SEEDS_DIR="fuzz/seeds/${{ inputs.target }}" # Copy only seed files, skip metadata like focus_commands.json COUNT=$(find "$EXTRA_DIR" -maxdepth 1 -type f ! -name '*.json' 2>/dev/null | wc -l) if [ "$COUNT" -gt 0 ]; then echo "Merging ${COUNT} targeted seeds into corpus" find "$EXTRA_DIR" -maxdepth 1 -type f ! -name '*.json' -exec cp -t "$SEEDS_DIR/" {} + else echo "No targeted seed files to merge" fi - name: Run AFL++ fuzzing shell: bash run: | MODE="${{ inputs.mode }}" DURATION_MINUTES="${{ inputs.duration-minutes }}" echo "Starting AFL++ fuzzing..." echo "Configuration:" echo " Target: ${{ inputs.target }}" echo " Mode: ${MODE}" echo " Duration: ${DURATION_MINUTES} minutes" cd fuzz export BUILD_DIR="${GITHUB_WORKSPACE}/build-dbg" # Run fuzzer with timeout timeout ${DURATION_MINUTES}m ./run_fuzzer.sh "${{ inputs.target }}" || EXIT_CODE=$? # timeout returns 124 if it timed out (expected), 0 if finished naturally if [ "${EXIT_CODE:-0}" -eq 124 ]; then echo "Fuzzing completed (timeout reached)" elif [ "${EXIT_CODE:-0}" -eq 0 ]; then echo "Fuzzing completed normally" else echo "::error::Fuzzer failed with exit code ${EXIT_CODE}" exit 1 fi env: # Mode-specific environment variables AFL_BENCH_UNTIL_CRASH: ${{ inputs.mode == 'smoke' && '1' || '' }} AFL_NO_UI: 1 AFL_AUTORESUME: 1 AFL_I_DONT_CARE_ABOUT_MISSING_CRASHES: 1 AFL_TESTCACHE_SIZE: ${{ inputs.mode == 'smoke' && '50' || '500' }} AFL_SKIP_CPUFREQ: 1 AFL_FAST_CAL: ${{ inputs.mode == 'long' && '1' || '' }} AFL_PERSISTENT_RECORD: 1000 AFL_CUSTOM_MUTATOR_ONLY: 1 FUZZ_FOCUS_COMMANDS: ${{ inputs.focus-commands }} - name: Analyze fuzzing results shell: bash if: always() id: analyze run: | echo "Analyzing fuzzing results..." TARGET="${{ inputs.target }}" CRASHES_DIR="fuzz/artifacts/${TARGET}/default/crashes" HANGS_DIR="fuzz/artifacts/${TARGET}/default/hangs" QUEUE_DIR="fuzz/artifacts/${TARGET}/default/queue" # Count results CRASH_COUNT=0 HANG_COUNT=0 CORPUS_SIZE=0 if [ -d "$CRASHES_DIR" ]; then CRASH_COUNT=$(find "$CRASHES_DIR" -maxdepth 1 -type f -name 'id:*' 2>/dev/null | wc -l) fi if [ -d "$HANGS_DIR" ]; then HANG_COUNT=$(find "$HANGS_DIR" -maxdepth 1 -type f -name 'id:*' ! -name 'RECORD:*' 2>/dev/null | wc -l) fi if [ -d "$QUEUE_DIR" ]; then CORPUS_SIZE=$(find "$QUEUE_DIR" -type f ! -name ".state" 2>/dev/null | wc -l) fi echo "Fuzzing Results:" echo " Crashes: $CRASH_COUNT" echo " Hangs: $HANG_COUNT" echo " Corpus size: $CORPUS_SIZE" # Show statistics for long mode if [ "${{ inputs.mode }}" = "long" ]; then STATS_FILE="fuzz/artifacts/${TARGET}/default/fuzzer_stats" if [ -f "$STATS_FILE" ]; then echo "" echo "Key Statistics:" grep -E "execs_done|execs_per_sec|paths_total|corpus_count|unique_crashes|unique_hangs|last_crash|last_hang" "$STATS_FILE" || true fi fi echo "hang_count=${HANG_COUNT}" >> "$GITHUB_OUTPUT" echo "crash_count=${CRASH_COUNT}" >> "$GITHUB_OUTPUT" # Fail the job if crashes or hangs were found if [ "$CRASH_COUNT" -gt 0 ]; then echo "::error::Found $CRASH_COUNT crash(es)!" echo "" echo "Crash input files (excluding RECORD):" find "$CRASHES_DIR" -maxdepth 1 -name 'id:*' ! -name 'RECORD:*' -type f | sort || true exit 1 fi if [ "$HANG_COUNT" -gt 0 ]; then echo "::error::Found $HANG_COUNT hang(s)!" echo "" echo "Hang input files (excluding RECORD):" find "$HANGS_DIR" -maxdepth 1 -name 'id:*' ! -name 'RECORD:*' -type f | sort || true exit 1 fi if [ "$CORPUS_SIZE" -gt 0 ]; then echo "No crashes found - fuzzing test passed!" else echo "No fuzzing artifacts found (fuzzer may not have started)" fi - name: Package crash artifacts shell: bash if: failure() && steps.analyze.outputs.crash_count > 0 run: | CRASHES_DIR="$(pwd)/fuzz/artifacts/${{ inputs.target }}/default/crashes" if [ ! -d "$CRASHES_DIR" ] || [ -z "$(ls -A "$CRASHES_DIR" 2>/dev/null)" ]; then echo "No crash artifacts to package" exit 0 fi echo "Raw crash directory contents:" ls -la "$CRASHES_DIR" mkdir -p fuzz/packaged # Find crash input files (not RECORD files) find "$CRASHES_DIR" -maxdepth 1 -name 'id:*' ! -name 'RECORD:*' -type f | while read -r f; do CRASH_ID=$(basename "$f" | sed 's/^id:\([0-9]*\),.*/\1/') echo "Packaging crash ${CRASH_ID}..." if ( cd fuzz && ./package_crash.sh "$CRASH_ID" "$CRASHES_DIR" ); then mv "fuzz/crash-${CRASH_ID}.tar.gz" fuzz/packaged/ 2>/dev/null || true else echo "Warning: failed to package crash ${CRASH_ID}, continuing..." fi done echo "Packaged crashes:" ls -lh fuzz/packaged/ 2>/dev/null || echo " (none)" - name: Upload crash artifacts if: failure() && steps.analyze.outputs.crash_count > 0 uses: actions/upload-artifact@v4 with: name: fuzz-${{ inputs.mode }}-${{ inputs.target }}-crashes-${{ inputs.run-number }} path: | fuzz/packaged/*.tar.gz fuzz/artifacts/${{ inputs.target }}/default/fuzzer_stats retention-days: 10 if-no-files-found: ignore - name: Package hang artifacts shell: bash if: failure() && steps.analyze.outputs.hang_count > 0 run: | HANGS_DIR="fuzz/artifacts/${{ inputs.target }}/default/hangs" if [ ! -d "$HANGS_DIR" ] || [ -z "$(ls -A "$HANGS_DIR" 2>/dev/null)" ]; then echo "No hang artifacts to package" exit 0 fi mkdir -p fuzz/packaged_hangs tar -czf "fuzz/packaged_hangs/hangs-${{ inputs.target }}.tar.gz" \ -C "$(dirname "$HANGS_DIR")" hangs/ echo "Packaged hangs:" ls -lh fuzz/packaged_hangs/ - name: Upload hang artifacts if: failure() && steps.analyze.outputs.hang_count > 0 uses: actions/upload-artifact@v4 with: name: fuzz-${{ inputs.mode }}-${{ inputs.target }}-hangs-${{ inputs.run-number }} path: | fuzz/packaged_hangs/*.tar.gz fuzz/artifacts/${{ inputs.target }}/default/fuzzer_stats retention-days: 10 if-no-files-found: ignore ================================================ FILE: .github/actions/lint-test-chart/action.yml ================================================ name: Lint test chart description: "Run lint test chart" runs: using: "composite" steps: - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Helm uses: azure/setup-helm@v4 - uses: actions/setup-python@v5 with: python-version: "3.9" check-latest: true - name: Chart Rendering Tests shell: bash run: | go test -v ./contrib/charts/dragonfly/... - name: Set up chart-testing uses: helm/chart-testing-action@v2.6.1 - name: Run chart-testing (list-changed) id: list-changed shell: bash run: | changed=$(ct list-changed --config .ct.yaml) if [[ -n "$changed" ]]; then echo "changed=true" >> $GITHUB_OUTPUT fi - name: Run chart-testing (lint) shell: bash run: | ct \ lint \ --config .ct.yaml \ ${{github.event_name == 'workflow_dispatch' && '--all'}} ; - name: Create kind cluster uses: helm/kind-action@v1 - name: Install Dependencies shell: bash run: | curl -sL https://github.com/prometheus-operator/prometheus-operator/releases/download/v0.73.0/bundle.yaml | kubectl create -f - - name: Getting cluster ready shell: bash run: | kubectl label nodes chart-testing-control-plane key/node-kind=high-memory - name: Run chart-testing (install) shell: bash run: | ct \ install \ --config .ct.yaml \ --debug \ --helm-extra-set-args "--set=image.repository=ghcr.io/${{ github.repository }},probes=null" \ ${{github.event_name == 'workflow_dispatch' && '--all'}} ; ================================================ FILE: .github/actions/multi-registry-docker-login/action.yml ================================================ name: 'Multi-Registry Docker Login' description: 'Authenticate with both GHCR and Google Artifact Registry' inputs: GITHUB_TOKEN: description: 'GitHub token for GHCR' required: true GCP_SA_KEY: description: 'Google Service Account JSON key' required: true runs: using: "composite" steps: - name: Login to GHCR uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} password: ${{ inputs.GITHUB_TOKEN }} - name: Login to Google Artifact Registry uses: docker/login-action@v3 with: registry: us-central1-docker.pkg.dev username: _json_key password: ${{ inputs.GCP_SA_KEY }} ================================================ FILE: .github/actions/regression-tests/action.yml ================================================ name: Regression Tests description: "Run regression tests" inputs: dfly-executable: required: true type: string gspace-secret: required: false type: string run-only-on-ubuntu-latest: # 'true' or 'false' cause boolean # is not supported in composite actions required: true type: string build-folder-name: required: true type: string filter: required: false type: string aws-access-key-id: required: false type: string description: "AWS access key ID (optional if using OIDC - credentials set by workflow)" aws-secret-access-key: required: false type: string description: "AWS secret access key (optional if using OIDC - credentials set by workflow)" s3-bucket: required: true type: string epoll: required: false type: string runs: using: "composite" # bring back timeouts once composite actions start supporting them # timeout-minutes: 20 steps: - name: Sync valkey-search tests uses: ./.github/actions/sync-valkey-tests - name: Free disk space if: contains(runner.labels, 'self-hosted') == false shell: bash run: | echo "===================Before freeing up space ============================================" df -h rm -rf /hostroot/usr/share/dotnet rm -rf /hostroot/usr/local/share/boost rm -rf /hostroot/usr/local/lib/android rm -rf /hostroot/opt/ghc echo "===================After freeing up space ============================================" df -h - name: Install Python test requirements shell: bash run: | cd ${GITHUB_WORKSPACE}/tests # https://peps.python.org/pep-0668/#keep-the-marker-file-in-container-images if compgen -G '/usr/lib/python3.*/EXTERNALLY-MANAGED' > /dev/null; then pip3 install --break-system-packages -r dragonfly/requirements.txt else pip3 install -r dragonfly/requirements.txt fi - name: Run S3 snapshot tests with MinIO if: inputs.s3-bucket != '' shell: bash run: | echo "=== Running S3 snapshot tests with local MinIO ===" cd ${GITHUB_WORKSPACE}/tests export DRAGONFLY_PATH="${GITHUB_WORKSPACE}/${{inputs.build-folder-name}}/${{inputs.dfly-executable}}" # MinIO binary is downloaded and started by conftest.py when MINIO_S3_ENDPOINT is set MINIO_S3_ENDPOINT=http://localhost:9000 timeout 10m pytest -k "s3" --timeout=300 --color=yes dragonfly/snapshot_test.py --log-cli-level=INFO -v - name: Run PyTests id: main shell: bash run: | ls -l ${GITHUB_WORKSPACE}/ cd ${GITHUB_WORKSPACE}/tests echo "Current commit is ${{github.sha}}" # used by PyTests export DRAGONFLY_PATH="${GITHUB_WORKSPACE}/${{inputs.build-folder-name}}/${{inputs.dfly-executable}}" export ROOT_DIR="${GITHUB_WORKSPACE}/tests/dragonfly/valkey_search" export UBSAN_OPTIONS=print_stacktrace=1:halt_on_error=1 # to crash on errors export FILTER="${{inputs.filter}}" # Exclude large tests unless explicitly requested if [[ "$FILTER" == "large" ]]; then : # keep as-is, run only large tests elif [[ -n "$FILTER" ]]; then FILTER="(not large) and ($FILTER)" else FILTER="not large" fi if [[ "${{inputs.epoll}}" == 'epoll' ]]; then FILTER="$FILTER and not exclude_epoll" # Run only replication tests with epoll timeout 80m pytest -m "$FILTER" --durations=10 --timeout=300 --color=yes --json-report --json-report-file=report.json dragonfly --df force_epoll=true --log-cli-level=INFO || code=$? else # Run only replication tests with iouring timeout 80m pytest -m "$FILTER" --durations=10 --timeout=300 --color=yes --json-report --json-report-file=report.json dragonfly --log-cli-level=INFO || code=$? fi # timeout returns 124 if we exceeded the timeout duration if [[ $code -eq 124 ]]; then # Add an extra new line here because when tests timeout the first line below continues from the test failure name echo "\n" echo "🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑" echo "🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 TESTS TIMEDOUT 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑" echo "🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑" # Copy the last log file because we timedout and pytest did not copy it over # the /tmp/failed/ folder cat /tmp/last_test_log_dir.txt | xargs -I {} mv {}/ /tmp/failed/ exit 1 fi # when a test fails in pytest it returns 1 but there are other return codes as well so we just check if the code is non zero if [[ $code -ne 0 ]]; then exit 1 fi env: # Add environment variables to enable the S3 snapshot test. # AWS credentials: if inputs provided, use them; otherwise rely on workflow OIDC auth DRAGONFLY_S3_BUCKET: ${{ inputs.s3-bucket }} AWS_ACCESS_KEY_ID: ${{ inputs.aws-access-key-id || env.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ inputs.aws-secret-access-key || env.AWS_SECRET_ACCESS_KEY }} AWS_SESSION_TOKEN: ${{ env.AWS_SESSION_TOKEN }} AWS_REGION: ${{ env.AWS_REGION || 'us-east-1' }} - name: Send notification on failure if: failure() && github.ref == 'refs/heads/main' shell: bash run: | get_failed_tests() { local report_file=$1 echo $(jq -r '.tests[] | select(.outcome == "failed") | .nodeid' "$report_file") } cd ${GITHUB_WORKSPACE}/tests failed_tests="" if [ -f report.json ]; then failed_tests=$(get_failed_tests report.json) fi KIND="iouring" if [[ "${{inputs.epoll}}" == 'epoll' ]]; then KIND="epoll" fi job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="Regression $KIND tests failed.\\n The commit is: ${{github.sha}}.\\n $failed_tests \\n Job Link: ${job_link}\\n" curl -s \ -X POST \ -H 'Content-Type: application/json' \ '${{ inputs.gspace-secret }}' \ -d '{"text": "'"${message}"'"}' - name: Copy binary on a self hosted runner if: failure() && contains(runner.labels, 'self-hosted') shell: bash run: | cd ${GITHUB_WORKSPACE}/build timestamp=$(date +%Y-%m-%d_%H:%M:%S) mv ./dragonfly /var/crash/dragonfly_${timestamp} ================================================ FILE: .github/actions/repeat/action.yml ================================================ name: Run Tests On Repeat description: "Repeat specific tests" inputs: dfly-executable: required: true type: string run-only-on-ubuntu-latest: required: true type: string build-folder-name: required: true type: string expression: required: false type: string aws-access-key-id: required: false type: string description: "AWS access key ID (optional if using OIDC - credentials set by workflow)" aws-secret-access-key: required: false type: string description: "AWS secret access key (optional if using OIDC - credentials set by workflow)" s3-bucket: required: true type: string count: required: true type: number timeout: required: true type: string epoll: required: true type: string vmodule_expression: required: true type: string runs: using: "composite" steps: - name: Repeat pytests id: main shell: bash run: | ls -l ${GITHUB_WORKSPACE}/ cd ${GITHUB_WORKSPACE}/tests echo "Current commit is ${{github.sha}}" pip3 install -r dragonfly/requirements.txt # used by PyTests export DRAGONFLY_PATH="${GITHUB_WORKSPACE}/${{inputs.build-folder-name}}/${{inputs.dfly-executable}}" export UBSAN_OPTIONS=print_stacktrace=1:halt_on_error=1 # to crash on errors if [[ "${{ inputs.epoll }}" == "epoll" ]]; then FORCE_EPOLL="--df force_epoll=true" else FORCE_EPOLL="" fi if [[ $"{{ inputs.vmodule_expression }}" != "" ]]; then VMOD="--df vmodule=${{ inputs.vmodule_expression }}" else VMOD="" fi echo Running command: timeout ${{ inputs.timeout }} pytest ${{ inputs.expression }} --drop-data-after-each-test ${FORCE_EPOLL} ${VMOD} --color=yes --json-report --json-report-file=report.json --log-cli-level=DEBUG --count=${{ inputs.count }} timeout ${{ inputs.timeout }} pytest ${{ inputs.expression }} --drop-data-after-each-test ${FORCE_EPOLL} ${VMOD} --color=yes --json-report --json-report-file=report.json --log-cli-level=DEBUG --count=${{ inputs.count }} || code=$? # timeout returns 124 if we exceeded the timeout duration if [[ $code -eq 124 ]]; then # Add an extra new line here because when tests timeout the first line below continues from the test failure name echo "\n" echo "🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑" echo "🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 TESTS TIMEDOUT 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑" echo "🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑 🛑" # Copy the last log file because we timedout and pytest did not copy it over # the /tmp/failed/ folder cat /tmp/last_test_log_dir.txt | xargs -I {} mv {}/ /tmp/failed/ exit 1 fi # when a test fails in pytest it returns 1 but there are other return codes as well so we just check if the code is non zero if [[ $code -ne 0 ]]; then exit 1 fi env: # Add environment variables to enable the S3 snapshot test. # AWS credentials: if inputs provided, use them; otherwise rely on workflow OIDC auth DRAGONFLY_S3_BUCKET: ${{ inputs.s3-bucket }} AWS_ACCESS_KEY_ID: ${{ inputs.aws-access-key-id || env.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ inputs.aws-secret-access-key || env.AWS_SECRET_ACCESS_KEY }} AWS_SESSION_TOKEN: ${{ env.AWS_SESSION_TOKEN }} AWS_REGION: ${{ env.AWS_REGION || 'us-east-1' }} ================================================ FILE: .github/actions/sync-valkey-tests/action.yml ================================================ name: Sync valkey-search tests description: "Synchronizes valkey-search tests using a fixed revision" runs: using: composite steps: - name: Sync valkey-search tests shell: bash run: | cd ${GITHUB_WORKSPACE}/tests/dragonfly/valkey_search # main branch revision ./sync-valkey-search-tests.sh 90124dc91756b24cb2e58e5c4eea5b8d53004ea6 ================================================ FILE: .github/actions/test-docker/action.yml ================================================ name: Test Docker Image inputs: image_id: required: true type: string name: required: true type: string runs: using: "composite" steps: - name: Test Image shell: bash run: | echo "Testing ${{ inputs.name }} image" docker pull ${{inputs.image_id}} docker image inspect ${{inputs.image_id}} # docker run with port-forwarding docker run --name test -d -p 6379:6379 ${{inputs.image_id}} until [ "`docker inspect -f {{.State.Health.Status}} test`"=="healthy" ]; do sleep 0.1; done; ================================================ FILE: .github/bullmq-skipped-tests.txt ================================================ # BullMQ tests excluded from CI runs against Dragonfly # # Format: one pattern per line (used as JS regex in mocha --grep --invert) # Categories: # DRAGONFLY_BUG - Dragonfly does not support this behaviour yet # FLAKY - Test has race conditions / timing issues unrelated to Dragonfly # ── DRAGONFLY BUG ──────────────────────────────────────────────────────────── # BullMQ Lua scripts access keys that are not declared in KEYS[]. # Dragonfly enforces strict Lua key declaration; allow-undeclared-keys causes # global transaction mode and breaks other tests. handle errors.*for flows Flows - addBulk.*handle errors # Job.finished: job hash persists after removeOnComplete instead of being deleted. rejects with missing key for job message # ── FLAKY ───────────────────────────────────────────────────────────────────── # deduplication key removal races with the 'deduplicated' QueueEvents listener. # XREAD from '$' is noted as unstable in upstream BullMQ code. removes deduplication key # QueueEvents 'waiting' event: XREAD from '$' is unstable on CI. # Upstream comment: "additional delay since XREAD from '$' is unstable" emits waiting when a job has been added # getWorkers: race between worker 'ready' event and assertion. gets all workers for this queue only # getWorkers (shared connection): upstream test file has comment # "Test is very flaky on CI, so we skip it for now." gets all workers for a given queue # Job Scheduler monthly repeat: sinon fake-timer races with real Redis async ops. # The worker loop does not advance in time before the 200 s timeout expires. should repeat 7:th day every month at 9:25 ================================================ FILE: .github/copilot-instructions.md ================================================ --- description: 'Code review guidelines for GitHub copilot in this project' applyTo: '**' excludeAgent: ["coding-agent"] --- # Code Review Instructions Keep reviews high-signal and minimal. Only comment on real bugs with high confidence. ## Comment Only When - The issue is a correctness, security, concurrency, or architecture problem. - The impact is clear and non-trivial. - You can point to concrete evidence in the diff (not speculation). ## Avoid - Style, formatting, naming, or minor performance nits. - Optional refactors or “nice to have” suggestions. - Praise, restating the code, or long explanations. - Duplicate comments for the same root cause. ## Review Style - Be terse: 1-2 sentences per issue. - Include file and line references when possible. - If no issues are found, say “No issues found.” - Provide concrete suggestions for fixes when possible, or examples to illustrate the problem. ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" open-pull-requests-limit: 1 groups: actions: patterns: - "*" - package-ecosystem: "gomod" directories: - "/contrib/charts/dragonfly" - "/tools/replay" schedule: interval: "weekly" open-pull-requests-limit: 1 #uncomment it to group dependency updates #groups: #go-mod: #patterns: #- "*" ignore: # Disable all updates except security updates #remove an item from ignore list to get dependency updates of that kind - dependency-name: "*" update-types: - "version-update:semver-major" - "version-update:semver-minor" - "version-update:semver-patch" - package-ecosystem: "pip" directories: - "/tests/dragonfly" - "/tools" schedule: interval: "weekly" #uncomment it to group dependency updates #groups: #py-dep: #patterns: #- "*" ignore: # Disable all updates except security updates #remove an item from ignore list to get dependency updates of that kind - dependency-name: "*" update-types: - "version-update:semver-major" - "version-update:semver-minor" - "version-update:semver-patch" ================================================ FILE: .github/instructions/code-review.instructions.md ================================================ --- description: 'Code review instructions for Dragonfly' applyTo: '**' excludeAgent: ["coding-agent"] --- # Dragonfly Code Review Instructions Dragonfly is a high-performance, Redis-compatible in-memory data store written in C++20 with a unique shared-nothing, fiber-based architecture. Code reviews must prioritize correctness, security, and architectural compliance specific to this threading model. ## Review Priorities ### 🔴 CRITICAL (Block merge immediately) **Threading Model Violations** (causes deadlocks/crashes): - ❌ **NEVER** use `std::thread`, `std::mutex`, `std::condition_variable`, or standard library threading primitives - ✅ **ALWAYS** use fiber-aware equivalents: `util::fb2::Mutex`, `util::fb2::Fiber`, `util::fb2::CondVar` from `util/fibers/` **Architecture Violations**: - ❌ Cross-shard data access without proper synchronization - ✅ Per-shard operations only (see `src/server/db_slice.cc` for patterns) **Security Vulnerabilities**: - Authentication/authorization bypass in ACL code (`src/server/acl/`) - Exposed secrets, credentials in code or logs - Buffer overflows, use-after-free, memory safety issues **Correctness Issues**: - Race conditions in fiber scheduling - Logic errors in transaction handling (`src/server/transaction.cc`) - Data corruption risks in DashTable operations (`src/core/dash.h`) ### 🟡 IMPORTANT (Requires discussion) **Code Quality**: - Missing error handling (should return `OpStatus` from `facade/op_status.h`) - Obvious memory leaks (check ASAN reports) - Performance bottlenecks in hot paths (unnecessary allocations, N+1 patterns) **Test Coverage**: - New features without tests (both C++ unit tests and Python integration tests) - Changes to critical paths (transactions, replication, cluster) without test coverage - Modified code that fails existing tests **Style Violations** (severe only): - Not following naming conventions: `snake_case` variables, `PascalCase` functions, `kPascalCase` constants - Code that won't pass pre-commit hooks (clang-format, 100 char limit) ### 🟢 SUGGESTIONS (Non-blocking, comment only if obvious) - Over-engineering: adding abstraction layers, feature flags, or configurability not requested - Missing comments on complex fiber synchronization logic - Premature optimization without profiling ## Dragonfly-Specific Patterns ### ✅ DO: Correct Patterns **Threading & Synchronization**: ```cpp // ✅ CORRECT: Fiber-aware mutex util::fb2::Mutex mutex_; std::lock_guard lock(mutex_); // ✅ CORRECT: Fiber-aware operations util::fb2::Fiber fb = util::fb2::Fiber("name", [&] { /* work */ }); ``` **Per-Shard Design**: ```cpp // ✅ CORRECT: Operate on shard-local data void DbSlice::SomeOperation() { // Access only this shard's data auto& db_slice = cntx->ns->GetCurrentDbSlice(); } ``` ### ❌ DON'T: Anti-Patterns **Threading**: ```cpp // ❌ WRONG: Standard library threading (causes deadlocks!) std::mutex mutex_; std::thread worker; std::condition_variable cv_; ``` **Global State**: ```cpp // ❌ WRONG: Global mutable state (breaks shared-nothing architecture) static std::unordered_map global_cache; ``` **Build Commands**: - ❌ Don't suggest `./tools/docker/build.sh` or `make` for incremental builds - ✅ Use `cd build-dbg && ninja ` instead ## Code Review Checklist When reviewing Dragonfly code, verify: 1. **Architecture Compliance**: - [ ] No standard library threading primitives (`std::thread`, `std::mutex`) - [ ] No global mutable state - [ ] Fiber-aware synchronization used correctly - [ ] Follows per-shard, shared-nothing design 2. **Security**: - [ ] No OWASP vulnerabilities (injection, XSS, auth bypass) - [ ] No hardcoded secrets or credentials - [ ] Input validation on command arguments - [ ] Safe memory operations (no buffer overflows) 3. **Testing**: - [ ] New functionality has test coverage - [ ] Tests build and pass: `cd build-dbg && ninja && ./` - [ ] No test regressions 4. **Style & Formatting**: - [ ] Follows naming conventions (snake_case vars, PascalCase functions) - [ ] Will pass pre-commit checks (clang-format, 100 char limit) - [ ] Code compiles without warnings (CI uses `-Werror`) 5. **Helio Submodule**: - [ ] No direct edits to `helio/` directory (it's a git submodule) ## Common False Positives to Ignore These are **NOT** issues in Dragonfly's design. Do not comment on: 1. **Single-threaded-looking code**: Per-shard operations intentionally avoid locks 2. **Custom allocators**: mimalloc is used intentionally for performance 3. **Manual memory management**: Required for performance-critical paths 4. **Complex template metaprogramming**: DashTable uses advanced C++20 features 5. **Missing const**: Not always applicable in high-performance code ## Review Style Guidelines 1. **Be specific**: Reference file:line, explain WHY it's wrong 2. **Show examples**: Demonstrate the correct pattern with code 3. **Prioritize**: Security and correctness over style 4. **Link to docs**: Reference `docs/df-share-nothing.md`, `docs/transaction.md`, etc. 5. **Be concise**: Dragonfly team values focused, actionable feedback ## Example Review Comments **❌ BAD - Too noisy**: > "Consider using auto here for type inference" **✅ GOOD - Actionable and specific**: > "🔴 CRITICAL: Line 42 uses `std::mutex`. This will cause fiber deadlocks. Replace with `util::fb2::Mutex` from helio/util/fibers/. See src/server/set_family.cc:123 for correct pattern." **✅ GOOD - Security focused**: > "🔴 SECURITY: Line 58 doesn't validate `user_input` before passing to eval(). Vulnerable to command injection. Add validation or use SafeEval()." **✅ GOOD - Architecture violation**: > "🟡 ARCHITECTURE: Line 91 accesses global `cache_map`. Dragonfly uses shared-nothing design - each shard must have its own cache. See docs/df-share-nothing.md" --- **Key Files Reference**: See AGENTS.md for complete codebase structure, build commands, and testing procedures. ================================================ FILE: .github/workflows/benchmark.yml ================================================ name: benchmark-tests on: schedule: - cron: "0 9 * * *" # run at 6 AM UTC workflow_dispatch: permissions: contents: read jobs: benchmark: if: github.repository == 'dragonflydb/dragonfly' strategy: matrix: config: - operator: apiVersion: "dragonflydb.io/v1alpha1" kind: "Dragonfly" metadata: labels: app.kubernetes.io/name: "dragonfly" app.kubernetes.io/instance: "dragonfly-sample" app.kubernetes.io/part-of: "dragonfly-operator" app.kubernetes.io/managed-by: "kustomize" app.kubernetes.io/created-by: "dragonfly-operator" name: "dragonfly-sample" spec: image: "ghcr.io/dragonflydb/dragonfly:latest" args: ["--cache_mode"] replicas: 2 resources: requests: cpu: "2" memory: "2000Mi" limits: cpu: "2" memory: "2000Mi" runs-on: ubuntu-latest container: image: ghcr.io/romange/benchmark-dev:latest options: --security-opt seccomp=unconfined permissions: id-token: write steps: - name: Setup namespace name id: setup run: echo "namespace=benchmark-$(date +"%Y-%m-%d-%s")" >> $GITHUB_OUTPUT - uses: actions/checkout@v6 with: submodules: true - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }} aws-region: ${{ vars.AWS_REGION }} - name: Update kube config run: aws eks update-kubeconfig --name "$EKS_CLUSTER_NAME" --region "$AWS_REGION" env: AWS_REGION: ${{ vars.AWS_REGION }} EKS_CLUSTER_NAME: dev - name: Scale up run: | set -x aws autoscaling set-desired-capacity --auto-scaling-group-name "$AUTOSCALING_GROUP" --desired-capacity "$DESIRED_CAPACITY" env: AUTOSCALING_GROUP: ${{ vars.DEV_EKS_AS_GROUP }} DESIRED_CAPACITY: 1 - name: Install the CRD and Operator run: | # Install the CRD and Operator kubectl apply -f https://raw.githubusercontent.com/dragonflydb/dragonfly-operator/main/manifests/dragonfly-operator.yaml - name: Apply Configuration run: | set -x kubectl create namespace ${{ steps.setup.outputs.namespace }} || true echo '${{ toJson(matrix.config.operator) }}' | kubectl apply -n ${{ steps.setup.outputs.namespace }} -f - - name: Wait For Service run: | set -x kubectl wait -n ${{ steps.setup.outputs.namespace }} dragonfly/dragonfly-sample --for=jsonpath='{.status.phase}'=ready --timeout=180s kubectl wait -n ${{ steps.setup.outputs.namespace }} pods --selector app=dragonfly-sample --for condition=Ready --timeout=120s kubectl describe -n ${{ steps.setup.outputs.namespace }} pod dragonfly-sample-0 - name: Run Memtier Benchmark shell: bash run: | kubectl apply -n ${{ steps.setup.outputs.namespace }} -f tools/benchmark/k8s-benchmark-job.yaml - name: Version upgrade shell: bash run: | # benchmark is running, wait for 30 seconds before version upgrade sleep 30 kubectl patch dragonfly dragonfly-sample -n ${{ steps.setup.outputs.namespace }} --type merge -p '{"spec":{"image":"ghcr.io/dragonflydb/dragonfly-weekly:latest"}}' - name: Wait for Memtier Benchmark fail shell: bash run: | # Memtier benchmark run will fail at some point because old master shutdown on version upgrade kubectl wait --for=condition=failed --timeout=120s -n ${{ steps.setup.outputs.namespace }} jobs/memtier-benchmark 2>/dev/null kubectl logs -n ${{ steps.setup.outputs.namespace }} -f jobs/memtier-benchmark kubectl delete -n ${{ steps.setup.outputs.namespace }} jobs/memtier-benchmark - name: Run Memtier Benchmark again shell: bash run: | kubectl apply -n ${{ steps.setup.outputs.namespace }} -f tools/benchmark/k8s-benchmark-job.yaml while true; do if kubectl wait --for=condition=complete --timeout=0 -n ${{ steps.setup.outputs.namespace }} jobs/memtier-benchmark 2>/dev/null; then job_result=0 break fi if kubectl wait --for=condition=failed --timeout=0 -n ${{ steps.setup.outputs.namespace }} jobs/memtier-benchmark 2>/dev/null; then job_result=1 break fi sleep 3 done kubectl logs -n ${{ steps.setup.outputs.namespace }} -f jobs/memtier-benchmark if [[ $job_result -eq 1 ]]; then exit 1 fi - name: Server checks run: | nohup kubectl port-forward -n ${{ steps.setup.outputs.namespace }} service/dragonfly-sample 6379:6379 & pip install -r tools/requirements.txt python3 tools/benchmark/post_run_checks.py - name: Get Dragonfly logs uses: nick-fields/retry@v3 if: always() with: timeout_minutes: 1 max_attempts: 3 command: | kubectl logs -n ${{ steps.setup.outputs.namespace }} dragonfly-sample-0 - name: Get Dragonfly replica logs uses: nick-fields/retry@v3 if: always() with: timeout_minutes: 1 max_attempts: 3 command: | kubectl logs -n ${{ steps.setup.outputs.namespace }} dragonfly-sample-1 - name: Describe dragonflydb object uses: nick-fields/retry@v3 if: always() with: timeout_minutes: 1 max_attempts: 3 command: | kubectl describe dragonflies.dragonflydb.io -n ${{ steps.setup.outputs.namespace }} dragonfly-sample - name: Scale down to zero if: always() run: | set -x aws autoscaling set-desired-capacity --auto-scaling-group-name "$AUTOSCALING_GROUP" --desired-capacity 0 env: AUTOSCALING_GROUP: ${{ vars.DEV_EKS_AS_GROUP }} - name: Cleanup if: always() run: | set -x kubectl delete namespace ${{ steps.setup.outputs.namespace }} kubectl delete namespace dragonfly-operator-system - name: Send notification on failure if: failure() && github.ref == 'refs/heads/main' shell: bash run: | job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="Benchmark tests failed.\\n Job Link: ${job_link}\\n" curl -s \ -X POST \ -H 'Content-Type: application/json' \ '${{ secrets.GSPACES_BOT_DF_BUILD }}' \ -d '{"text": "'"${message}"'"}' ================================================ FILE: .github/workflows/bullmq-tests.yml ================================================ name: bullmq-tests on: schedule: - cron: '0 7 * * *' # run at 7 AM daily workflow_dispatch: permissions: contents: read env: NODE_VERSION: "22.12.0" jobs: build: if: github.repository == 'dragonflydb/dragonfly' runs-on: ubuntu-latest name: Build timeout-minutes: 60 container: image: ghcr.io/romange/ubuntu-dev:20-gcc14 options: --security-opt seccomp=unconfined credentials: username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} steps: - uses: actions/checkout@v6 with: submodules: true - name: Build Dragonfly run: | cmake -B ${GITHUB_WORKSPACE}/build \ -DCMAKE_BUILD_TYPE=Release \ -DWITH_AWS=OFF \ -DWITH_GCP=OFF \ -DWITH_UNWIND=OFF \ -DWITH_GPERF=OFF \ -GNinja \ -L cd ${GITHUB_WORKSPACE}/build && ninja dragonfly - name: Install Node.js run: | wget -q https://unofficial-builds.nodejs.org/download/release/v${NODE_VERSION}/node-v${NODE_VERSION}-linux-x64-glibc-217.tar.xz tar -xf node-v${NODE_VERSION}-linux-x64-glibc-217.tar.xz cp -r node-v${NODE_VERSION}-linux-x64-glibc-217/* /usr/local/ apt-get update && apt-get install -y jq redis-tools npm install -g yarn node --version yarn --version - name: Start Dragonfly run: | ${GITHUB_WORKSPACE}/build/dragonfly \ --alsologtostderr \ --cluster_mode=emulated \ --lock_on_hashtags \ --dbfilename= \ --port 6379 & timeout 15s bash -c 'until redis-cli -p 6379 PING 2>/dev/null | grep -q PONG; do sleep 0.1; done' - name: Build BullMQ run: | cd ${GITHUB_WORKSPACE} git clone https://github.com/dragonflydb/bullmq cd bullmq yarn install yarn build - name: Run BullMQ tests run: | cd ${GITHUB_WORKSPACE}/bullmq SKIP_PATTERN=$(grep -v '^#' ${GITHUB_WORKSPACE}/.github/bullmq-skipped-tests.txt | grep -v '^[[:space:]]*$' | paste -sd '|' || true) if [ -n "${SKIP_PATTERN}" ]; then BULLMQ_TEST_PREFIX={b} yarn test --grep "${SKIP_PATTERN}" --invert else BULLMQ_TEST_PREFIX={b} yarn test fi - name: Upload logs on failure if: failure() uses: actions/upload-artifact@v6 with: name: unit_logs path: /tmp/dragonfly.* - name: Send notification on failure if: failure() && github.ref == 'refs/heads/main' run: | job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="BullMQ tests failed.\\n Commit: ${{github.sha}}\\n Job Link: ${job_link}\\n" curl -s \ -X POST \ -H 'Content-Type: application/json' \ '${{ secrets.GSPACES_BOT_DF_BUILD }}' \ -d '{"text": "'"${message}"'"}' ================================================ FILE: .github/workflows/ci.yml ================================================ name: ci-tests on: # push: # branches: [ main ] pull_request: branches: [main] workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: pre-commit: if: github.event_name == 'pull_request' runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: actions/setup-python@v6 with: python-version: '3.12' cache: 'pip' - uses: actions/cache@v4 with: path: ~/.cache/pre-commit key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }} - uses: pre-commit/action@v3.0.1 with: extra_args: >- --show-diff-on-failure --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} build: strategy: matrix: # Test of these containers container: ["ubuntu-dev:24", "alpine-dev:latest"] build-type: [Debug, Release] compiler: [{ cxx: g++, c: gcc }] # -no-pie to disable address randomization so we could symbolize stacktraces cxx_flags: ["-Werror -no-pie"] sanitizers: ["NoSanitizers"] include: - container: "alpine-dev:latest" build-type: Debug compiler: { cxx: clang++, c: clang } cxx_flags: "" sanitizers: "NoSanitizers" - container: "ubuntu-dev:24" build-type: Debug compiler: { cxx: clang++, c: clang } # https://maskray.me/blog/2023-08-25-clang-wunused-command-line-argument (search for compiler-rt) cxx_flags: "-Wno-error=unused-command-line-argument" sanitizers: "Sanitizers" runs-on: ubuntu-latest container: image: ghcr.io/romange/${{ matrix.container }} # Seems that docker by default prohibits running iouring syscalls options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /:/hostroot - /mnt:/mnt credentials: username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} steps: - uses: actions/checkout@v6 with: submodules: true - name: Prepare Environment run: | uname -a cmake --version mkdir -p ${GITHUB_WORKSPACE}/build mount echo "===================Before freeing up space ============================================" df -h rm -rf /hostroot/usr/share/dotnet rm -rf /hostroot/usr/local/share/boost rm -rf /hostroot/usr/local/lib/android rm -rf /hostroot/opt/ghc echo "===================After freeing up space ============================================" df -h touch /mnt/foo ls -la /mnt/foo - name: System diagnostics run: | echo "ulimit is" ulimit -s echo "-----------------------------" echo "disk space is:" df -h echo "-----------------------------" - name: Build Dragonfly uses: ./.github/actions/builder with: build-type: ${{matrix.build-type}} c-compiler: ${{matrix.compiler.c}} cxx-compiler: ${{matrix.compiler.cxx}} cxx-flags: ${{matrix.cxx_flags}} sanitizers: ${{matrix.sanitizers}} with-aws: 'OFF' - name: PostFail if: failure() run: | echo "disk space is:" df -h - name: C++ Unit Tests - IoUring run: | cd ${GITHUB_WORKSPACE}/build echo Run ctest -V -L DFLY GLOG_alsologtostderr=1 GLOG_vmodule=rdb_load=1,rdb_save=1,snapshot=1,op_manager=1,op_manager_test=1 \ FLAGS_fiber_safety_margin=4096 timeout 20m ctest -V -L DFLY -E allocation_tracker_test # Run allocation tracker test separately without alsologtostderr because it generates a TON of logs. FLAGS_fiber_safety_margin=4096 timeout 5m ./allocation_tracker_test timeout 5m ./dragonfly_test timeout 5m ./json_family_test --jsonpathv2=false timeout 5m ./tiered_storage_test --vmodule=db_slice=2 --logtostderr timeout 5m ./search_test --use_numeric_range_tree=false timeout 5m ./search_family_test --use_numeric_range_tree=false - name: C++ Unit Tests - Epoll run: | cd ${GITHUB_WORKSPACE}/build # Create a rule that automatically prints stacktrace upon segfault cat > ./init.gdb <> $GITHUB_OUTPUT - name: Docker meta id: metadata uses: docker/metadata-action@v5 with: images: | ${{ env.image }} ${{ env.GCS_IMAGE }} tags: | type=sha,enable=true,prefix=${{ matrix.flavor}}-,suffix=-${{ matrix.os.arch }},format=short labels: | org.opencontainers.image.vendor=DragonflyDB LTD org.opencontainers.image.title=Dragonfly Development Image org.opencontainers.image.description=The fastest in-memory store - name: Build image id: build uses: docker/build-push-action@v6 with: context: . push: true provenance: false # Prevent pushing a docker manifest tags: | ${{ steps.metadata.outputs.tags }} labels: ${{ steps.metadata.outputs.labels }} file: tools/packaging/Dockerfile.${{ matrix.flavor }}-dev cache-from: type=gha,scope=tagged${{ matrix.flavor }} cache-to: type=gha,scope=tagged${{ matrix.flavor }},mode=max load: true # Load the build images into the local docker. - name: Test Image run: | echo ${{ steps.build.outputs.digest }} image_tags=(${{ steps.metadata.outputs.tags }}) # install redis-tools sudo apt-get install redis-tools -y for image_tag in "${image_tags[@]}"; do echo "Testing image: ${image_tag}" docker image inspect ${image_tag} echo "Testing ${{ matrix.flavor }} image" # docker run with port-forwarding docker run -d -p 6379:6379 ${image_tag} sleep 5 redis-cli -h localhost ping | grep -q "PONG" || exit 1 docker stop $(docker ps -q --filter ancestor=${image_tag}) done - name: Extract and Upload Binaries if: matrix.flavor == 'ubuntu' # Only run once per flavor run: | # Get the image tag image_tags=(${{ steps.metadata.outputs.tags }}) image_tag=${image_tags[0]} # Extract version from the image echo "Extracting version from image..." VERSION=$(docker run --rm ${image_tag} dragonfly --version | sed -r "s/\x1B\[([0-9]{1,3}(;[0-9]{1,2})?)?[mGK]//g" | head -n1 | cut -d' ' -f2 | cut -d'-' -f1) # Check if version starts with a release version (v*.*.*) if [[ ! $VERSION =~ ^v[0-9]+\.[0-9]+\.[0-9]+ ]]; then # Get the latest release version to use as prefix LATEST_RELEASE=$(curl -s https://api.github.com/repos/dragonflydb/dragonfly/releases/latest | jq -r .tag_name) VERSION="${LATEST_RELEASE}+${VERSION}" fi echo "Dragonfly version: $VERSION" echo "Extracting binary from ${image_tag} for ${{ matrix.os.arch }}" # Create a temporary container and copy the binary container_id=$(docker create ${image_tag}) docker cp ${container_id}:/usr/local/bin/dragonfly ./dragonfly docker rm ${container_id} # Create a tar archive if [[ "${{ matrix.os.arch }}" == "arm64" ]]; then arch_name="aarch64" else arch_name="x86_64" fi tar_name="dragonfly-${arch_name}-dbgsym.tar.gz" tar czf ${tar_name} dragonfly # Upload to GCS echo "Uploading ${tar_name} to GCS" gcloud storage cp "$tar_name" "gs://${{ secrets.STAGING_BINARY_BUCKET }}/dragonfly/$VERSION/$tar_name" # Upload to AWS echo "Uploading ${tar_name} to AWS" aws s3 cp "$tar_name" "s3://${{ secrets.STAGING_BINARY_BUCKET }}/dragonfly/$VERSION/$tar_name" # Cleanup rm -f dragonfly ${tar_name} outputs: # matrix jobs outputs override each other, but we use the same sha # for all images, so we can use the same output name. sha: ${{ steps.build_info.outputs.short_sha }} merge_manifest: if: github.repository == 'dragonflydb/dragonfly' needs: [build_and_tag] runs-on: ubuntu-latest strategy: matrix: flavor: [alpine,ubuntu] steps: - name: checkout uses: actions/checkout@v6 - name: Login to Registries uses: ./.github/actions/multi-registry-docker-login with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} - name: Merge and Push run: | # Function to create and push manifests for a given registry create_and_push_manifests() { local registry=$1 local flavor=$2 local sha=$3 # Create and push the manifest like dragonfly-dev:alpine- local sha_tag="${registry}:${flavor}-${sha}" docker manifest create ${sha_tag} --amend ${sha_tag}-amd64 --amend ${sha_tag}-arm64 docker manifest push ${sha_tag} # Create and push the manifest like dragonfly-dev:alpine local flavor_tag="${registry}:${flavor}" docker manifest create ${flavor_tag} --amend ${sha_tag}-amd64 --amend ${sha_tag}-arm64 docker manifest push ${flavor_tag} } # GitHub Container Registry manifests create_and_push_manifests "${{ env.image }}" "${{ matrix.flavor }}" "${{ needs.build_and_tag.outputs.sha }}" # Google Artifact Registry manifests create_and_push_manifests "${{ env.GCS_IMAGE }}" "${{ matrix.flavor }}" "${{ needs.build_and_tag.outputs.sha }}" ================================================ FILE: .github/workflows/docker-release2.yml ================================================ name: Docker Release-v2 on: workflow_dispatch: inputs: TAG_NAME: description: 'Tag name that the major tag will point to' required: true PRERELEASE: description: 'Whether this is a prerelease' type: boolean required: true release: types: [published] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true env: TAG_NAME: ${{ github.event.inputs.TAG_NAME || github.event.release.tag_name }} IS_PRERELEASE: ${{ github.event.release.prerelease || github.event.inputs.PRERELEASE }} IMAGE: ghcr.io/dragonflydb/dragonfly GCS_IMAGE: us-central1-docker.pkg.dev/dragonflydb-public/dragonfly-registry/dragonfly jobs: build_and_tag: name: Build and Push ${{matrix.flavor}} ${{ matrix.os.arch }} image strategy: matrix: flavor: [ubuntu] os: - image: ubuntu-24.04 arch: amd64 - image: ubuntu-24.04-arm arch: arm64 runs-on: ${{ matrix.os.image }} permissions: contents: read packages: write id-token: write steps: - name: checkout uses: actions/checkout@v6 with: fetch-depth: 0 submodules: true - name: Set up Docker Build uses: docker/setup-buildx-action@v3 - name: Login to Registries uses: ./.github/actions/multi-registry-docker-login with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} - name: Fetch release asset uses: dsaltares/fetch-gh-release-asset@1.1.2 with: version: "tags/${{ env.TAG_NAME }}" regex: true file: "dragonfly-.*\\.tar\\.gz" target: 'releases/' token: ${{ secrets.GITHUB_TOKEN }} - name: Extract artifacts run: | echo "Event prerelease ${{ github.event.release.prerelease }}" echo "Input prerelease ${{ github.event.inputs.PRERELEASE }}" ls -l ls -l releases for f in releases/*.tar.gz; do tar xvfz $f -C releases; done rm releases/*.tar.gz - name: Docker meta id: metadata uses: docker/metadata-action@v5 with: images: | ${{ env.IMAGE }} ${{ env.GCS_IMAGE }} flavor: | latest=false prefix=${{ matrix.flavor}}- suffix=-${{ matrix.os.arch }} tags: | type=semver,pattern={{version}},enable=true,value=${{ env.TAG_NAME }} type=semver,pattern={{raw}},enable=true,value=${{ env.TAG_NAME }} type=ref,event=pr labels: | org.opencontainers.image.vendor=DragonflyDB LTD org.opencontainers.image.title=Dragonfly Production Image org.opencontainers.image.description=The fastest in-memory store org.opencontainers.image.version=${{ env.TAG_NAME }} - name: Build image id: build uses: docker/build-push-action@v6 with: context: . push: true provenance: false # Prevent pushing a docker manifest tags: | ${{ steps.metadata.outputs.tags }} labels: ${{ steps.metadata.outputs.labels }} file: tools/packaging/Dockerfile.${{ matrix.flavor }}-prod cache-from: type=gha,scope=prod-${{ matrix.flavor }} cache-to: type=gha,scope=prod-${{ matrix.flavor }},mode=max load: true # Load the build images into the local docker. - name: Test Image uses: ./.github/actions/test-docker timeout-minutes: 1 with: image_id: ${{ env.IMAGE }}@${{ steps.build.outputs.digest }} name: ${{ matrix.flavor }}-${{ matrix.os.arch }} - id: output-sha run: | echo "sha_${{ matrix.os.arch }}=${{ steps.build.outputs.digest }}" >> $GITHUB_OUTPUT outputs: sha_amd: ${{ steps.output-sha.outputs.sha_amd64 }} sha_arm: ${{ steps.output-sha.outputs.sha_arm64 }} merge_manifest: needs: [build_and_tag] runs-on: ubuntu-latest strategy: matrix: flavor: [ubuntu] steps: - name: checkout uses: actions/checkout@v6 - name: Login to Registries uses: ./.github/actions/multi-registry-docker-login with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GCP_SA_KEY: ${{ secrets.GCP_SA_KEY }} - name: Merge and Push run: | # Function to create and push manifests for a given registry create_and_push_manifests() { local registry=$1 local sha_amd=$2 local sha_arm=$3 local flavor=$4 local tag_name=$5 local is_prerelease=$6 # Function for semantic version comparison # Returns true if current_version >= latest_version semver_cmp() { local current_version=$1 local latest_version=$2 local should_update=true # Extract major.minor.patch components IFS='.' read -ra CURRENT_PARTS <<< "$current_version" IFS='.' read -ra LATEST_PARTS <<< "$latest_version" # Pad arrays to same length for comparison while [ ${#CURRENT_PARTS[@]} -lt 3 ]; do CURRENT_PARTS+=(0); done while [ ${#LATEST_PARTS[@]} -lt 3 ]; do LATEST_PARTS+=(0); done # Compare major.minor.patch numerically if (( 10#${CURRENT_PARTS[0]} < 10#${LATEST_PARTS[0]} )); then should_update=false elif (( 10#${CURRENT_PARTS[0]} == 10#${LATEST_PARTS[0]} )) && (( 10#${CURRENT_PARTS[1]} < 10#${LATEST_PARTS[1]} )); then should_update=false elif (( 10#${CURRENT_PARTS[0]} == 10#${LATEST_PARTS[0]} )) && (( 10#${CURRENT_PARTS[1]} == 10#${LATEST_PARTS[1]} )) && (( 10#${CURRENT_PARTS[2]} < 10#${LATEST_PARTS[2]} )); then should_update=false fi # Log debug info to stderr instead of stdout echo "Version comparison: current=${CURRENT_PARTS[0]}.${CURRENT_PARTS[1]}.${CURRENT_PARTS[2]} vs latest=${LATEST_PARTS[0]}.${LATEST_PARTS[1]}.${LATEST_PARTS[2]}" >&2 # Return only the result echo $should_update } if [[ "$is_prerelease" == 'true' ]]; then # Create and push the manifest like dragonfly:alpha-ubuntu tag="${registry}:alpha-${flavor}" docker manifest create ${tag} --amend ${sha_amd} --amend ${sha_arm} docker manifest push ${tag} elif [[ "$flavor" == 'ubuntu' ]]; then # Checking if this version should be tagged as latest echo "Checking if ${tag_name} should be tagged as latest..." # Remove 'v' prefix if present for semantic comparison current_version=${tag_name#v} # Get the current latest version by running the latest image latest_version="" if docker pull ${registry}:latest &>/dev/null; then echo "Found latest tag, checking its version..." # First try to get version from image labels using docker inspect echo "Method 1: Trying to get version from image labels..." label_version=$(docker image inspect --format '{{ index .Config.Labels "org.opencontainers.image.version" }}' ${registry}:latest 2>/dev/null || echo "") if [[ -n "$label_version" ]]; then echo "Found version from image labels: $label_version" # Extract version from format like "ubuntu-1.28.1-arm64" if [[ $label_version == ubuntu-*-* ]]; then # Extract the middle part (version) from ubuntu-VERSION-arch latest_full_version=$(echo "$label_version" | cut -d'-' -f2) else # Use the label as is latest_full_version=$label_version fi echo "Extracted version: $latest_full_version" else # Fallback to running the container if label inspect failed echo "Method 2: Falling back to container execution..." latest_full_version=$(docker run --rm --entrypoint /bin/sh ${registry}:latest -c "dragonfly --version | cut -d' ' -f2 | head -n 1") fi echo "Latest full version: ${latest_full_version}" # Extract only the semantic version part (before any dash) latest_version=$(echo "${latest_full_version}" | cut -d'-' -f1) # Remove 'v' prefix if present latest_version=${latest_version#v} echo "Current latest version: ${latest_version}" else echo "No latest tag found yet or couldn't pull it" fi # Compare versions only if we have a latest version should_update_latest=true if [[ -n "$latest_version" ]]; then # Call our semver comparison function should_update_latest=$(semver_cmp "$current_version" "$latest_version") fi if [[ "$should_update_latest" == true ]]; then echo "Version ${tag_name} is newer than or equal to current latest, updating latest tag" tag="${registry}:latest" # Create and push the manifest like dragonfly:latest docker manifest create ${tag} --amend ${sha_amd} --amend ${sha_arm} docker manifest push ${tag} else echo "Version ${tag_name} is older than current latest (${latest_version}), NOT updating latest tag" fi fi # Create and push the manifest like dragonfly:v1.26.4 tag="${registry}:${tag_name}" docker manifest create ${tag} --amend ${sha_amd} --amend ${sha_arm} docker manifest push ${tag} } # GitHub Container Registry manifests ghcr_sha_amd=${{ env.IMAGE }}@${{ needs.build_and_tag.outputs.sha_amd }} ghcr_sha_arm=${{ env.IMAGE }}@${{ needs.build_and_tag.outputs.sha_arm }} create_and_push_manifests "${{ env.IMAGE }}" "$ghcr_sha_amd" "$ghcr_sha_arm" "${{ matrix.flavor }}" "${{ env.TAG_NAME }}" "${{ env.IS_PRERELEASE }}" # Google Artifact Registry manifests gar_sha_amd=${{ env.GCS_IMAGE }}@${{ needs.build_and_tag.outputs.sha_amd }} gar_sha_arm=${{ env.GCS_IMAGE }}@${{ needs.build_and_tag.outputs.sha_arm }} create_and_push_manifests "${{ env.GCS_IMAGE }}" "$gar_sha_amd" "$gar_sha_arm" "${{ matrix.flavor }}" "${{ env.TAG_NAME }}" "${{ env.IS_PRERELEASE }}" release_helm_and_notify: needs: [merge_manifest] runs-on: ubuntu-latest permissions: contents: write packages: write pull-requests: write steps: - name: print_env run: env - name: checkout uses: actions/checkout@v6 with: token: ${{ secrets.DRAGONFLY_TOKEN }} # PAT to push to main fetch-depth: 0 - name: Install helm uses: azure/setup-helm@v4 - name: Setup Go uses: actions/setup-go@v6 - name: Configure Git if: env.IS_PRERELEASE != 'true' run: | git config user.name "$GITHUB_ACTOR" git config user.email "$GITHUB_ACTOR@users.noreply.github.com" - name: Update helm chart if: env.IS_PRERELEASE != 'true' run: | git checkout -b helm-chart-update/${{ env.TAG_NAME }} origin/main sed -Ei \ -e 's/^(version\:) .*/\1 '${{ env.TAG_NAME }}'/g' \ -e 's/^(appVersion\:) .*/\1 "'${{ env.TAG_NAME }}'"/g' \ contrib/charts/dragonfly/Chart.yaml go test ./contrib/charts/dragonfly/... -update git commit \ -m 'chore(helm-chart): update to ${{ env.TAG_NAME }}' \ contrib/charts/dragonfly/Chart.yaml \ contrib/charts/dragonfly/ci || true - name: Push Helm chart as OCI to Github if: env.IS_PRERELEASE != 'true' run: | echo "${{ secrets.GITHUB_TOKEN }}" | \ helm registry login -u ${{ github.actor }} --password-stdin ghcr.io helm package contrib/charts/dragonfly helm push dragonfly-${{ env.TAG_NAME }}.tgz oci://ghcr.io/${{ github.repository }}/helm - name: Discord notification env: DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }} uses: Ilshidur/action-discord@d2594079a10f1d6739ee50a2471f0ca57418b554 with: args: 'DragonflyDB version [${{ env.TAG_NAME }}](https://github.com/dragonflydb/dragonfly/releases/tag/${{ env.TAG_NAME }}) has been released 🎉' - name: Re-build Docs if: env.IS_PRERELEASE != 'true' run: | curl -s -X POST '${{ secrets.VERCEL_DOCS_WEBHOOK }}' - name: Create Helm Chart PR if: env.IS_PRERELEASE != 'true' env: GH_TOKEN: ${{ secrets.DRAGONFLY_TOKEN }} run: | git push origin helm-chart-update/${{ env.TAG_NAME }} gh pr create \ --base main \ --head helm-chart-update/${{ env.TAG_NAME }} \ --title 'chore(helm-chart): update to ${{ env.TAG_NAME }}' \ --body 'Automated Helm chart version bump to ${{ env.TAG_NAME }}.' \ --reviewer vyavdoshenko ================================================ FILE: .github/workflows/epoll-regression-tests.yml ================================================ name: Epoll Regression Tests on: schedule: - cron: "0 0/3 * * *" workflow_dispatch: jobs: build: if: github.repository == 'dragonflydb/dragonfly' strategy: matrix: # Test of these containers container: ["ubuntu-dev:24"] proactor: [Epoll] build-type: [Debug] runner: [ubuntu-latest, [self-hosted, linux, ARM64]] runs-on: ${{ matrix.runner }} permissions: id-token: write contents: read container: image: ghcr.io/romange/${{ matrix.container }} options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /var/crash:/var/crash - /:/hostroot - /mnt:/mnt steps: - uses: actions/checkout@v6 with: submodules: true - name: Print environment info run: | cat /proc/cpuinfo ulimit -a env - name: Build Dragonfly uses: ./.github/actions/builder with: build-type: ${{matrix.build-type}} targets: 'dragonfly' - name: Authenticate to AWS uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_CI_S3_ROLE_ARN }} aws-region: us-east-1 - name: Run regression tests action uses: ./.github/actions/regression-tests with: dfly-executable: dragonfly gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} build-folder-name: build filter: ${{ matrix.build-type == 'Release' && 'not empty' || 'not opt_only' }} s3-bucket: ${{ secrets.S3_REGTEST_BUCKET }} # Chain ternary oprator of the form (which can be nested) # (expression == condition && || ) epoll: ${{ matrix.proactor == 'Epoll' && 'epoll' || 'iouring' }} - name: Upload logs on failure if: failure() uses: actions/upload-artifact@v6 with: name: logs path: /tmp/failed/* lint-test-chart: if: github.repository == 'dragonflydb/dragonfly' runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: ./.github/actions/lint-test-chart ================================================ FILE: .github/workflows/fuzz-long.yml ================================================ name: AFL++ Long Fuzzing Campaign on: schedule: # Run nightly at 2 AM UTC - cron: '0 2 * * *' workflow_dispatch: inputs: resp_duration: description: 'RESP fuzzing duration in minutes' required: false default: '60' type: string memcache_duration: description: 'Memcache fuzzing duration in minutes' required: false default: '30' type: string concurrency: group: ${{ github.workflow }} cancel-in-progress: true jobs: fuzz-long: if: github.repository == 'dragonflydb/dragonfly' runs-on: CI-LARGE-86 timeout-minutes: 120 strategy: fail-fast: false matrix: include: - target: resp duration: '60' - target: memcache duration: '30' container: image: ghcr.io/romange/ubuntu-dev:24-afl options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" credentials: username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} steps: - name: Checkout code uses: actions/checkout@v6 with: submodules: true - name: Run AFL++ long fuzzing campaign (${{ matrix.target }}) uses: ./.github/actions/fuzzing with: mode: long target: ${{ matrix.target }} duration-minutes: ${{ matrix.target == 'resp' && (github.event.inputs.resp_duration || matrix.duration) || (github.event.inputs.memcache_duration || matrix.duration) }} run-number: ${{ github.run_number }} - name: Send notification on failure if: failure() && github.ref == 'refs/heads/main' run: | job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="AFL++ ${{ matrix.target }} fuzzing found crashes.\\n Commit: ${{github.sha}}\\n Job Link: ${job_link}\\n" curl -s \ -X POST \ -H 'Content-Type: application/json' \ '${{ secrets.GSPACES_BOT_DF_BUILD }}' \ -d '{"text": "'"${message}"'"}' ================================================ FILE: .github/workflows/fuzz-pr.yml ================================================ # Run AFL++ fuzzing on PRs that touch C++ code. # # For each PR, an LLM analyzes the diff and generates: # 1. Targeted seed files — initial inputs crafted to exercise the changed code paths. # (A "seed" is a RESP-encoded sequence of Redis commands that the fuzzer starts from # and mutates; see fuzz/seeds/resp/*.resp for the existing seed corpus.) # 2. Focus command list — commands the mutator should prefer (~70% of the time), # so mutations concentrate on the affected code instead of spreading randomly. # # The fuzzer then runs for 15 minutes in "smoke" mode (stop on first crash). # When ANTHROPIC_API_KEY is unavailable (e.g. fork PRs), seed generation is skipped # and the fuzzer uses the existing seed corpus as-is. # # Additionally, if the PR touches memcache-related code (memcache_parser, mc_family, # fuzz/memcache_mutator.py, or fuzz/seeds/memcache/), a focused memcache fuzzing step # runs automatically after RESP fuzzing passes, reusing the already-built binary. name: AFL++ PR Fuzzing on: pull_request: branches: [main] paths: - 'src/**/*.cc' - 'src/**/*.h' - 'helio/**/*.cc' - 'helio/**/*.h' - 'fuzz/**' - '.github/workflows/fuzz-pr.yml' - '.github/actions/fuzzing/**' workflow_dispatch: inputs: duration: description: 'Fuzzing duration in minutes' required: false default: '15' type: string memcache-duration: description: 'Memcache fuzzing duration in minutes' required: false default: '10' type: string concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: fuzz-pr: runs-on: CI-LARGE-86 timeout-minutes: 60 container: image: ghcr.io/romange/ubuntu-dev:24-afl options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" credentials: username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} steps: - name: Checkout code uses: actions/checkout@v6 with: submodules: true fetch-depth: 0 - name: Generate PR diff id: diff run: | if [ "${{ github.event_name }}" = "pull_request" ]; then git config --global --add safe.directory "$GITHUB_WORKSPACE" BASE=${{ github.event.pull_request.base.sha }} HEAD_SHA=${{ github.event.pull_request.head.sha }} MERGE_BASE=$(git merge-base "$BASE" "$HEAD_SHA") git diff "$MERGE_BASE".."$HEAD_SHA" > /tmp/pr_diff.txt else echo "" > /tmp/pr_diff.txt fi DIFF_LINES=$(wc -l < /tmp/pr_diff.txt) echo "diff_lines=${DIFF_LINES}" >> "$GITHUB_OUTPUT" echo "::group::PR diff summary" echo "C++ diff lines: ${DIFF_LINES}" if [ "$DIFF_LINES" -gt 0 ]; then echo "Changed files:" grep '^diff --git' /tmp/pr_diff.txt | sed 's|diff --git a/.* b/| |' || true else echo "No C++ file changes in this PR — seed generation will be skipped" fi echo "::endgroup::" - name: Generate targeted seeds id: seeds run: | pip install 'anthropic>=0.39,<1' 2>/dev/null || pip install --break-system-packages 'anthropic>=0.39,<1' 2>/dev/null || true SEEDS_DIR="${GITHUB_WORKSPACE}/fuzz/seeds/pr_targeted" mkdir -p "$SEEDS_DIR" python3 fuzz/generate_targeted_seeds.py \ --output-dir "$SEEDS_DIR" \ < /tmp/pr_diff.txt FOCUS="" if [ -f "$SEEDS_DIR/focus_commands.json" ]; then FOCUS=$(cat "$SEEDS_DIR/focus_commands.json") fi echo "focus_commands=${FOCUS}" >> "$GITHUB_OUTPUT" echo "seeds_dir=${SEEDS_DIR}" >> "$GITHUB_OUTPUT" SEED_COUNT=$(ls "$SEEDS_DIR"/*.resp 2>/dev/null | wc -l || echo 0) echo "::group::Seed generation results" echo "Seeds generated: ${SEED_COUNT}" echo "Focus commands: ${FOCUS:-none}" if [ "$SEED_COUNT" -gt 0 ]; then ls -la "$SEEDS_DIR"/*.resp fi echo "::endgroup::" # Job summary { echo "### Fuzzing Seed Generation" echo "" if [ "$SEED_COUNT" -gt 0 ]; then echo "- **Seeds generated:** ${SEED_COUNT}" echo "- **Focus commands:** \`${FOCUS}\`" elif [ "$(wc -l < /tmp/pr_diff.txt)" -eq 0 ]; then echo "- No C++ changes in PR — using default seed corpus" elif [ -z "$ANTHROPIC_API_KEY" ]; then echo "- No API key — using default seed corpus" else echo "- LLM did not produce usable seeds — using default seed corpus" fi } >> "$GITHUB_STEP_SUMMARY" env: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - name: Run AFL++ PR fuzzing uses: ./.github/actions/fuzzing with: mode: smoke duration-minutes: ${{ github.event.inputs.duration || '15' }} run-number: ${{ github.run_number }} extra-seeds-dir: ${{ steps.seeds.outputs.seeds_dir }} focus-commands: ${{ steps.seeds.outputs.focus_commands }} # Reuses the binary built by the RESP step above (build: false). # Only runs when RESP fuzzing passed (default success() condition) and memcache # code was actually touched in this PR. - name: Check if memcache-related files changed id: memcache-check run: | if [ "${{ github.event_name }}" = "pull_request" ]; then CHANGED=$(grep -E '^diff --git a/(src/(facade/memcache|server/mc_family)|fuzz/(memcache_mutator|seeds/memcache))' /tmp/pr_diff.txt || true) if [ -n "$CHANGED" ]; then echo "run=true" >> "$GITHUB_OUTPUT" echo "Memcache-related files changed — will run memcache fuzzing:" echo "$CHANGED" | sed 's|diff --git a/.* b/| |' else echo "run=false" >> "$GITHUB_OUTPUT" echo "No memcache-related files changed — skipping memcache fuzzing" fi else echo "run=true" >> "$GITHUB_OUTPUT" echo "Manual trigger — running memcache fuzzing" fi - name: Run AFL++ memcache fuzzing if: success() && steps.memcache-check.outputs.run == 'true' uses: ./.github/actions/fuzzing with: mode: smoke target: memcache build: 'false' duration-minutes: ${{ github.event.inputs['memcache-duration'] || '10' }} run-number: ${{ github.run_number }} ================================================ FILE: .github/workflows/generate-osrepo-site.yml ================================================ name: generate-site on: workflow_dispatch: release: types: [published] jobs: gen-site: runs-on: ubuntu-latest env: SiteRoot: _site name: Generate index and site assets steps: - name: Checkout Repository uses: actions/checkout@v6 - name: Install packaging tools # RPM tools are available on ubuntu run: sudo apt install -y rpm gpg createrepo-c dpkg-dev reprepro - name: Setup requirements working-directory: tools/packaging/osrepos run: pip install -r requirements.txt - name: Download packages working-directory: tools/packaging/osrepos run: python scripts/fetch-releases.py $SiteRoot - name: Import GPG key id: gpg-import uses: crazy-max/ghaction-import-gpg@v6 with: gpg_private_key: ${{ secrets.GPG_PRIVATE_KEY }} - name: Sign RPMs shell: sh working-directory: tools/packaging/osrepos run: sh scripts/sign-rpms.sh ${{ steps.gpg-import.outputs.fingerprint }} - name: Create YUM repository # Creates metadata for YUM/DNF repository, the files were copied in the download step shell: sh working-directory: tools/packaging/osrepos run: createrepo_c -v $SiteRoot/rpm - name: Sign YUM repository shell: sh working-directory: tools/packaging/osrepos run: gpg --armor --detach-sign $SiteRoot/rpm/repodata/repomd.xml - name: Create APT repository # The configuration for apt repo is in tools/packaging/osrepos/reprepro-config, # which ensures the same GPG key used elsewhere in this action is used to sign # the repository shell: sh working-directory: tools/packaging/osrepos run: sh -x scripts/generate-apt-repo.sh - name: Prepare assets working-directory: tools/packaging/osrepos run: | cp -aRv dragonfly.repo pgp-key.public dragonfly.sources $SiteRoot/ rm -rf $SiteRoot/deb/conf - name: Generate Directory Listings working-directory: tools/packaging/osrepos run: python scripts/generate-index.py $SiteRoot - name: Authenticate uses: 'google-github-actions/auth@v3' with: project_id: 'dragonflydb' credentials_json: ${{ secrets.GCP_BUCKET_CREDENTIALS }} - name: GCloud setup uses: 'google-github-actions/setup-gcloud@v3' - name: Deploy site working-directory: tools/packaging/osrepos run: | gcloud storage rm ${{ secrets.GCP_PACKAGES_BUCKET }}/** gcloud storage rsync $SiteRoot ${{ secrets.GCP_PACKAGES_BUCKET }} --recursive --delete-unmatched-destination-objects - name: Notify on failure if: failure() run: | job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="Package repo generation failed.\nCommit: ${{ github.sha }}\nJob: ${job_link}" curl -sSf -X POST -H 'Content-Type: application/json' '${{ secrets.GSPACES_BOT_DF_BUILD }}' -d '{"text": "'"${message}"'"}' ================================================ FILE: .github/workflows/heavy-tests.yml ================================================ name: Heavy Tests on: schedule: - cron: "0 0/6 * * *" workflow_dispatch: jobs: build: if: github.repository == 'dragonflydb/dragonfly' strategy: matrix: # Test of these containers container: ["ubuntu-dev:24"] proactor: [Uring] build-type: [Release] runner: [CI-LARGE-86, CI-LARGE-ARM] runs-on: ${{ matrix.runner }} permissions: id-token: write contents: read container: image: ghcr.io/romange/${{ matrix.container }} options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /var/crash:/var/crash - /:/hostroot - /mnt:/mnt steps: - uses: actions/checkout@v6 with: submodules: true - name: Print environment info run: | cat /proc/cpuinfo ulimit -a env lsblk -l - name: Build Dragonfly uses: ./.github/actions/builder with: build-type: ${{matrix.build-type}} targets: 'dragonfly' - name: Authenticate to AWS uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_CI_S3_ROLE_ARN }} aws-region: us-east-1 - name: Run heavy tests uses: ./.github/actions/regression-tests with: dfly-executable: dragonfly gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} build-folder-name: build filter: large s3-bucket: ${{ secrets.S3_REGTEST_BUCKET }} - name: Upload logs on failure if: failure() uses: actions/upload-artifact@v6 with: name: logs-${{ matrix.runner }} path: /tmp/failed/* ================================================ FILE: .github/workflows/ioloop-v2-regtests.yml ================================================ name: RegTests IoLoopV2 # Manually triggered only on: workflow_dispatch: jobs: build: strategy: matrix: # Test of these containers container: ["ubuntu-dev:20-gcc14"] proactor: [Uring] build-type: [Debug, Release] runner: [ubuntu-latest, [self-hosted, linux, ARM64]] runs-on: ${{ matrix.runner }} permissions: id-token: write contents: read container: image: ghcr.io/romange/${{ matrix.container }} options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /var/crash:/var/crash - /:/hostroot - /mnt:/mnt steps: - uses: actions/checkout@v6 with: submodules: true - name: Print environment info run: | cat /proc/cpuinfo ulimit -a env - name: Build Dragonfly uses: ./.github/actions/builder with: build-type: ${{matrix.build-type}} targets: 'dragonfly' - name: Authenticate to AWS uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_CI_S3_ROLE_ARN }} aws-region: us-east-1 - name: Run regression tests action uses: ./.github/actions/regression-tests with: dfly-executable: dragonfly gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} build-folder-name: build filter: ${{ matrix.build-type == 'Release' && 'not debug_only and not tls' || 'not opt_only and not tls' }} s3-bucket: ${{ secrets.S3_REGTEST_BUCKET }} - name: Upload logs on failure if: failure() uses: actions/upload-artifact@v6 with: name: logs path: /tmp/failed/* lint-test-chart: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: ./.github/actions/lint-test-chart ================================================ FILE: .github/workflows/mastodon-ruby-tests.yml ================================================ name: Mastodon ruby tests on: schedule: - cron: '0 6 * * *' # run at 6 AM UTC workflow_dispatch: jobs: build-and-test: if: github.repository == 'dragonflydb/dragonfly' runs-on: ubuntu-latest name: Build and run tests services: postgres: image: postgres:14-alpine env: POSTGRES_PASSWORD: postgres POSTGRES_USER: postgres options: >- --health-cmd pg_isready --health-interval 10ms --health-timeout 3s --health-retries 50 ports: - 5432:5432 redis: image: docker.dragonflydb.io/dragonflydb/dragonfly:latest options: >- --health-cmd "redis-cli ping" --health-interval 10ms --health-timeout 3s --health-retries 50 ports: - 6379:6379 env: DB_HOST: localhost DB_USER: postgres DB_PASS: postgres RAILS_ENV: test ALLOW_NOPAM: true PAM_ENABLED: true PAM_DEFAULT_SERVICE: pam_test PAM_CONTROLLED_SERVICE: pam_test_controlled OIDC_ENABLED: true OIDC_SCOPE: read SAML_ENABLED: true CAS_ENABLED: true BUNDLE_WITH: 'pam_authentication test' GITHUB_RSPEC: false steps: - name: Checkout mastodon uses: actions/checkout@v6 with: repository: mastodon/mastodon - name: Install pre-requisites run: | sudo apt update sudo apt install -y libicu-dev libidn11-dev libvips42 ffmpeg imagemagick libpam-dev - name: Set up Ruby uses: ruby/setup-ruby@v1 with: ruby-version: 3.4 bundler-cache: true - name: Enable corepack shell: bash run: corepack enable - name: Install all production yarn packages shell: bash run: yarn workspaces focus --production - name: Set up Node.js uses: actions/setup-node@v6 with: node-version-file: '.nvmrc' - name: Precompile assets run: |- bin/rails assets:precompile - name: Load database schema run: | bin/rails db:setup bin/flatware fan bin/rails db:test:prepare - name: Run tests env: SPEC_OPTS: '--exclude-pattern "**/self_destruct_scheduler_spec.rb"' run: | unset COVERAGE bin/flatware rspec -r ./spec/flatware_helper.rb - name: Notify on failures if: failure() shell: bash run: | job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="Mastodon ruby tests failed.\\n The commit is: ${{github.sha}}.\\n Job Link: ${job_link}\\n" curl -s \ -X POST \ -H 'Content-Type: application/json' \ '${{ secrets.GSPACES_BOT_DF_BUILD }}' \ -d '{"text": "'"${message}"'"}' ================================================ FILE: .github/workflows/package-install.yml ================================================ name: package-install-tests on: schedule: - cron: '0 6 * * *' workflow_dispatch: workflow_run: workflows: ["generate-site"] types: [completed] jobs: test-rpm: runs-on: ubuntu-latest if: github.repository == 'dragonflydb/dragonfly' && (github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success') container: image: ghcr.io/romange/fedora:30 steps: - name: Install on fedora run: | curl -Lo /etc/yum.repos.d/dragonfly.repo https://packages.dragonflydb.io/dragonfly.repo dnf clean all dnf makecache dnf -y install dragonfly dragonfly --version test-deb-ubuntu: runs-on: ubuntu-latest if: github.repository == 'dragonflydb/dragonfly' && (github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success') container: image: ghcr.io/romange/ubuntu:noble steps: - name: Install on ubuntu run: | apt update apt install -y curl curl -Lo /usr/share/keyrings/dragonfly-keyring.public https://packages.dragonflydb.io/pgp-key.public curl -Lo /etc/apt/sources.list.d/dragonfly.sources https://packages.dragonflydb.io/dragonfly.sources apt update apt install -y dragonfly dragonfly --version notify-on-failure: runs-on: ubuntu-latest needs: [test-rpm, test-deb-ubuntu] if: github.repository == 'dragonflydb/dragonfly' && always() && (needs.test-rpm.result == 'failure' || needs.test-deb-ubuntu.result == 'failure') steps: - name: Notify on failure run: | job_link="${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/actions/runs/${GITHUB_RUN_ID}" message="Package install tests failed.\nCommit: ${{ github.sha }}\nJob: ${job_link}" curl -sSf -X POST -H 'Content-Type: application/json' '${{ secrets.GSPACES_BOT_DF_BUILD }}' -d '{"text": "'"${message}"'"}' ================================================ FILE: .github/workflows/regression-tests.yml ================================================ name: Regression Tests on: schedule: - cron: "0 0/3 * * *" workflow_dispatch: jobs: build: if: github.repository == 'dragonflydb/dragonfly' strategy: matrix: # Test of these containers container: ["ubuntu-dev:24"] proactor: [Uring] build-type: [Debug, Release] runner: [ubuntu-latest, [self-hosted, linux, ARM64]] runs-on: ${{ matrix.runner }} permissions: id-token: write contents: read container: image: ghcr.io/romange/${{ matrix.container }} options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /var/crash:/var/crash - /:/hostroot - /mnt:/mnt steps: - uses: actions/checkout@v6 with: submodules: true - name: Print environment info run: | cat /proc/cpuinfo ulimit -a env lsblk -l - name: Build Dragonfly uses: ./.github/actions/builder with: build-type: ${{matrix.build-type}} targets: 'dragonfly' - name: Authenticate to AWS uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_CI_S3_ROLE_ARN }} aws-region: us-east-1 - name: Run regression tests action uses: ./.github/actions/regression-tests with: dfly-executable: dragonfly gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} build-folder-name: build filter: ${{ matrix.build-type == 'Release' && 'not debug_only' || 'not opt_only' }} s3-bucket: ${{ secrets.S3_REGTEST_BUCKET }} - name: Upload logs on failure if: failure() uses: actions/upload-artifact@v6 with: name: logs path: /tmp/failed/* lint-test-chart: if: github.repository == 'dragonflydb/dragonfly' runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: ./.github/actions/lint-test-chart ================================================ FILE: .github/workflows/release.yml ================================================ name: Version Release on: push: tags: - 'v*' permissions: contents: write env: RELEASE_DIR: build-release jobs: create-release: runs-on: ubuntu-latest steps: - name: Create Release uses: ncipollo/release-action@v1 with: allowUpdates: true omitBody: true prerelease: true draft: true token: ${{ secrets.GITHUB_TOKEN }} build-arm: runs-on: ubuntu-24.04-arm name: Build arm64 on ubuntu-24.04-arm needs: create-release container: image: ghcr.io/romange/ubuntu-dev:20-gcc14 options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" steps: - uses: actions/checkout@v6 with: submodules: true - name: Build artifacts run: | # Work around https://github.com/actions/checkout/issues/766 git config --global --add safe.directory "$GITHUB_WORKSPACE" git describe --always --tags ${{ github.sha }} ./tools/release.sh ./tools/packaging/generate_debian_package.sh ${{ env.RELEASE_DIR }}/dragonfly-aarch64 mv dragonfly_*.deb ${{ env.RELEASE_DIR }}/ - name: Upload uses: actions/upload-artifact@v6 with: name: dragonfly-aarch64 path: | ${{ env.RELEASE_DIR }}/dragonfly-*tar.gz ${{ env.RELEASE_DIR }}/dragonfly_*.deb ${{ env.RELEASE_DIR }}/dfly_bench-*tar.gz build-native: runs-on: ubuntu-latest needs: create-release strategy: matrix: include: # Build with these flags - name: debian container: ubuntu-dev:20-gcc14 - name: rpm container: fedora:30-gcc14 container: image: ghcr.io/romange/${{ matrix.container }} options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" # Some tests which launch their own containers need a mounted volume to write through files # into child containers volumes: - /mnt:/mnt steps: - uses: actions/checkout@v6 with: submodules: true - name: Configure run: | if [ -f /etc/redhat-release ]; then dnf install -y rpm-build libstdc++-static fi - name: Build artifacts timeout-minutes: 25 run: | # Work around https://github.com/actions/checkout/issues/766 git config --global --add safe.directory "$GITHUB_WORKSPACE" git describe --always --tags ${{ github.sha }} # set WITH_SIMSIMD=OFF for fedora:30 if [ "${{ matrix.name }}" == 'rpm' ]; then export WITH_SIMSIMD="OFF" fi ./tools/release.sh # once the build is over, we want to generate a Debian package if [ -f /etc/debian_version ]; then ./tools/packaging/generate_debian_package.sh ${{ env.RELEASE_DIR }}/dragonfly-x86_64 else echo "Creating package for ${{github.ref_name}}" ./tools/packaging/rpm/build_rpm.sh ${{ env.RELEASE_DIR }}/dragonfly-x86_64.tar.gz ${{github.ref_name}} fi - name: Save artifacts run: | # place all artifacts at the same location set -eu mkdir -p results-artifacts if [ -f /etc/debian_version ]; then mv ${{ env.RELEASE_DIR }}/dragonfly-*tar.gz results-artifacts mv dragonfly_*.deb results-artifacts mv ${{ env.RELEASE_DIR }}/dfly_bench-*tar.gz results-artifacts else ls -l *.rpm mv ./*.rpm ./results-artifacts/ fi - name: Upload uses: actions/upload-artifact@v6 with: name: dragonfly-amd64-${{ matrix.name }} path: results-artifacts/* test-regression: needs: [build-native, build-arm] runs-on: ${{ matrix.runner }} strategy: matrix: include: - name: amd64 runner: ubuntu-latest artifact: dragonfly-amd64-debian binary: dragonfly-x86_64 - name: arm64 runner: ubuntu-24.04-arm artifact: dragonfly-aarch64 binary: dragonfly-aarch64 container: image: ghcr.io/romange/ubuntu-dev:24 options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /mnt:/mnt steps: - uses: actions/checkout@v6 with: submodules: true - name: Download artifacts uses: actions/download-artifact@v7 with: name: ${{ matrix.artifact }} path: results-artifacts - name: Extract artifacts run: | set -eu mkdir -p ${{ env.RELEASE_DIR }} tar -xzf results-artifacts/dragonfly-*dbgsym.tar.gz -C ${{ env.RELEASE_DIR }} - name: Run regression tests uses: ./.github/actions/regression-tests with: dfly-executable: ${{ matrix.binary }} gspace-secret: ${{ secrets.GSPACES_BOT_DF_BUILD }} build-folder-name: ${{ env.RELEASE_DIR }} filter: 'not debug_only' publish_release: runs-on: ubuntu-latest needs: test-regression steps: - uses: actions/download-artifact@v7 name: Download files with: path: artifacts - name: See all the artifacts run: | ls -lR artifacts/ - uses: ncipollo/release-action@v1 with: artifacts: "artifacts/dragonfly-*/*" allowUpdates: true draft: true prerelease: true omitNameDuringUpdate: true token: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .github/workflows/repeat-tests.yml ================================================ name: Repeat Tests on: workflow_dispatch: inputs: branch: description: "The branch on which tests will be repeated" type: string required: false commit: description: "A specific commit SHA to test (takes precedence over branch)" type: string required: false count: description: "The number of times the tests will be repeated" type: number required: false default: 1 expression: description: "A pytest expression which will filter the tests" required: true type: string timeout: description: "Overall timeout for all test runs" required: false type: string default: "60m" epoll: description: "Force epoll mode in test" required: false type: string default: "no" use_release: description: "Use latest release instead of building dragonfly" required: false type: string default: "no" vmodule_expression: description: "Emit verbose dragonfly logs for modules, eg x=2,y=3" required: false type: string default: "" build_type: description: "Build type: Debug or Release" required: false type: choice options: - Debug - Release default: "Debug" jobs: build: strategy: matrix: container: ["ubuntu-dev:24"] proactor: [Uring] build-type: ["${{ inputs.build_type || 'Debug' }}"] runner: [ubuntu-latest] runs-on: ${{ matrix.runner }} permissions: id-token: write contents: read container: image: ghcr.io/romange/${{ matrix.container }} options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" volumes: - /var/crash:/var/crash steps: - uses: actions/checkout@v6 with: submodules: true ref: ${{ inputs.commit || inputs.branch }} - name: Print environment info run: | cat /proc/cpuinfo ulimit -a env - name: Fetch release shell: bash if: ${{ inputs.use_release == 'yes' }} run: | mkdir "${GITHUB_WORKSPACE}"/build cd "${GITHUB_WORKSPACE}"/build wget -q https://github.com/dragonflydb/dragonfly/releases/latest/download/dragonfly-x86_64.tar.gz tar xf dragonfly-x86_64.tar.gz mv dragonfly-x86_64 dragonfly ls -l - name: Build Dragonfly if: ${{ inputs.use_release != 'yes' }} uses: ./.github/actions/builder with: build-type: ${{matrix.build-type}} targets: 'dragonfly' - name: Sync valkey tests uses: ./.github/actions/sync-valkey-tests - name: Authenticate to AWS uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_CI_S3_ROLE_ARN }} aws-region: us-east-1 - name: Run tests on repeat uses: ./.github/actions/repeat with: run-only-on-ubuntu-latest: true dfly-executable: dragonfly build-folder-name: build s3-bucket: ${{ secrets.S3_REGTEST_BUCKET }} expression: ${{ inputs.expression }} count: ${{ inputs.count }} timeout: ${{ inputs.timeout }} epoll: ${{ inputs.epoll }} vmodule_expression: ${{ inputs.vmodule_expression }} - name: Upload logs on failure if: failure() uses: actions/upload-artifact@v6 with: name: logs path: /tmp/failed/* - name: Copy binary on a self hosted runner if: failure() run: | # We must use sh syntax. if [ "$RUNNER_ENVIRONMENT" = "self-hosted" ]; then cd ${GITHUB_WORKSPACE}/build timestamp=$(date +%Y-%m-%d_%H:%M:%S) mv ./dragonfly /var/crash/dragonfy_${timestamp} fi ================================================ FILE: .github/workflows/test-fakeredis.yml ================================================ --- name: Test Dragonfly/Fakeredis on: workflow_dispatch: pull_request: permissions: contents: read checks: write concurrency: group: dragonfly-${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: test: runs-on: ubuntu-latest container: image: ghcr.io/romange/ubuntu-dev:22 options: --security-opt seccomp=unconfined --sysctl "net.ipv6.conf.all.disable_ipv6=0" strategy: fail-fast: false name: "Run tests: " permissions: pull-requests: write checks: read steps: - uses: actions/checkout@v6 with: submodules: true - name: Install dependencies env: PYTHON_KEYRING_BACKEND: keyring.backends.null.Keyring shell: bash working-directory: tests/fakeredis run: | pip install poetry echo "$HOME/.poetry/bin" >> $GITHUB_PATH poetry install - name: Configure CMake run: | cmake -B ${GITHUB_WORKSPACE}/build \ -DCMAKE_BUILD_TYPE=Debug -DWITH_AWS:BOOL=OFF -DWITH_GCP:BOOL=OFF -DWITH_GPERF:BOOL=OFF \ -GNinja -L cd ${GITHUB_WORKSPACE}/build && pwd - name: Build run: | cd ${GITHUB_WORKSPACE}/build ninja dragonfly echo "-----------------------------" # The order of redirect is important ./dragonfly --proactor_threads=4 --noversion_check --port=6380 \ --lua_resp2_legacy_float 1> /tmp/dragonfly.log 2>&1 & - name: Run tests working-directory: tests/fakeredis run: | # Some tests are pending on #5383 poetry run pytest test/ \ --ignore test/test_hypothesis/test_transaction.py \ --ignore test/test_hypothesis/test_zset.py \ --ignore test/test_hypotesis_joint/test_joint.py \ --junit-xml=results-tests.xml --html=report-tests.html -v continue-on-error: false # Fail the job if tests fail - name: Show Dragonfly stats if: always() run: | redis-cli -p 6380 INFO ALL - name: Upload Tests Result xml if: always() uses: actions/upload-artifact@v6 with: name: tests-result-logs path: | /tmp/dragonfly.* - name: Upload Tests Result html if: always() uses: actions/upload-artifact@v6 with: name: report-tests.html path: tests/fakeredis/report-tests.html - name: Publish Test Report if: ${{ github.event_name == 'pull_request' }} uses: mikepenz/action-junit-report@v6 with: report_paths: tests/fakeredis/results-tests.xml # Do not create a check run # annotate_only: true publish-html-results: name: Publish HTML Test Results to GitHub Pages needs: test if: ${{ github.ref == 'refs/heads/main' }} runs-on: ubuntu-latest permissions: pages: write # to deploy to Pages id-token: write # to verify the deployment originates from an appropriate source environment: name: github-pages url: ${{ steps.deployment.outputs.page_url }} steps: - name: Bundle Tests Result to one artifact uses: actions/upload-artifact/merge@v6 with: delete-merged: true name: test-results-html pattern: '*.html' - name: Download html pages uses: actions/download-artifact@v7 with: name: test-results-html path: results/ - uses: actions/setup-python@v6 with: cache-dependency-path: tests/fakeredis/poetry.lock python-version: "3.10" - name: Merge html results run: | pip install pytest-html-merger && mkdir merged pytest_html_merger -i results/ -o merged/index.html - name: Publish to GitHub Pages uses: actions/upload-pages-artifact@v4 with: path: merged/ - name: Deploy to GitHub Pages id: deployment uses: actions/deploy-pages@v4 with: token: '${{ secrets.GITHUB_TOKEN }}' ================================================ FILE: .gitignore ================================================ build/* build-* clang/* clang-* .vscode/*.db .vscode/settings.json .vscode/launch.json third_party genfiles/* *.sublime-* *.orig .tags !third_party/include/* *.pyc /CMakeLists.txt.user _deps releases .DS_Store .idea/* .hypothesis .secrets cmake-build-debug .venv/ fuzz/artifacts/ fuzz/corpus/ tools/replay/traffic-replay # Valkey-search integration tests (synced from external repo) tests/dragonfly/valkey_search/integration/ _codeql_build_dir/ ================================================ FILE: .gitmodules ================================================ [submodule "helio"] path = helio url = https://github.com/romange/helio.git ================================================ FILE: .gitorderfile ================================================ *.py *.md *.in *.txt *.sh *.yml *.h *.cc *.lua *.go * ================================================ FILE: .nvmrc ================================================ 22.19 ================================================ FILE: .pre-commit-config.yaml ================================================ default_stages: [pre-commit] exclude: | (?x)( src/redis/.* | src/huff/.* | contrib/charts/dragonfly/ci/.* | patches/.* ) repos: - repo: local hooks: - id: conventional-commits name: Conventional Commit Minder entry: contrib/scripts/conventional-commits language: script stages: [commit-msg] - id: signed-commit name: Signed Commit Enforcer entry: contrib/scripts/signed-commit language: script stages: [commit-msg] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/pre-commit/mirrors-clang-format rev: v14.0.6 hooks: - id: clang-format name: Clang formatting - repo: https://github.com/psf/black rev: 25.1.0 hooks: - id: black ================================================ FILE: .pre-commit-hooks.yaml ================================================ - id: conventional-commits name: Conventional Commits Minder entry: contrib/scripts/conventional-commits language: script description: Conventional Commits Enforcement at the `git commit` client-side level always_run: true stages: [commit-msg] - id: signed-commit name: Signed Commit Enforcer entry: contrib/scripts/signed-commit language: script description: Ensures all commits contain a Signed-off-by line always_run: true stages: [commit-msg] ================================================ FILE: .snyk ================================================ # Snyk (https://snyk.io) policy file exclude: global: - tests/integration/** - contrib/charts/** ================================================ FILE: .vscode/c_cpp_properties.json ================================================ { "configurations": [ { "name": "Linux", "includePath": [ "${default}" ], "cStandard": "c17", "cppStandard": "c++17", "intelliSenseMode": "${default}", "compileCommands": "${workspaceFolder}/build-dbg/compile_commands.json", "configurationProvider": "ms-vscode.cmake-tools" } ], "version": 4 } ================================================ FILE: AGENTS.md ================================================ # Dragonfly Development Guide > **Essential reference for working with the Dragonfly codebase** > Architecture, build system, testing infrastructure, and development workflows. --- ## Table of Contents 1. [Critical Workflow Rules](#critical-workflow-rules) 2. [Quick Command Reference](#quick-command-reference) 3. [Project Overview](#project-overview) 4. [Repository Structure](#repository-structure) 5. [Build Instructions](#build-instructions) 6. [Testing](#testing) 7. [CI/CD Pipeline](#cicd-pipeline) 8. [Code Style & Pre-commit Hooks](#code-style--pre-commit-hooks) 9. [Third-Party Dependencies](#third-party-dependencies) 10. [Platform Support](#platform-support) 11. [CMake Build Options](#cmake-build-options) 12. [Key Files Reference](#key-files-reference) 13. [Common Pitfalls](#common-pitfalls) 14. [Debugging Tips](#debugging-tips) 15. [Validation Checklist](#validation-checklist) --- ## Critical Workflow Rules **MANDATORY - Always Follow This Order:** 1. ✅ **Read Before Edit** - Always read files before modifying 2. ✅ **Use Correct Build Commands** - See [Quick Command Reference](#quick-command-reference) below 3. ✅ **Test After Changes** - Build and run a relevant unit test - `ninja && ./unit_test` 4. ✅ **Format Code** - `pre-commit run --files ` 5. ✅ **Follow Architecture** - See [Architecture Patterns](#architecture-patterns) below ### Pull Request Guidelines **Conciseness is Key**: PR descriptions should be short, focused, and easy to scan. - **Title**: Imperative, descriptive (e.g., "Fix fiber stack overflow in test_reply_guard_oom") - **Summary**: 1-2 sentences explaining *what* changed and *why* - **Changes**: Bullet points for key changes - **Fixes**: Link issues (e.g., "Fixes #123") - **Commit messages**: Keep every line (subject and body) <= 100 characters; wrap long descriptions --- ## Quick Command Reference **CRITICAL: Read the full sections below for context. These are shortcuts only.** ### Building (see [Build Instructions](#build-instructions) for details) ```bash # Debug build (for development) ./helio/blaze.sh cd build-dbg && ninja dragonfly # Build main binary cd build-dbg && ninja generic_family_test # Build specific test # Release build (for production/benchmarking) ./helio/blaze.sh -release cd build-opt && ninja dragonfly ``` ### Testing (see [Testing](#testing) for details) ```bash # C++ Unit Tests cd build-dbg ctest -V -L DFLY # Run all tests ./generic_family_test # Run specific test binary ./generic_family_test --gtest_filter="Set.*" # Run specific test case ``` ### Code Formatting ```bash # Setup (once) pipx install pre-commit clang-format black pre-commit install # Format code pre-commit run --files # Format specific files pre-commit run --all-files # Format all files ``` ### Common Operations ```bash # Check git status git status # Check current branch git branch # View recent commits git log --oneline -10 ``` --- ## Architecture Patterns **Code Style**: [.clang-format](.clang-format) - snake_case vars, PascalCase functions, kPascalCase constants **DO ✅**: - Fiber-aware: `util::fb2::Mutex`, `util::fb2::Fiber` → [helio/util/fibers/](helio/util/fibers/) - Per-shard ops (no global state) → [docs/df-share-nothing.md](docs/df-share-nothing.md) - Command pattern → [src/server/set_family.cc](src/server/set_family.cc) - Error handling: `OpStatus` → [src/server/common.h](src/server/common.h) - Test patterns → [tests/dragonfly/conftest.py](tests/dragonfly/conftest.py) **DON'T ❌**: - `std::thread`, `std::mutex` (deadlocks!) - Global mutable state - Edit without reading - Skip tests - Use `./tools/docker/build.sh` for local development (use `ninja` instead) - Use `make` for incremental builds (use `ninja` instead) --- ## Project Overview **Dragonfly** is a high-performance, Redis and Memcached compatible in-memory data store written in C++20. It delivers significantly higher throughput than traditional single-threaded Redis implementations through innovative architectural choices. ### Key Characteristics - **Language**: C++20 (Google C++ Style Guide 2020 version) - **Architecture**: Shared-nothing multi-threaded design (via `helio` library) - **Performance**: Uses io_uring (Linux 5.11+) for high-performance async I/O, with epoll fallback - **Threading Model**: Fiber-based cooperative multitasking with lock-free data structures - **Build System**: CMake + Ninja via `helio/blaze.sh` wrapper script - **Target Platform**: Linux (kernel 5.11+ recommended), FreeBSD support available - **Protocols**: Redis RESP2/RESP3, Memcached binary protocol - **Compatibility**: Drop-in replacement for Redis API coverage ### Architectural Highlights **For detailed architecture documentation, see [docs/df-share-nothing.md](docs/df-share-nothing.md)** 1. **Shared-Nothing Design**: Each thread operates independently with its own data structures, minimizing lock contention 2. **Helio Framework**: Custom I/O and threading library built on io_uring/epoll with fiber support 3. **DashTable**: Novel hash table implementation optimized for multi-core systems - see [docs/dashtable.md](docs/dashtable.md) 4. **Transaction Model**: Non-blocking optimistic transactions - see [docs/transaction.md](docs/transaction.md) 5. **Tiering Support**: Optional disk-backed storage for large datasets 6. **Search Module**: Full-text search capabilities (when enabled with WITH_SEARCH) --- ## Repository Structure ``` dragonfly/ ├── src/ # Main C++ source code │ ├── server/ # Core server implementation │ │ ├── dfly_main.cc # Main entry point │ │ ├── main_service.cc # Service lifecycle & command routing │ │ ├── db_slice.cc # Per-thread database shard │ │ ├── engine_shard_set.cc # Shard management │ │ ├── cluster/ # Cluster mode implementation │ │ ├── journal/ # Replication journal │ │ ├── tiering/ # Tiered storage │ │ ├── search/ # Search module │ │ └── acl/ # Access control lists │ ├── core/ # Core data structures │ │ ├── dash.h # DashTable hash table │ │ ├── dense_set.h # Compact set implementation │ │ ├── string_map.h # Optimized string-keyed maps │ │ ├── search/ # Search core algorithms │ │ └── json/ # JSON support │ ├── facade/ # Network & command handling │ │ ├── dragonfly_connection.cc # Connection management │ │ ├── redis_parser.cc # RESP protocol parser │ │ └── memcache_parser.cc # Memcached protocol │ └── redis/ # Redis-specific implementations │ └── lua/ # Lua scripting support │ ├── helio/ # Git submodule: I/O and threading library │ │ # ** DO NOT EDIT unless contributing to helio ** │ ├── util/ # Utilities: fibers, I/O, synchronization │ ├── io/ # io_uring & epoll abstraction │ └── blaze.sh # Build configuration wrapper │ ├── tests/ # Test suite │ ├── dragonfly/ # Python pytest integration/regression tests │ │ ├── conftest.py # Pytest fixtures & configuration │ │ ├── requirements.txt # Python test dependencies │ │ └── *.py # Test files │ └── pytest.ini # Pytest configuration & markers │ ├── docs/ # Documentation │ ├── build-from-source.md # Build instructions │ ├── dashtable.md # DashTable internals │ ├── transaction.md # Transaction model │ ├── df-share-nothing.md # Shared-nothing architecture │ └── differences.md # Differences from Redis │ ├── contrib/ # Utilities │ ├── docker/ # Docker configurations │ └── charts/dragonfly/ # Helm chart for Kubernetes │ ├── tools/ # Benchmarking & utility tools │ └── packaging/ # Packaging scripts │ ├── CMakeLists.txt # Root CMake configuration ├── .clang-format # C++ formatting rules (clang-format v14.0.6) ├── .pre-commit-config.yaml # Pre-commit hooks configuration ├── pyproject.toml # Python formatting (Black, 100 chars) └── CONTRIBUTING.md # Contribution guidelines ``` ### Critical Paths to Remember - **Main entry**: `src/server/dfly_main.cc` - **Command dispatch**: `src/server/main_service.cc` - **Data storage**: `src/server/db_slice.cc` - **Networking**: `src/facade/dragonfly_connection.cc` - **Helio library**: `helio/` (I/O and threading library) --- ## Build Instructions **For complete build instructions, see [docs/build-from-source.md](docs/build-from-source.md)** ### Quick Start **Debug build** (for development): ```bash ./helio/blaze.sh cd build-dbg && ninja dragonfly ./dragonfly --alsologtostderr ``` **Release build** (for production/benchmarking): ```bash ./helio/blaze.sh -release cd build-opt && ninja dragonfly ``` **Production release build** (static linking, optimized): ```bash make release # Configure + build make package # Create release packages with debug symbols ``` The [Makefile](Makefile) builds production releases with: - Static linking: libstdc++, libgcc, Boost, OpenSSL - Architecture optimizations (x86_64: `-march=core2 -msse4.1 -mtune=skylake`) - Debug symbols (compressed) - Output: `build-release/dragonfly-{arch}.tar.gz` **Common build options**: - See [docs/build-from-source.md](docs/build-from-source.md) for all options --- ## Testing **For complete testing documentation, see [tests/README.md](tests/README.md)** ### Quick Reference **C++ Unit Tests**: ```bash cd build-dbg ctest -V -L DFLY # Run all tests ./generic_family_test # Run specific test binary ./generic_family_test --gtest_filter="Set.*" # Run specific test case ``` --- ## CI/CD Pipeline **For complete CI configuration, see [.github/workflows/ci.yml](.github/workflows/ci.yml)** The CI workflow runs on all PRs and includes: - **Pre-commit checks**: clang-format, black formatters - **Build matrix**: Multiple OS/compiler/sanitizer combinations (Ubuntu 20/24, Alpine, GCC/Clang, ASAN/UBSAN) - **Test execution**: C++ unit tests, Python integration tests, cluster mode tests - **Additional validations**: Helm charts, Docker image builds --- ## Code Style & Pre-commit Hooks **For complete contribution guidelines, see [CONTRIBUTING.md](CONTRIBUTING.md)** **Code style configuration files**: - **C++**: [.clang-format](.clang-format) - Google C++ Style Guide (2020), clang-format v14.0.6, 100 char limit - **Python**: [pyproject.toml](pyproject.toml) - Black formatter, 100 char limit, PEP 8 compliant - **Pre-commit hooks**: [.pre-commit-config.yaml](.pre-commit-config.yaml) - Automated formatting checks **Quick setup**: ```bash pipx install pre-commit clang-format black pre-commit install pre-commit run --all-files # Run all formatters ``` --- ## Third-Party Dependencies **Key Libraries**: Abseil (strings/flags), Boost 1.71+ (context/intrusive), mimalloc (allocator), jsoncons (JSON), OpenSSL (TLS), libunwind (traces) **Build artifacts**: `build-dbg/third_party/` - DO NOT edit **For complete dependency info, see [docs/build-from-source.md](docs/build-from-source.md)** --- ## Platform Support **Linux**: Primary platform. Kernel 5.11+ (io_uring), 5.1+ (basic), < 5.1 (epoll fallback) - Check: `uname -r` - Force epoll: `--proactor_type=epoll` - Docker: `--security-opt seccomp=unconfined` **FreeBSD**: Supported (kqueue backend) **macOS**: Not supported for production (use Docker/Linux) **For complete platform info, see [docs/build-from-source.md](docs/build-from-source.md)** --- ## CMake Build Options **For complete list of build options, see [docs/build-from-source.md](docs/build-from-source.md)** ### Common Options Pass options to `helio/blaze.sh` with `-D` prefix: ```bash ./helio/blaze.sh -DWITH_SEARCH=OFF -DWITH_AWS=ON ``` **Most useful options**: - `WITH_ASAN=ON` / `WITH_USAN=ON` - Enable sanitizers for debugging - `WITH_SEARCH=OFF` - Disable search module for faster builds - `WITH_AWS=OFF` / `WITH_GCP=OFF` - Disable cloud libraries - `WITH_TIERING=OFF` - Disable disk storage - `USE_MOLD=ON` - Faster linking with LTO (production builds) **Quick configurations**: ```bash # Minimal build (fast compilation) ./helio/blaze.sh -DWITH_GPERF=OFF -DWITH_AWS=OFF -DWITH_GCP=OFF -DWITH_TIERING=OFF -DWITH_SEARCH=OFF # Full-featured (all options ON by default) ./helio/blaze.sh # Production optimized ./helio/blaze.sh -release -DUSE_MOLD=ON ``` --- ## Key Files Reference Quick reference to the most important files in the codebase. | Purpose | File Path | |---------|-----------| | **Entry Points & Core** | | | Main entry point | `src/server/dfly_main.cc` | | Server lifecycle & command routing | `src/server/main_service.cc` | | Per-thread database shard | `src/server/db_slice.cc` | | Shard management | `src/server/engine_shard_set.cc` | | **Data Structures** | | | DashTable hash table | `src/core/dash.h` | | Dense set implementation | `src/core/dense_set.h` | | String map | `src/core/string_map.h` | | **Networking** | | | Connection handling | `src/facade/dragonfly_connection.cc` | | Redis protocol parser | `src/facade/redis_parser.cc` | | Memcached protocol parser | `src/facade/memcache_parser.cc` | | **Build System** | | | Root CMake config | `CMakeLists.txt` | | Build script wrapper | `helio/blaze.sh` | | Server CMake config | `src/server/CMakeLists.txt` | | **CI/CD** | | | Main CI workflow | `.github/workflows/ci.yml` | | Pre-commit config | `.pre-commit-config.yaml` | | **Code Style** | | | C++ formatting | `.clang-format` | | Python formatting | `pyproject.toml` | | **Testing** | | | Pytest configuration | `tests/pytest.ini` | | Pytest fixtures | `tests/dragonfly/conftest.py` | | Test requirements | `tests/dragonfly/requirements.txt` | | **Documentation** | | | Build instructions | `docs/build-from-source.md` | | Architecture overview | `docs/df-share-nothing.md` | | DashTable internals | `docs/dashtable.md` | | Transaction model | `docs/transaction.md` | | **Configuration** | | | Contributing guide | `CONTRIBUTING.md` | | CLA agreement | `CLA.txt` | --- ## Common Pitfalls 1. **Pre-commit not installed**: `pipx install pre-commit clang-format black && pre-commit install` 2. **Wrong binary**: Debug: `build-dbg/dragonfly`, Release: `build-opt/dragonfly` 3. **Wrong build command**: Use `cd build-dbg && ninja `, NOT `./tools/docker/build.sh` 4. **Test timeouts**: `timeout 20m ctest -V -L DFLY` 5. **ASAN leaks**: Check CI, suppress in `helio/util/asan_suppressions.txt` 6. **Helio modifications**: DON'T edit `helio/` (it's a git submodule - changes go upstream) 7. **CodeQL checks**: DON'T run codeql_checker when testing changes - it's slow and unnecessary for development --- ## Debugging Tips **Logging**: `--alsologtostderr --v=1 --vmodule=module=2` **ASAN**: `ASAN_OPTIONS=detect_leaks=1:symbolize=1`, suppressions: `helio/util/asan_suppressions.txt` **CI reproduction**: See [.github/workflows/ci.yml](.github/workflows/ci.yml) **Troubleshooting**: Check fiber deadlocks (use `util::fb2` not `std::mutex`), timeout issues (`--test_timeout`), ASAN reports --- ## Validation Checklist Before claiming a task is complete, verify: ### Code Quality - [ ] Code compiles without errors: `cd build-dbg && ninja dragonfly` - [ ] Code compiles without warnings (CI uses `-Werror`) - [ ] Code follows Google C++ Style Guide (run `clang-format`) - [ ] No new ASAN/UBSAN violations ### Testing - [ ] All existing C++ unit tests pass: `ctest -V -L DFLY` - [ ] New feature has corresponding test coverage - [ ] Tests pass in both Debug and Release builds - [ ] Tests pass with ASAN/UBSAN enabled (if applicable) - [ ] **DO NOT run codeql_checker** - it's slow and unnecessary for development testing ### Pre-commit & Style - [ ] Pre-commit hooks installed: `pre-commit install` - [ ] Code formatted with clang-format (C++) and black (Python) ### Documentation - [ ] Public APIs have comments explaining purpose - [ ] Complex algorithms have explanatory comments - [ ] README or docs updated if behavior changes - [ ] No commented-out code left in final commit ### Performance - [ ] No obvious performance regressions (run benchmarks if needed) - [ ] No unnecessary allocations in hot paths - [ ] Lock-free data structures used where appropriate ================================================ FILE: CLA.txt ================================================ Thanks for your interest in contributing to Dragonfly™. By contributing to this project in any way form or media you grant DragonflyDB Ltd. and its affiliates a perpetual, worldwide, non-exclusive, free of charge, royalty-free, irrevocable license to use, modify, make available, reproduce, make derivatives, publicly display and perform, sublicense, sell, and distribute your contributions and any derivatives thereof as part of Dragonfly™. You represent that You are legally entitled to grant the above license. You acknowledge that DragonflyDB currently distributes Dragonfly™ under the Business Source License 1.1 (BSL-1.1) license, and agree that your contribution may be distributed under BSL-1.1 as part of Dragonfly™. You also represent that your contributions are your original work and that neither the content contributed, nor making the contribution to Dragonfly™ violates any third party’ rights. If you are making this contribution while being engaged by any other company or entity, please make sure you have the necessary permissions required to do so. ================================================ FILE: CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.15 FATAL_ERROR) set(PROJECT_CONTACT romange@gmail.com) include(CheckCXXCompilerFlag) enable_testing() set(CMAKE_EXPORT_COMPILE_COMMANDS 1) # AFL++ fuzzing support - must be set BEFORE project() command option(USE_AFL "Enable AFL++ fuzzing" OFF) if(USE_AFL) # Automatically set AFL++ compilers if not already set if(NOT CMAKE_C_COMPILER MATCHES "afl-" AND NOT CMAKE_CXX_COMPILER MATCHES "afl-") find_program(AFL_CC afl-clang-fast) find_program(AFL_CXX afl-clang-fast++) if(AFL_CC AND AFL_CXX) message(STATUS "AFL++ fuzzing enabled - setting compilers") set(CMAKE_C_COMPILER ${AFL_CC}) set(CMAKE_CXX_COMPILER ${AFL_CXX}) else() message(FATAL_ERROR "USE_AFL=ON but AFL++ compilers not found!\n" "Please install AFL++: apt install afl++ or build from source\n" "https://github.com/AFLplusplus/AFLplusplus") endif() endif() endif() # Set targets in folders set_property(GLOBAL PROPERTY USE_FOLDERS ON) project(DRAGONFLY C CXX) set(CMAKE_CXX_STANDARD 20) # Disabled because it has false positives with ref-counted intrusive pointers. CHECK_CXX_COMPILER_FLAG("-Wuse-after-free" HAS_USE_AFTER_FREE_WARN) if (HAS_USE_AFTER_FREE_WARN) set(CMAKE_CXX_FLAGS "-Wno-use-after-free ${CMAKE_CXX_FLAGS}") endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set(CMAKE_CXX_FLAGS "-Wthread-safety -Werror=thread-safety ${CMAKE_CXX_FLAGS}") endif() # We can not use here CHECK_CXX_COMPILER_FLAG because systems that do not support sanitizers # fail during linking time. set(CMAKE_REQUIRED_FLAGS "-fsanitize=address") check_cxx_source_compiles("int main() { return 0; }" SUPPORT_ASAN) set(CMAKE_REQUIRED_FLAGS "-fsanitize=undefined") check_cxx_source_compiles("int main() { return 0; }" SUPPORT_USAN) set(CMAKE_REQUIRED_FLAGS "") # We must define all the required variables from the root cmakefile, otherwise # they just disappear. set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/helio/cmake" ${CMAKE_MODULE_PATH}) option(BUILD_SHARED_LIBS "Build shared libraries" OFF) option(DF_USE_SSL "Provide support for SSL connections" ON) find_package(OpenSSL) # AFL++ configuration - must be before sanitizer checks if(USE_AFL) message(STATUS "AFL++ fuzzing mode active") message(STATUS " C compiler: ${CMAKE_C_COMPILER}") message(STATUS " C++ compiler: ${CMAKE_CXX_COMPILER}") # Add USE_AFL as compile definition so #ifdef USE_AFL works in code add_compile_definitions(USE_AFL) # AFL++ requires specific compiler flags for coverage instrumentation set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g") # Force disable sanitizers when fuzzing (AFL++ incompatible with ASAN/UBSAN) message(STATUS "Disabling sanitizers (incompatible with AFL++ fuzzing)") set(WITH_ASAN OFF CACHE BOOL "Disable ASAN for fuzzing" FORCE) set(WITH_USAN OFF CACHE BOOL "Disable UBSAN for fuzzing" FORCE) # Disable AWS and GCP for fuzzing builds (not needed, reduces build time) message(STATUS "Disabling AWS and GCP integrations for fuzzing") set(WITH_AWS OFF CACHE BOOL "Disable AWS for fuzzing" FORCE) set(WITH_GCP OFF CACHE BOOL "Disable GCP for fuzzing" FORCE) endif() option(WITH_ASAN "Enable -fsanitize=address" OFF) if (SUPPORT_ASAN AND WITH_ASAN) message(STATUS "address sanitizer enabled") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") endif() option(WITH_USAN "Enable -fsanitize=undefined" OFF) if (SUPPORT_USAN AND WITH_USAN) message(STATUS "ub sanitizer enabled") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=undefined") endif() include(third_party) include(internal) include_directories(src) include_directories(helio) add_subdirectory(helio) add_subdirectory(src) ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at https://github.com/dragonflydb/dragonfly/discussions. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Dragonfly DB Thank you for your interest in Dragonfly DB. Feel free to browse our [Discussions](https://github.com/dragonflydb/dragonfly/discussions) and [Issues](https://github.com/dragonflydb/dragonfly/issues) ## Build from source See [building from source](./docs/build-from-source.md) Please note that to build a development/debug version, it's better to alter the configure and build steps above with: ```sh ./helio/blaze.sh # without '-release' flag. Creates build-dbg subfolder cd build-dbg && ninja dragonfly ``` ## Before you make your changes ```sh cd dragonfly # project root # Make sure you have 'pre-commit', 'clang-format' and black is installed pipx install pre-commit clang-format pipx install pre-commit black # IMPORTANT! Enable our pre-commit message hooks # This will ensure your commits match our formatting requirements pre-commit install ``` This step must be done on each machine you wish to develop and contribute from to activate the `commit-msg` and `commit` hooks client-side. Once you have done these things, we look forward to adding your contributions and improvements to the Dragonfly DB project. ## Unit testing ``` # Build a specific test cd build-dbg && ninja [test_name] # e.g cd build-dbg && ninja generic_family_test # Run ./[test_name] # e.g ./generic_family_test ``` ## Rendering Helm golden files A Golang golden test is included in the dragonfly helm chart. This test will render the chart and compare the output to a golden file. If the output has changed, the test will fail and the golden file will need to be updated. This can be done by running: ```bash cd contrib/charts/dragonfly go test -v ./... -update ``` This makes it easy to see the changes in the rendered output without having to manually run the `helm template` and diff the output. ## Signoff Commits All community submissions must include a signoff. ```bash git commit -s -m '...' ``` ## Squash Commits Please squash all commits for a change into a single commit (this can be done using "git rebase -i"). Do your best to have a well-formed commit message for the change. ## Use Conventional Commits This repo uses [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) The Conventional Commits specification is a lightweight convention on top of commit messages. It provides an easy set of rules for creating an explicit commit history; which makes it easier to write automated tools on top of. This convention dovetails with [SemVer](http://semver.org), by describing the features, fixes, and breaking changes made in commit messages. The commit message should be structured as follows: --- ``` [optional scope]: [optional body] [optional footer(s)] ``` --- This repo uses automated tools to standardize the formatting of code, text files, and commits. - [Pre-commit hooks](#pre-commit-hooks) validate and automatically apply code formatting rules. ## `pre-commit` hooks The Dragonfly DB team has agreed to systematically use several pre-commit hooks to normalize the formatting of code. You need to install and enable pre-commit to have these used when you do your commits. ## Codebase guidelines This repo conforms to the Google's C++ Style Guide. Keep in mind we use an older version of the style guide which can be found [here](https://github.com/google/styleguide/blob/505ba68c74eb97e6966f60907ce893001bedc706/cppguide.html). Any exceptions to the rules specified in the style guide will be documented here. ## License terms for contributions Please see our [CLA agreement](./CLA.txt) ## THANK YOU FOR YOUR CONTRIBUTIONS ================================================ FILE: CONTRIBUTORS.md ================================================ # Contributors (alphabetical by surname) * **[Amir Alperin](https://github.com/iko1)** * **[Philipp Born](https://github.com/tamcore)** * Helm Chart * **[Meng Chen](https://github.com/matchyc)** * **[Yuxuan Chen](https://github.com/YuxuanChen98)** * **[Pawel Kaplinski](https://github.com/pawelKapl)** * **[Redha Lhimeur](https://github.com/redhal)** * **[Braydn Moore](https://github.com/braydnm)** * **[Logan Raarup](https://github.com/logandk)** * **[Ryan Russell](https://github.com/ryanrussell)** * Docs & Code Readability * **[Ali-Akber Saifee](https://github.com/alisaifee)** * **[Elle Y](https://github.com/inohime)** * **[ATM SALEH](https://github.com/ATM-SALEH)** * **[Shohei Shiraki](https://github.com/highpon)** * **[Leonardo Mello](https://github.com/lsvmello)** * **[Nico Coetzee](https://github.com/nicc777)** ================================================ FILE: LICENSE.md ================================================ # Dragonfly Business Source License 1.1 License: Business Source License 1.1 [BSL 1.1](https://spdx.org/licenses/BUSL-1.1.html) Licensor: DragonflyDB, Ltd. Licensed Work: Dragonfly including the software components, or any portion of them, and any modification. Change Date: March 1, 2029 Change License: [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0), as published by the Apache Foundation. Additional Use Grant: You may make use of the Licensed Work (i) only as part of your own product or service, provided it is not an in-memory data store product or service; and (ii) provided that you do not use, provide, distribute, or make available the Licensed Work as a Service. A “Service” is a commercial offering, product, hosted, or managed service, that allows third parties (other than your own employees and contractors acting on your behalf) to access and/or use the Licensed Work or a substantial set of the features or functionality of the Licensed Work to third parties as a software-as-a-service, platform-as-a-service, infrastructure-as-a-service or other similar services that compete with Licensor products or services. Text of BSL 1.1 The Licensor hereby grants you the right to copy, modify, create derivative works, redistribute, and make non-production use of the Licensed Work. The Licensor may make an Additional Use Grant, above, permitting limited production use. Effective on the Change Date, or the fourth anniversary of the first publicly available distribution of a specific version of the Licensed Work under this License, whichever comes first, the Licensor hereby grants you rights under the terms of the Change License, and the rights granted in the paragraph above terminate. If your use of the Licensed Work does not comply with the requirements currently in effect as described in this License, you must purchase a commercial license from the Licensor, its affiliated entities, or authorized resellers, or you must refrain from using the Licensed Work. All copies of the original and modified Licensed Work, and derivative works of the Licensed Work, are subject to this License. This License applies separately for each version of the Licensed Work and the Change Date may vary for each version of the Licensed Work released by Licensor. You must conspicuously display this License on each original or modified copy of the Licensed Work. If you receive the Licensed Work in original or modified form from a third party, the terms and conditions set forth in this License apply to your use of that work. Any use of the Licensed Work in violation of this License will automatically terminate your rights under this License for the current and all other versions of the Licensed Work. This License does not grant you any right in any trademark or logo of Licensor or its affiliates (provided that you may use a trademark or logo of Licensor as expressly required by this License). TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON AN “AS IS” BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND TITLE. ================================================ FILE: Makefile ================================================ BUILD_ARCH := $(shell uname -m) RELEASE_NAME := "dragonfly-${BUILD_ARCH}" HELIO_RELEASE_FLAGS = -DHELIO_RELEASE_FLAGS="-g" HELIO_USE_STATIC_LIBS = ON HELIO_OPENSSL_USE_STATIC_LIBS = ON HELIO_ENABLE_GIT_VERSION = ON HELIO_WITH_UNWIND ?= OFF RELEASE_DIR=build-release WITH_SIMSIMD ?= ON # Some distributions (old fedora) have incorrect dependencies for crypto # so we add -lz for them. LINKER_FLAGS=-lz # equivalent to: if $(uname_m) == x86_64 || $(uname_m) == amd64 # Override HELIO_MARCH_OPT via environment: make HELIO_MARCH_OPT="-march=native" ifneq (, $(filter $(BUILD_ARCH),x86_64 amd64)) HELIO_MARCH_OPT ?= -march=core2 -msse4.1 -mpopcnt -mtune=skylake endif # For release builds we link statically libstdc++ and libgcc. Currently, # all the release builds are performed by gcc. LINKER_FLAGS += -static-libstdc++ -static-libgcc # Optional ASAN support: make ASAN=1 release ifdef ASAN SANITIZE_COMPILE_FLAGS = -fsanitize=address -Wno-maybe-uninitialized SANITIZE_LINK_FLAGS = -fsanitize=address endif HELIO_FLAGS = -DHELIO_RELEASE_FLAGS="-g" \ -DCMAKE_CXX_FLAGS="$(SANITIZE_COMPILE_FLAGS)" \ -DCMAKE_EXE_LINKER_FLAGS="$(LINKER_FLAGS) $(SANITIZE_LINK_FLAGS)" \ -DBoost_USE_STATIC_LIBS=$(HELIO_USE_STATIC_LIBS) \ -DOPENSSL_USE_STATIC_LIBS=$(HELIO_OPENSSL_USE_STATIC_LIBS) \ -DENABLE_GIT_VERSION=$(HELIO_ENABLE_GIT_VERSION) \ -DWITH_SIMSIMD=$(WITH_SIMSIMD) \ -DWITH_UNWIND=$(HELIO_WITH_UNWIND) -DMARCH_OPT="$(HELIO_MARCH_OPT)" .PHONY: default configure: cmake -L -B $(RELEASE_DIR) -DCMAKE_BUILD_TYPE=Release -GNinja $(HELIO_FLAGS) build: cd $(RELEASE_DIR); \ ninja dfly_bench dragonfly && ldd dragonfly package: cd $(RELEASE_DIR); \ tar cvfz $(RELEASE_NAME)-dbgsym.tar.gz dragonfly ../LICENSE.md; \ objcopy \ --remove-section=".debug_*" \ --remove-section="!.debug_line" \ --compress-debug-sections \ dragonfly \ $(RELEASE_NAME); \ tar cvfz $(RELEASE_NAME).tar.gz $(RELEASE_NAME) ../LICENSE.md; \ objcopy \ --remove-section=".debug_*" \ --remove-section="!.debug_line" \ --compress-debug-sections \ dfly_bench \ dfly_bench-$(BUILD_ARCH); \ tar cvfz dfly_bench-$(BUILD_ARCH).tar.gz dfly_bench-$(BUILD_ARCH) release: configure build default: release ================================================ FILE: README.ja-JP.md ================================================

Dragonfly

[![ci-tests](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml/badge.svg)](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml) [![Twitter URL](https://img.shields.io/twitter/follow/dragonflydbio?style=social)](https://twitter.com/dragonflydbio) その他の言語: [English](README.md) [简体中文](README.zh-CN.md) [한국어](README.ko-KR.md) [Português](README.pt-BR.md) [Web サイト](https://www.dragonflydb.io/) • [ドキュメント](https://dragonflydb.io/docs) • [クイックスタート](https://www.dragonflydb.io/docs/getting-started) • [コミュニティ Discord](https://discord.gg/HsPjXGVH85) • [Dragonfly Forum](https://dragonfly.discourse.group/) • [Join the Dragonfly Community](https://www.dragonflydb.io/community) [GitHub Discussions](https://github.com/dragonflydb/dragonfly/discussions) • [GitHub Issues](https://github.com/dragonflydb/dragonfly/issues) • [コントリビュート](https://github.com/dragonflydb/dragonfly/blob/main/CONTRIBUTING.md) ## 世界最速のインメモリデータストア Dragonfly は最新のアプリケーションワークロードのために構築されたインメモリデータストアです。 Redis や Memcached の API と完全に互換性があるため、Dragonfly を採用するためにコードを変更する必要はありません。従来のインメモリデータストアと比較して、Dragonfly は 25 倍のスループット、より低いテールレイテンシでより高いキャッシュヒット率、そして容易な垂直スケーラビリティを提供します。 ## コンテンツ - [ベンチマーク](#ベンチマーク) - [クイックスタート](https://github.com/dragonflydb/dragonfly/tree/main/docs/quick-start) - [コンフィグ](#コンフィグ) - [ロードマップとステータス](#ロードマップとステータス) - [デザイン決定](#デザイン決定) - [バックグラウンド](#バックグラウンド) ## ベンチマーク ベンチマークでは、Dragonfly は Redis と比較して 25 倍のスループットを示し、c6gn.16xlarge で 3.8M QPS を超えました。 Dragonfly のピークスループットにおける 99 パーセンタイルのレイテンシ指標: | op | r6g | c6gn | c7g | |-------|-------|-------|-------| | set | 0.8ms | 1ms | 1ms | | get | 0.9ms | 0.9ms | 0.8ms | | setex | 0.9ms | 1.1ms | 1.3ms | *すべてのベンチマークは `memtier_benchmark` (下記参照) を使い、スレッド数はサーバーとインスタンスタイプごとに調整しました。`memtier` は別の c6gn.16xlarge マシンで実行した。SETEX ベンチマークの有効期限は 500 に設定し、テストが終了しても有効であることを確認しました。* ```bash memtier_benchmark --ratio ... -t -c 30 -n 200000 --distinct-client-seed -d 256 \ --expiry-range=... ``` パイプラインモード `--pipeline=30` では、Dragonfly は SET 操作で **10M QPS**、GET 操作で **15M QPS** に達する。 ### Dragonfly vs. Memcached AWS 上の c6gn.16xlarge インスタンスで Dragonfly と Memcached を比較した。 同程度のレイテンシで、Dragonfly のスループットは Memcached のスループットを書き込みと読み込みの両方のワークロードで上回った。Dragonfly は、[Memcached の書き込みパス](docs/memcached_benchmark.md)での競合により、書き込みワークロードでより優れたレイテンシを示しました。 #### SET ベンチマーク | Server | QPS(thousands qps) | latency 99% | 99.9% | |:---------:|:------------------:|:-----------:|:-------:| | Dragonfly | 🟩 3844 |🟩 0.9ms | 🟩 2.4ms | | Memcached | 806 | 1.6ms | 3.2ms | #### GET ベンチマーク | Server | QPS(thousands qps) | latency 99% | 99.9% | |-----------|:------------------:|:-----------:|:-------:| | Dragonfly | 🟩 3717 | 1ms | 2.4ms | | Memcached | 2100 | 🟩 0.34ms | 🟩 0.6ms | Memcached は読み取りベンチマークでより低いレイテンシを示したが、スループットも低かった。 ### メモリ効率 メモリ効率をテストするために、`debug populate 5000000 key 1024` コマンドを使用して Dragonfly と Redis に ~5GB のデータを入れ、`memtier` コマンドで更新トラフィックを送信し、`bgsave` コマンドでスナップショットを開始しました。 この図は、各サーバがメモリ効率の面でどのような挙動を示したかを示している。 Dragonfly はアイドル状態では Redis よりも 30% メモリ効率が高く、スナップショットフェーズではメモリ使用量の目に見える増加は見られなかった。ピーク時には Redis のメモリ使用量は Dragonfly の 3 倍近くまで増加しました。 Dragonfly はスナップショットをより早く、数秒以内に終了させました。 Dragonfly のメモリ効率の詳細については、[Dashtable ドキュメント](/docs/dashtable.md)を参照してください。 ## コンフィグ Dragonfly は一般的な Redis の引数をサポートしています。例えば `dragonfly --requirepass=foo --bind localhost`。 Dragonfly は現在、以下の Redis 固有の引数をサポートしています: * `port`: Redis 接続ポート (`default: 6379`). * `bind`: ローカルホストからの接続のみを許可する場合は `localhost` を、**その IP** アドレスへの接続 (つまり外部からの接続) を許可する場合はパブリック IP アドレスを指定する。 * `requirepass`: AUTH 認証用のパスワード (`default: ""`)。 * `maxmemory`: データベースが使用するメモリの上限 (人間が読めるバイト数) (`default: 0`)。 `maxmemory` に `0` を指定すると、プログラムが自動的に最大メモリ使用量を決定する。 * `dir`: Dragonfly Docker はデフォルトで `/data` フォルダをスナップショットに使用し、CLI は `""` を使用する。`v` の Docker オプションでホストフォルダにマッピングできる。 * `dbfilename`: データベースを保存・ロードするファイル名 (`default: dump`). Dragonfly 特有の議論もある: * `memcached_port`: Memcached 互換 API を有効にするポート (`default: disabled`)。 * `keys_output_limit`: `keys` コマンドで返されるキーの最大数(`default: 8192`)。`keys` は危険なコマンドであることに注意してください。あまりに多くのキーを取得するとメモリ使用量が増大するため、結果を切り捨てています。 * `dbnum`: `select` でサポートされるデータベースの最大数。 * `cache_mode`: 以下の[斬新なキャッシュデザイン](#斬新なキャッシュデザイン)のセクションを参照してください。 * `hz`: キーの有効期限評価頻度 (`default: 100`)。この頻度が低いと、アイドル時の CPU 使用量が少なくなるが、その分古くなったキーをクリアする速度が遅くなる。 * `primary_port_http_enabled`: もし `true` (`default: true`) なら、メイン TCP ポートで HTTP コンソールにアクセスできるようにする。 * `admin_port`: 割り当てられたポートのコンソールへの管理者アクセスを有効にする(`default: disabled`)。HTTP と RESP プロトコルの両方をサポートする。 * `admin_bind`: 管理コンソールの TCP 接続を指定されたアドレスにバインドする(`default: any`)。HTTP と RESP の両方のプロトコルをサポートする。 * `admin_nopass`: 割り当てられたポートで、認証トークンなしでコンソールへのオープン管理アクセスを有効にする (`default: false`)。HTTP と RESP の両方のプロトコルをサポートする。 * `cluster_mode`: サポートするクラスターモード (`default: ""`)。現在は `emulated` のみをサポートしている。 * `cluster_announce_ip`: クラスタコマンドがクライアントにアナウンスする IP。 ### 一般的なオプションを使用した開始スクリプトの例: ```bash ./dragonfly-x86_64 --logtostderr --requirepass=youshallnotpass --cache_mode=true -dbnum 1 --bind localhost --port 6379 --maxmemory=12gb --keys_output_limit=12288 --dbfilename dump.rdb ``` また、`dragonfly --flagfile ` を実行することで、設定ファイルから引数を指定することもできる。ファイルには 1 行に 1 つのフラグを記述し、キーと値のフラグには空白の代わりに等号を記述します。 ログの管理や TLS のサポートなど、その他のオプションについては `dragonfly --help` を実行してください。 ## ロードマップとステータス Dragonfly は現在、~185 個の Redis コマンドと、`cas` 以外のすべての Memcached コマンドをサポートしている。ほぼ Redis 5 API と同等ですが、Dragonfly の次のマイルストーンは基本的な機能を安定させ、レプリケーション API を実装することです。まだ実装されていないコマンドで必要なものがあれば、issue を開いてください。 Dragonfly ネイティブのレプリケーションについては、桁違いに高速な分散ログフォーマットを設計中です。 レプリケーション機能に続いて、Redis バージョン 3-6 の API に不足しているコマンドを追加していく予定です。 現在 Dragonfly がサポートしているコマンドについては、[コマンドリファレンス](https://dragonflydb.io/docs/category/command-reference)をご覧ください。 ## デザイン決定 ### 斬新なキャッシュデザイン Dragonfly には、シンプルでメモリ効率の良い、単一の統一された適応型キャッシュアルゴリズムがあります。 `cache_mode=true` フラグを渡すことでキャッシュモードを有効にすることができます。このモードをオンにすると、Dragonfly は将来つまずく可能性が最も低いアイテムを退避させますが、`maxmemory` の限界に近づいたときのみ退避させます。 ### 比較的正確な有効期限 有効期限は 8 年以内。 ミリ秒精度の有効期限(PEXPIRE、PSETEX など)は、**2^28ms** を超える期限については、最も近い秒に丸められます。この誤差は 0.001% 以下であり、大きな範囲であれば許容範囲となります。 Dragonfly の期限と Redis の実装の詳細な違いについては、[こちら](docs/differences.md)を参照してください。 ### ネイティブ HTTP コンソールと Prometheus 互換メトリクス デフォルトでは、Dragonfly はメイン TCP ポート(6379)経由での HTTP アクセスを許可しています。その通り、Redis プロトコル経由でも HTTP プロトコル経由でも Dragonfly に接続することができます。ブラウザで試してみてください。HTTP アクセスには現在あまり情報がありませんが、将来的にはデバッグや管理に役立つ情報が含まれるようになる予定です。 Prometheus 互換のメトリクスを見るには、URL `:6379/metrics` にアクセスしてください。 Prometheus からエクスポートされたメトリクスは Grafana ダッシュボードと互換性があります[こちらを参照](tools/local/monitoring/grafana/provisioning/dashboards/dashboard.json)。 重要です!HTTP コンソールは安全なネットワーク内でアクセスすることを想定しています。Dragonfly の TCP ポートを外部に公開する場合は、`--http_admin_console=false` または `--nohttp_admin_console` でコンソールを無効にすることをお勧めします。 ## バックグラウンド Dragonfly は、インメモリデータストアを 2022 年に設計したらどのようになるかという実験から始まりました。メモリストアのユーザーとして、またクラウド企業で働いたエンジニアとしての経験から学んだ教訓をもとに、Dragonfly では 2 つの重要な特性を維持する必要があると考えました: それは、すべてのオペレーションにおける原子性の保証と、非常に高いスループットにおけるミリ秒以下の低レイテンシーです。 私たちの最初の課題は、パブリッククラウドで現在利用可能なサーバーを使用して、CPU、メモリー、I/O リソースをフルに活用する方法でした。これを解決するために、私たちは[シェアードナッシングアーキテクチャ](https://en.wikipedia.org/wiki/Shared-nothing_architecture)を使用しています。このアーキテクチャでは、各スレッドが辞書データのスライスを独自に管理できるように、スレッド間でメモリストアの鍵空間を分割することができます。これらのスライスを "shards" と呼ぶ。シェアードナッシングアーキテクチャのスレッドと I/O 管理のためのライブラリは、[こちら](https://github.com/romange/helio)でオープンソースで提供されています。 複数キー操作に対する原子性保証を提供するために、我々は最近の学術研究の進歩を利用している。Dragonfly のトランザクションフレームワークの開発には、論文 ["VLL: a lock manager redesign for main memory database systems"](https://www.cs.umd.edu/~abadi/papers/vldbj-vll.pdf) を選びました。シェアードナッシングアーキテクチャと VLL の選択により、ミューテックスやスピンロックを使用せずにアトミックなマルチキー操作を構成することができました。これは我々の PoC にとって大きなマイルストーンであり、その性能は他の商用やオープンソースのソリューションよりも際立っていました。 私たちの第二の課題は、新しいストアのために、より効率的なデータ構造を設計することだった。この目標を達成するために、我々は論文 ["Dash: Scalable Hashing on Persistent Memory"](https://arxiv.org/pdf/2003.07302.pdf) に基づいたハッシュテーブル構造を核とした。この論文自体は、永続メモリ領域を中心にしており、メインメモリストアとは直接関係ありませんが、それでも私たちの問題に最も当てはまります。この論文で提案されているハッシュテーブル設計により、Redis の辞書に存在する 2 つの特別な特性を維持することができました: それは、データストアの成長中にハッシュをインクリメンタルする機能と、ステートレススキャン操作を使って変更中の辞書をトラバースする機能です。これら2つの特性に加え、Dash は CPU とメモリの使用効率が高い。Dash の設計を活用することで、私たちは以下のような機能をさらに革新することができました: * TTL レコードの効率的なレコード期限切れ。 * LRU や LFU のような他のキャッシュ戦略よりも高いヒット率を、**ゼロメモリオーバーヘッド** で達成する新しいキャッシュエビクションアルゴリズム。 * 新しい **フォークレス** スナップショットアルゴリズム。 Dragonfly の基盤を構築し、[そのパフォーマンスに満足したら](#ベンチマーク)、Redis と Memcached の機能を実装していきました。現在までに 185 個の Redis コマンド(Redis 5.0 API とほぼ同等)と 13 個の Memcached コマンドを実装しました。 そして最後に、
私たちの使命は、最新のハードウェアの進歩を活用した、クラウドワークロード向けの、優れた設計、超高速、コスト効率の良いインメモリデータストアを構築することです。現在のソリューションの API と提案を維持しながら、その問題点を解決するつもりです。 ================================================ FILE: README.ko-KR.md ================================================

Dragonfly

[![ci-tests](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml/badge.svg)](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml) [![Twitter URL](https://img.shields.io/twitter/follow/dragonflydbio?style=social)](https://twitter.com/dragonflydbio) 다른 언어 번역본: [English](README.zh-CN.md) [简体中文](README.zh-CN.md) [日本語](README.ja-JP.md) [Português](README.pt-BR.md) [Website](https://www.dragonflydb.io/) • [Docs](https://dragonflydb.io/docs) • [Quick Start](https://www.dragonflydb.io/docs/getting-started) • [Community Discord](https://discord.gg/HsPjXGVH85) • [Dragonfly Forum](https://dragonfly.discourse.group/) • [Join the Dragonfly Community](https://www.dragonflydb.io/community) [GitHub Discussions](https://github.com/dragonflydb/dragonfly/discussions) • [GitHub Issues](https://github.com/dragonflydb/dragonfly/issues) • [Contributing](https://github.com/dragonflydb/dragonfly/blob/main/CONTRIBUTING.md) • [Dragonfly Cloud](https://www.dragonflydb.io/cloud) ## 세상에서 가장 빠른 인-메모리 스토어 Dragonfly는 현대 애플리케이션 작업을 위한 인-메모리 데이터스토어입니다. Dragonfly는 Redis와 Memcached API와 완벽하게 호환되며, 이를 적용하기 위한 코드 변경을 필요로 하지 않습니다. Dragonfly는 기존 레거시 인-메모리 데이터스토어와 비교하여 25배 이상의 높은 처리량과 캐시 히트율, 낮은 꼬리 지연시간을 갖고있으며 간편한 수직 확장성을 지니고 있습니다. ## 콘텐츠 - [벤치마크](#benchmarks) - [빠른 시작](https://github.com/dragonflydb/dragonfly/tree/main/docs/quick-start) - [설정](#configuration) - [로드맵과 상태](#roadmap-status) - [설계 의사결정](#design-decisions) - [개발 배경](#background) ## 벤치마크 벤치마크에 따르면, Dragonfly는 레디스와 비교하여 처리량이 25배이상 증가하였고, c6gn.16xlarge 인스턴스에서 3.8M QPS를 돌파하였음을 보여줍니다. Dragonfly의 피크 처리량에서의 99퍼센트 지연 시간 지표: | op | r6g | c6gn | c7g | |-------|-------|-------|-------| | set | 0.8ms | 1ms | 1ms | | get | 0.9ms | 0.9ms | 0.8ms | | setex | 0.9ms | 1.1ms | 1.3ms | *모든 벤치마크는 서버 및 인스턴스 유형별로 조정된 스레드 수를 사용하여 `memtier_benchmark`(아래를 참고) 수행되었습니다. `memtier`는 별도의 c6gn.16xlarge 머신에서 실행되었습니다. 저희는 테스트 종료 이후에도 유효하게 유지되도록 보장하기 위해 SETEX 벤치마크의 만료 시간을 500으로 설정하였습니다.* ```bash memtier_benchmark --ratio ... -t -c 30 -n 200000 --distinct-client-seed -d 256 \ --expiry-range=... ``` 파이프라인 모드에서 `--pipeline=30`은 Dragonfly가 SET 연산으로 **10M QPS**, GET 연산으로 **15M QPS**에 도달할 수 있음을 나타냅니다. ### Dragonfly vs. Memcached 저희는 AWS의 c6gn.16xlarge 인스턴스에서 Dragonfly와 Memcached를 비교하는 작업을 수행했습니다. 비슷한 지연시간을 가진 상황에서, Dragonfly의 처리량은 쓰기 및 읽기 작업 모두에서 Memcached보다 성능이 뛰어났습니다. 쓰기 작업에서는 [Memcached의 쓰기 경로](docs/memcached_benchmark.md)에서의 경합으로 인하여 Dragonfly가 보다 적은 지연시간을 보였다는 점이 입증되었습니다. #### SET 벤치마크 | Server | QPS(thousands qps) | latency 99% | 99.9% | |:---------:|:------------------:|:-----------:|:-------:| | Dragonfly | 🟩 3844 |🟩 0.9ms | 🟩 2.4ms | | Memcached | 806 | 1.6ms | 3.2ms | #### GET 벤치마크 | Server | QPS(thousands qps) | latency 99% | 99.9% | |-----------|:------------------:|:-----------:|:-------:| | Dragonfly | 🟩 3717 | 1ms | 2.4ms | | Memcached | 2100 | 🟩 0.34ms | 🟩 0.6ms | Memcached는 읽기 벤치마크의 지연 시간은 적었지만, 처리량도 낮았습니다. ### 메모리 효율 메모리 효율을 테스트하기 위해서, 저희는 `debug populate 5000000 key 1024` 명령어를 활용하여 Dragonfly와 Redis에 ~5GB 정도의 데이터를 채운 후, `memtier` 를 통하여 업데이트 트래픽을 전송한 후, `bgsave` 명령을 통하여 스냅샷을 시작했습니다. 이 그림은 메모리 효율 측면에서 각 서버가 어떻게 동작했는지 보여줍니다. Dragonfly는 유휴 상태에서 Redis보다 메모리 효율이 30% 더 좋았으며, 스냅샷 단계에서 메모리 사용량이 눈에 띄게 증가하지 않았습니다. Redis는 고점에서 Dragonfly에 비해 메모리 사용량이 약 3배 증가하였습니다. Dragonfly는 스냅샷 단계를 몇 초안에 더 빨리 마쳤습니다. Dragonfly의 메모리 효율에 대한 정보가 더 필요하시다면, 저희의 [Dashtable 문서](/docs/dashtable.md)를 참고하시기 바랍니다. ## 설정 Dragonfly는 적용 가능한 Redis 인수를 지원합니다. 예를 들면, `dragonfly --requirepass=foo --bind localhost`와 같은 명령어를 사용할 수 있습니다. Dragonfly는 현재 아래와 같은 Redis 인수들을 지원합니다 : * `port`: Redis 연결 포트 (`기본값: 6379`). * `bind`: `localhost`를 사용하여 로컬호스트 연결만 허용하거나 공용 IP 주소를 사용하여 해당 IP 주소에 연결을 허용합니다.(즉, 외부에서도 가능) * `requirepass`: AUTH 인증을 위한 패스워드 (`기본값: ""`). * `maxmemory`: 데이터베이스에서 사용하는 최대 메모리 제한(사람이 읽을 수 있는 바이트 단위) (`기본값: 0`). `maxmemory` 의 값이 `0` 이면 프로그램이 최대 메모리 사용량을 자동으로 결정합니다. * `dir`: Dragonfly Docker는 스냅샷을 위해 기본적으로 `/data` 폴더를 사용하고, CLI은 `""`을 사용합니다. Docker 옵션인 `-v` 을 통해서 호스트 폴더에 매핑할 수 있습니다. * `dbfilename`: 저장하고 불러올 데이터베이스 파일 이름 (`기본값: dump`). 아래는 Dragonfly 전용 인수 입니다 : * `memcached_port`: Memcached 호환 API를 위한 포트 (`기본값: disabled`). * `keys_output_limit`: `keys` 명령을 통해 반환 되는 최대 키의 수 (`기본값: 8192`). `keys` 명령은 위험하기 때문에, 너무 많은 키를 가져올 때 메모리 사용량이 급증하지 않도록 결과를 해당 인수만큼 잘라냅니다. * `dbnum`: `select` 명령에 대해 지원되는 최대 데이터베이스 수. * `cache_mode`: 아래의 섹션 [새로운 캐시 설계](#novel-cache-design)을 참고해주시기 바랍니다. * `hz`: 키가 만료되었는지를 판단하는 빈도(`기본값: 100`). 낮은 빈도는 키 방출이 느려지는 대신, 유휴 상태일 때 CPU 사용량을 줄입니다. * `primary_port_http_enabled`: `true` 인 경우 HTTP 콘솔로 메인 TCP 포트 접근을 허용합니다. (`기본값: true`). * `admin_port`: 할당된 포트에서 관리자 콘솔 접근을 활성화합니다. (`기본값: disabled`). HTTP와 RESP 프로토콜 모두를 지원합니다. * `admin_bind`: 주어진 주소에 관리자 콘솔 TCP 연결을 바인딩합니다. (`기본값: any`). HTTP와 RESP 프로토콜 모두를 지원합니다. * `admin_nopass`: 할당된 포트에 대해서 인증 토큰 없이 관리자 콘솔 접근을 활성화합니다. (`default: false`). HTTP와 RESP 프로토콜 모두를 지원합니다. * `cluster_mode`: 클러스터 모드가 지원됩니다. (`기본값: ""`). 현재는`emulated` 만 지원합니다. * `cluster_announce_ip`: 클러스터 명령을 클라이언트에게 알리는 IP 주소. ### 주요 옵션을 활용한 실행 스크립트 예시: ```bash ./dragonfly-x86_64 --logtostderr --requirepass=youshallnotpass --cache_mode=true -dbnum 1 --bind localhost --port 6379 --maxmemory=12gb --keys_output_limit=12288 --dbfilename dump.rdb ``` 인수들은 `dragonfly --flagfile `을 실행하여 설정 파일을 통해서도 전달할 수 있습니다. 전달될 파일은 각 줄에 키-값 형태의 플래그 나열 하기위해 등호를 사용합니다. 로그 관리나 TLS 지원과 같은 추가 옵션을 확인하고 싶다면, `dragonfly --help` 를 실행해보시길 바랍니다. ## 로드맵과 상태 Dragonfly는 현재 ~185개의 Redis 명령어들과 `cas` 뿐만 아니라 모든 Memcached 명령어를 지원합니다. 이는 거의 Redis 5 API와 동등하며, Dragonfly의 다음 마일스톤은 기본 기능 을 안정화하고 복제 API를 구현하는 것입니다. 아직 구현되지 않은 필요한 명령가 있다면, 이슈를 오픈해주세요. Draginfly 고유 복제기능을 위해, 저희는 몇 배 높은 속도를 지원할 수 있는 분산 로그 형식을 설계하고 있습니다. 복제 기능을 추가한 뒤에 저희는 Redis 3-6 API에 해당되는 누락 명령어들을 계속 추가할 예정입니다. Dragonfly에 의해 현재 지원되는 명령어를 확인하기 위해서 [명령어 레퍼런스](https://dragonflydb.io/docs/category/command-reference)를 참고해주시기 바랍니다. ## 설계 의사결정 ### 새로운 캐시 설계 Dragonfly는 단순하고 메모리 효율적인 단일, 통합, 적응형 캐싱 알고리즘을 제공합니다. `--cache_mode=true` 플래그를 전달하여 캐싱 모드를 활성화할 수 있습니다. 이 모드가 활성화되면, Dragonfly는 `maxmemory` 한도에 가까워질 때만, 미래에 재사용 될 가능성이 가장 낮은 항목을 방출합니다. ### 상대적인 정확성을 가진 만료 기한 만료 범위는 약 ~8년으로 제한됩니다. 밀리초 단위의 정밀한 만료 기한(PEXPIRE, PSETEX, 등)은 **2^28ms보다 큰 기한에 대해** 가장 가까운 초로 반올림됩니다. 이는 0.001% 미만의 오차를 가지며, 큰 범위에 대해 적용될 때는 수용 가능한 수준입니다. 만약 이런 방식이 사용사례에 적합하지 않다면, 문의를 주시거나 해당 사용사례를 설명하는 이슈를 오픈해주세요. Dragonfly와 Redis의 만료 기한에 대한 구현의 차이는 [여기서 확인하실 수 있습니다](docs/differences.md). ### 네이티브 HTTP 콘솔과 Prometheus 호환 매트릭 기본적으로, Dragonfly는 메인 TCP 포트(6379)에 HTTP 접근을 허용합니다. 즉, Redis 프로토콜과 HTTP 프로토콜 모두를 통해 Dragonfly에 연결할 수 있습니다. - 서버는 연결 초기화 과정에서 프로토콜을 자동으로 인식합니다. 웹 브라우저를 통하여 시도해보시기 바랍니다. 현재 HTTP 접근은 많은 정보를 제공하지 않지만, 유용한 디버깅 및 관리 정보를 향후 추가할 예정입니다. `:6379/metrics` 에 접근하게 되면, Prometheus 호환 매트릭을 확인할 수 있습니다. Prometheus에서 내보내는 매트릭들은 Grafana 대시보드와 호환됩니다. 자세한 내용은 [여기](tools/local/monitoring/grafana/provisioning/dashboards/dashboard.json)를 참조해주세요. 중요! HTTP 노솔은 안전한 네트워크 내에서 접근하도록 설계되었습니다. Dragonfly의 TCP 포트를 외부로 노출한다면, `--http_admin_console=false` 혹은 `--nohttp_admin_console`과 같은 인수를 활용하여 콘솔을 비활성화하는 것을 조언해드립니다. ## 개발배경 Dragonfly는 2022년에 인-메모리 데이터스토어를 설계한다면 어땠을까에 대한 실험으로 시작되었습니다. 클라우드 회사에서 근무한 엔지니어 및 메모리 스토어 사용자의 경험을 바탕으로, 저희는 Dragonfly에 핵심적인 두 가지 핵심 특성을 보존해야함을 알았습니다: 모든 작업에 대한 원자성 보장과 매우 높은 처리량에 대한 밀리초 이하의 낮은 지연 시간을 보장하는 것이었습니다. 첫 번째 문제는 오늘날 퍼블릭 클라우드 환경에서 사용 가능한 서버를 사용하여 CPU, 메모리 및 I/O 자원을 어떻게 최대한 활용할 수 있을지였습니다. 이 문제를 해결하기 위해 저희는 [비공유 아키텍처(Shared Nothing Architecture)](https://en.wikipedia.org/wiki/Shared-nothing_architecture)를 사용했습니다. 이는 저희가 메모리 스토어의 각 스레드 사이의 키 공간을 분할할 수 있게하였습니다. 이를 통해 각 스레드들은 그들의 딕셔너리 데이터들의 조각을 관리할 수 있게 되었습니다. 저희는 이 조각들을 "샤드(shards)"라 불렀습니다. 비공유 아키텍처에 대한 스레드 및 I/O 관리를 위한 라이브러리는 [여기](https://github.com/romange/helio)에서 오픈소스로 제공됩니다. 멀티-키 작업에 대한 원자성 보장을 위해, Dragonfly의 트랜잭션 프레임워크를 개발하기 위해 저희는 최근 학계의 연구 발전을 활용했고 ["VLL: a lock manager redesign for main memory database systems”](https://www.cs.umd.edu/~abadi/papers/vldbj-vll.pdf) 논문을 채택했습니다. 비공유 아키텍처와 VLL의 선택은 우리가 뮤텍스나 스핀락을 사용하지 않고도 원자적 멀티-키 작업을 구성할 수 있게 했습니다. 이것은 저희의 PoC에 있어서 주요한 마일스톤이었고, 그 성능은 다른 상용 및 오픈소스 솔루션보다 성능이 뛰어났습니다. 두 번째 문제는 새로운 저장소를 위하여 더 효율적인 데이터 구조를 설계하는 것이었습니다. 이 목표를 달성하기 위해서 저희는 핵심 해시테이블 구조를 ["Dash: Scalable Hashing on Persistent Memory"](https://arxiv.org/pdf/2003.07302.pdf) 논문을 기반으로 작업했습니다. 이 논문은 영속적인 메모리 도메인을 중심으로 다루며, 이는 메인-메모리 저장소와 직접적인 연관관계는 없었습니다. 하지만 여전히 저희 문제를 해결하기 위해서 가장 적합했습니다. 해당 논문의 제안된 해시테이블 설계는 저희가 레디스 딕셔너리에 표현된 두 가지 특별한 특성을 유지 가능하게 해줬습니다: 데이터스토어 확장 중 증분 해싱 기능과 상태 없는 스캔 작업을 사용하여 변화하는 딕셔너리를 순회하는 능력이었습니다. 이 두 가지 속성 외에도 Dash는 CPU와 메모리 사용에서 더 효율적입니다. 저희는 다음과 같은 기능들로 더욱 혁신할 수 있었습니다: * TTL 레코드에 대한 효율적인 만료 처리 * LRU와 LFU 같은 다른 캐시 전략보다 더 높은 히트율을 달성하는 새로운 캐시 방출 알고리즘과 **제로 메모리 오버헤드**. * 새로운 **fork-less** 스냅샷 알고리즘. 저희는 Dragonfly의 기반을 구축하고 성능에 만족하게 되었을 때, Redis와 Memcached의 기능을 구현하기 시작했습니다. 저희는 약 185개의 Redis 명령(대략적으로 Redis 5.0 API와 동등)과 13개의 Memecached 명령을 구현했습니다. 마지막으로,
저희의 임무는 최신 하드웨어 발전을 활용하는 클라우드 작업을 위한 멋진 설계와 초고속 처리량 그리고 비용효율적인 인-메모리 데이터스토어를 만드는 것입니다. 저희는 현재 솔루션의 제품 API들이나 제안을 유지하면서 당면 과제를 해결하고자 합니다. ================================================ FILE: README.md ================================================

Dragonfly

[![ci-tests](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml/badge.svg)](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml) [![Twitter URL](https://img.shields.io/twitter/follow/dragonflydbio?style=social)](https://twitter.com/dragonflydbio) > Before moving on, please consider giving us a GitHub star ⭐️. Thank you! Other languages: [简体中文](README.zh-CN.md) [日本語](README.ja-JP.md) [한국어](README.ko-KR.md) [Português](README.pt-BR.md) [Website](https://www.dragonflydb.io/) • [Docs](https://dragonflydb.io/docs) • [Quick Start](https://www.dragonflydb.io/docs/getting-started) • [Community Discord](https://discord.gg/HsPjXGVH85) • [Dragonfly Forum](https://dragonfly.discourse.group/) • [Join the Dragonfly Community](https://www.dragonflydb.io/community) [GitHub Discussions](https://github.com/dragonflydb/dragonfly/discussions) • [GitHub Issues](https://github.com/dragonflydb/dragonfly/issues) • [Contributing](https://github.com/dragonflydb/dragonfly/blob/main/CONTRIBUTING.md) • [AI Agents Guide](AGENTS.md) • [Dragonfly Cloud](https://www.dragonflydb.io/cloud) ## The world's most efficient in-memory data store Dragonfly is an in-memory data store built for modern application workloads. Fully compatible with Redis and Memcached APIs, Dragonfly requires no code changes to adopt. Compared to legacy in-memory datastores, Dragonfly delivers 25X more throughput, higher cache hit rates with lower tail latency, and can run on up to 80% less resources for the same sized workload. ## Contents - [Benchmarks](#benchmarks) - [Quick start](https://github.com/dragonflydb/dragonfly/tree/main/docs/quick-start) - [Configuration](#configuration) - [Roadmap and status](#roadmap-status) - [Design decisions](#design-decisions) - [Background](#background) - [Build from source](./docs/build-from-source.md) ## Benchmarks We first compare Dragonfly with Redis on `m5.large` instance which is commonly used to run Redis due to its single-threaded architecture. The benchmark program runs from another load-test instance (c5n) in the same AZ using `memtier_benchmark -c 20 --test-time 100 -t 4 -d 256 --distinct-client-seed` Dragonfly shows a comparable performance: 1. SETs (`--ratio 1:0`): | Redis | DF | | -----------------------------------------|----------------------------------------| | QPS: 159K, P99.9: 1.16ms, P99: 0.82ms | QPS:173K, P99.9: 1.26ms, P99: 0.9ms | | | | 2. GETs (`--ratio 0:1`): | Redis | DF | | ----------------------------------------|----------------------------------------| | QPS: 194K, P99.9: 0.8ms, P99: 0.65ms | QPS: 191K, P99.9: 0.95ms, P99: 0.8ms | The benchmark above shows that the algorithmic layer inside DF that allows it to scale vertically does not take a large toll when running single-threaded. However, if we take a bit stronger instance (m5.xlarge), the gap between DF and Redis starts growing. (`memtier_benchmark -c 20 --test-time 100 -t 6 -d 256 --distinct-client-seed`): 1. SETs (`--ratio 1:0`): | Redis | DF | | ----------------------------------------|----------------------------------------| | QPS: 190K, P99.9: 2.45ms, P99: 0.97ms | QPS: 279K , P99.9: 1.95ms, P99: 1.48ms| 2. GETs (`--ratio 0:1`): | Redis | DF | | ----------------------------------------|----------------------------------------| | QPS: 220K, P99.9: 0.98ms , P99: 0.8ms | QPS: 305K, P99.9: 1.03ms, P99: 0.87ms | Dragonfly throughput capacity continues to grow with instance size, while single-threaded Redis is bottlenecked on CPU and reaches local maxima in terms of performance. If we compare Dragonfly and Redis on the most network-capable instance c6gn.16xlarge, Dragonfly showed a 25X increase in throughput compared to Redis single process, crossing 3.8M QPS. Dragonfly's 99th percentile latency metrics at its peak throughput: | op | r6g | c6gn | c7g | |-------|-------|-------|-------| | set | 0.8ms | 1ms | 1ms | | get | 0.9ms | 0.9ms | 0.8ms | | setex | 0.9ms | 1.1ms | 1.3ms | *All benchmarks were performed using `memtier_benchmark` (see below) with number of threads tuned per server and instance type. `memtier` was run on a separate c6gn.16xlarge machine. We set the expiry time to 500 for the SETEX benchmark to ensure it would survive the end of the test.* ```bash memtier_benchmark --ratio ... -t -c 30 -n 200000 --distinct-client-seed -d 256 \ --expiry-range=... ``` In pipeline mode `--pipeline=30`, Dragonfly reaches **10M QPS** for SET and **15M QPS** for GET operations. ### Dragonfly vs. Memcached We compared Dragonfly with Memcached on a c6gn.16xlarge instance on AWS. With a comparable latency, Dragonfly throughput outperformed Memcached throughput in both write and read workloads. Dragonfly demonstrated better latency in write workloads due to contention on the [write path in Memcached](docs/memcached_benchmark.md). #### SET benchmark | Server | QPS(thousands qps) | latency 99% | 99.9% | |:---------:|:------------------:|:-----------:|:-------:| | Dragonfly | 🟩 3844 |🟩 0.9ms | 🟩 2.4ms | | Memcached | 806 | 1.6ms | 3.2ms | #### GET benchmark | Server | QPS(thousands qps) | latency 99% | 99.9% | |-----------|:------------------:|:-----------:|:-------:| | Dragonfly | 🟩 3717 | 1ms | 2.4ms | | Memcached | 2100 | 🟩 0.34ms | 🟩 0.6ms | Memcached exhibited lower latency for the read benchmark, but also lower throughput. ### Memory efficiency To test memory efficiency, we filled Dragonfly and Redis with ~5GB of data using the `debug populate 5000000 key 1024` command, sent update traffic with `memtier`, and kicked off the snapshotting with the `bgsave` command. This figure demonstrates how each server behaved in terms of memory efficiency. Dragonfly was 30% more memory efficient than Redis in the idle state and did not show any visible increase in memory use during the snapshot phase. At peak, Redis memory use increased to almost 3X that of Dragonfly. Dragonfly finished the snapshot faster, within a few seconds. For more info about memory efficiency in Dragonfly, see our [Dashtable doc](/docs/dashtable.md). ## Configuration Dragonfly supports common Redis arguments where applicable. For example, you can run: `dragonfly --requirepass=foo --bind localhost`. Dragonfly currently supports the following Redis-specific arguments: * `port`: Redis connection port (`default: 6379`). * `bind`: Use `localhost` to only allow localhost connections or a public IP address to allow connections **to that IP** address (i.e. from outside too). Use `0.0.0.0` to allow all IPv4. * `requirepass`: The password for AUTH authentication (`default: ""`). * `maxmemory`: Limit on maximum memory (in human-readable bytes) used by the database (`default: 0`). A `maxmemory` value of `0` means the program will automatically determine its maximum memory usage. * `dir`: Dragonfly Docker uses the `/data` folder for snapshotting by default, the CLI uses `""`. You can use the `-v` Docker option to map it to your host folder. * `dbfilename`: The filename to save and load the database (`default: dump`). There are also some Dragonfly-specific arguments: * `memcached_port`: The port to enable Memcached-compatible API on (`default: disabled`). * `keys_output_limit`: Maximum number of returned keys in `keys` command (`default: 8192`). Note that `keys` is a dangerous command. We truncate its result to avoid a blowup in memory use when fetching too many keys. * `dbnum`: Maximum number of supported databases for `select`. * `cache_mode`: See the [novel cache design](#novel-cache-design) section below. * `hz`: Key expiry evaluation frequency (`default: 100`). Lower frequency uses less CPU when idle at the expense of a slower eviction rate. * `snapshot_cron`: Cron schedule expression for automatic backup snapshots using standard cron syntax with the granularity of minutes (`default: ""`). Here are some cron schedule expression examples below, and feel free to read more about this argument in our [documentation](https://www.dragonflydb.io/docs/managing-dragonfly/backups#the-snapshot_cron-flag). | Cron Schedule Expression | Description | |--------------------------|--------------------------------------------| | `* * * * *` | At every minute | | `*/5 * * * *` | At every 5th minute | | `5 */2 * * *` | At minute 5 past every 2nd hour | | `0 0 * * *` | At 00:00 (midnight) every day | | `0 6 * * 1-5` | At 06:00 (dawn) from Monday through Friday | * `primary_port_http_enabled`: Allows accessing HTTP console on main TCP port if `true` (`default: true`). * `admin_port`: To enable admin access to the console on the assigned port (`default: disabled`). Supports both HTTP and RESP protocols. * `admin_bind`: To bind the admin console TCP connection to a given address (`default: any`). Supports both HTTP and RESP protocols. * `admin_nopass`: To enable open admin access to console on the assigned port, without auth token needed (`default: false`). Supports both HTTP and RESP protocols. * `cluster_mode`: Cluster mode supported (`default: ""`). Currently supports only `emulated`. * `cluster_announce_ip`: The IP that cluster commands announce to the client. * `announce_port`: The port that cluster commands announce to the client, and to replication master. ### Example start script with popular options: ```bash ./dragonfly-x86_64 --logtostderr --requirepass=youshallnotpass --cache_mode=true -dbnum 1 --bind localhost --port 6379 --maxmemory=12gb --keys_output_limit=12288 --dbfilename dump.rdb ``` Arguments can be also provided via: * `--flagfile `: The file should list one flag per line, with equal signs instead of spaces for key-value flags. No quotes are needed for flag values. * Setting environment variables. Set `DFLY_x`, where `x` is the exact name of the flag, case sensitive. For more options like logs management or TLS support, run `dragonfly --help`. ## Roadmap and status Dragonfly currently supports ~185 Redis commands and all Memcached commands besides `cas`. Almost on par with the Redis 5 API, Dragonfly's next milestone will be to stabilize basic functionality and implement the replication API. If there is a command you need that is not implemented yet, please open an issue. For Dragonfly-native replication, we are designing a distributed log format that will support order-of-magnitude higher speeds. Following the replication feature, we will continue adding missing commands for Redis versions 3-6 APIs. Please see our [Command Reference](https://dragonflydb.io/docs/category/command-reference) for the current commands supported by Dragonfly. ## Design decisions ### Novel cache design Dragonfly has a single, unified, adaptive caching algorithm that is simple and memory efficient. You can enable caching mode by passing the `--cache_mode=true` flag. Once this mode is on, Dragonfly will evict items least likely to be stumbled upon in the future but only when it is near the `maxmemory` limit. ### Expiration deadlines with relative accuracy Expiration ranges are limited to ~8 years. Expiration deadlines with millisecond precision (PEXPIRE, PSETEX, etc.) are rounded to the closest second **for deadlines greater than 2^28ms**, which has less than 0.001% error and should be acceptable for large ranges. If this is not suitable for your use case, get in touch or open an issue explaining your case. For more detailed differences between Dragonfly expiration deadlines and Redis implementations, [see here](docs/differences.md). ### Native HTTP console and Prometheus-compatible metrics By default, Dragonfly allows HTTP access via its main TCP port (6379). That's right, you can connect to Dragonfly via Redis protocol and via HTTP protocol — the server recognizes the protocol automatically during the connection initiation. Go ahead and try it with your browser. HTTP access currently does not have much info but will include useful debugging and management info in the future. Go to the URL `:6379/metrics` to view Prometheus-compatible metrics. The Prometheus exported metrics are compatible with the Grafana dashboard, [see here](tools/local/monitoring/grafana/provisioning/dashboards/dashboard.json). Important! The HTTP console is meant to be accessed within a safe network. If you expose Dragonfly's TCP port externally, we advise you to disable the console with `--http_admin_console=false` or `--nohttp_admin_console`. ## Background Dragonfly started as an experiment to see how an in-memory datastore could look if it was designed in 2022. Based on lessons learned from our experience as users of memory stores and engineers who worked for cloud companies, we knew that we need to preserve two key properties for Dragonfly: Atomicity guarantees for all operations and low, sub-millisecond latency over very high throughput. Our first challenge was how to fully utilize CPU, memory, and I/O resources using servers that are available today in public clouds. To solve this, we use [shared-nothing architecture](https://en.wikipedia.org/wiki/Shared-nothing_architecture), which allows us to partition the keyspace of the memory store between threads so that each thread can manage its own slice of dictionary data. We call these slices "shards". The library that powers thread and I/O management for shared-nothing architecture is open-sourced [here](https://github.com/romange/helio). To provide atomicity guarantees for multi-key operations, we use the advancements from recent academic research. We chose the paper ["VLL: a lock manager redesign for main memory database systems”](https://www.cs.umd.edu/~abadi/papers/vldbj-vll.pdf) to develop the transactional framework for Dragonfly. The choice of shared-nothing architecture and VLL allowed us to compose atomic multi-key operations without using mutexes or spinlocks. This was a major milestone for our PoC and its performance stood out from other commercial and open-source solutions. Our second challenge was to engineer more efficient data structures for the new store. To achieve this goal, we based our core hashtable structure on the paper ["Dash: Scalable Hashing on Persistent Memory"](https://arxiv.org/pdf/2003.07302.pdf). The paper itself is centered around the persistent memory domain and is not directly related to main-memory stores, but it's still most applicable to our problem. The hashtable design suggested in the paper allowed us to maintain two special properties that are present in the Redis dictionary: The incremental hashing ability during datastore growth the ability to traverse the dictionary under changes using a stateless scan operation. In addition to these two properties, Dash is more efficient in CPU and memory use. By leveraging Dash's design, we were able to innovate further with the following features: * Efficient record expiry for TTL records. * A novel cache eviction algorithm that achieves higher hit rates than other caching strategies like LRU and LFU with **zero memory overhead**. * A novel **fork-less** snapshotting algorithm. Once we had built the foundation for Dragonfly and [we were happy with its performance](#benchmarks), we went on to implement the Redis and Memcached functionality. We have to date implemented ~185 Redis commands (roughly equivalent to Redis 5.0 API) and 13 Memcached commands. And finally,
Our mission is to build a well-designed, ultra-fast, cost-efficient in-memory datastore for cloud workloads that takes advantage of the latest hardware advancements. We intend to address the pain points of current solutions while preserving their product APIs and propositions. ================================================ FILE: README.pt-BR.md ================================================

Dragonfly

[![ci-tests](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml/badge.svg)](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml) [![Twitter URL](https://img.shields.io/twitter/follow/dragonflydbio?style=social)](https://twitter.com/dragonflydbio) > Antes de continuar, considere deixar uma estrela no nosso repositório ⭐️. Obrigado! Outros idiomas: [简体中文](README.zh-CN.md) [日本語](README.ja-JP.md) [한국어](README.ko-KR.md) [English](README.md) [Site oficial](https://www.dragonflydb.io/) • [Documentação](https://dragonflydb.io/docs) • [Guia Rápido](https://www.dragonflydb.io/docs/getting-started) • [Discord da Comunidade](https://discord.gg/HsPjXGVH85) • [Fórum Dragonfly](https://dragonfly.discourse.group/) • [Participe da Comunidade](https://www.dragonflydb.io/community) [Discussões no GitHub](https://github.com/dragonflydb/dragonfly/discussions) • [Issues no GitHub](https://github.com/dragonflydb/dragonfly/issues) • [Contribuindo](https://github.com/dragonflydb/dragonfly/blob/main/CONTRIBUTING.md) • [Dragonfly Cloud](https://www.dragonflydb.io/cloud) ## O armazenamento de dados em memória mais eficiente do mundo Dragonfly é um armazenamento de dados em memória projetado para cargas de trabalho modernas. Totalmente compatível com as APIs do Redis e Memcached, o Dragonfly não requer alterações de código para adoção. Em comparação com armazenamentos legados, o Dragonfly oferece 25x mais throughput, maiores taxas de acerto em cache com menor latência de cauda e pode operar com até 80% menos recursos para a mesma carga. ## Conteúdo - [Benchmarks](#benchmarks) - [Guia rápido](https://github.com/dragonflydb/dragonfly/tree/main/docs/quick-start) - [Configuração](#configuration) - [Roteiro e status](#roadmap-status) - [Decisões de design](#design-decisions) - [Contexto](#background) - [Compilação a partir do código-fonte](./docs/build-from-source.md) ## Benchmarks Primeiro comparamos o Dragonfly com o Redis em uma instância `m5.large`, frequentemente usada para rodar Redis devido à sua arquitetura single-threaded. O benchmark roda de outra instância de carga (c5n) na mesma AZ usando `memtier_benchmark -c 20 --test-time 100 -t 4 -d 256 --distinct-client-seed`. O Dragonfly mostra desempenho comparável: 1. SETs (`--ratio 1:0`): | Redis | DF | | ------------------------------------- | ------------------------------------ | | QPS: 159K, P99.9: 1.16ms, P99: 0.82ms | QPS: 173K, P99.9: 1.26ms, P99: 0.9ms | 2. GETs (`--ratio 0:1`): | Redis | DF | | ------------------------------------ | ------------------------------------ | | QPS: 194K, P99.9: 0.8ms, P99: 0.65ms | QPS: 191K, P99.9: 0.95ms, P99: 0.8ms | O benchmark mostra que a camada algorítmica do DF, que permite escalabilidade vertical, não gera sobrecarga significativa em execução single-thread. Com uma instância mais forte (m5.xlarge), a diferença entre DF e Redis cresce. (`memtier_benchmark -c 20 --test-time 100 -t 6 -d 256 --distinct-client-seed`): 1. SETs (`--ratio 1:0`): | Redis | DF | | ------------------------------------- | ------------------------------------- | | QPS: 190K, P99.9: 2.45ms, P99: 0.97ms | QPS: 279K, P99.9: 1.95ms, P99: 1.48ms | 2. GETs (`--ratio 0:1`): | Redis | DF | | ------------------------------------ | ------------------------------------- | | QPS: 220K, P99.9: 0.98ms, P99: 0.8ms | QPS: 305K, P99.9: 1.03ms, P99: 0.87ms | A capacidade de throughput do Dragonfly cresce com o tamanho da instância, enquanto o Redis single-thread atinge o limite de CPU. Na instância c6gn.16xlarge (maior capacidade de rede), o Dragonfly atinge 25x mais throughput que o Redis, superando 3.8M QPS. Latência de 99% no pico de throughput do Dragonfly: | op | r6g | c6gn | c7g | | ----- | ----- | ----- | ----- | | set | 0.8ms | 1ms | 1ms | | get | 0.9ms | 0.9ms | 0.8ms | | setex | 0.9ms | 1.1ms | 1.3ms | _Todos os benchmarks foram realizados com `memtier_benchmark`, ajustando o número de threads conforme a instância. O `memtier` rodava em uma c6gn.16xlarge separada. No benchmark SETEX, foi definido tempo de expiração de 500 para garantir sobrevivência até o final do teste._ ```bash memtier_benchmark --ratio ... -t -c 30 -n 200000 --distinct-client-seed -d 256 \ --expiry-range=... ``` Em modo pipeline `--pipeline=30`, o Dragonfly alcança **10M QPS** em SET e **15M QPS** em GET. ### Dragonfly vs. Memcached Comparamos Dragonfly e Memcached em uma c6gn.16xlarge na AWS. Com latência comparável, o throughput do Dragonfly superou o do Memcached tanto em leitura quanto escrita. Em escrita, a latência do Dragonfly foi melhor devido à contenção no [caminho de escrita do Memcached](docs/memcached_benchmark.md). #### Benchmark de SET | Servidor | QPS (milhares) | latência 99% | 99.9% | | :-------: | :------------: | :----------: | :------: | | Dragonfly | 🟩 3844 | 🟩 0.9ms | 🟩 2.4ms | | Memcached | 806 | 1.6ms | 3.2ms | #### Benchmark de GET | Servidor | QPS (milhares) | latência 99% | 99.9% | | --------- | :------------: | :----------: | :------: | | Dragonfly | 🟩 3717 | 1ms | 2.4ms | | Memcached | 2100 | 🟩 0.34ms | 🟩 0.6ms | Memcached teve menor latência em leitura, mas também menor throughput. ### Eficiência de memória Para testar a eficiência de memória, preenchemos o Dragonfly e o Redis com \~5GB de dados usando o comando `debug populate 5000000 key 1024`, enviamos tráfego de atualização com `memtier` e iniciamos o snapshot com o comando `bgsave`. A figura abaixo demonstra como cada servidor se comportou em termos de eficiência de memória. O Dragonfly foi 30% mais eficiente em memória que o Redis em estado ocioso e não apresentou aumento visível no uso de memória durante a fase de snapshot. No pico, o uso de memória do Redis aumentou para quase 3 vezes o do Dragonfly. O Dragonfly concluiu o snapshot mais rápido, em poucos segundos. Para mais informações sobre eficiência de memória no Dragonfly, veja nosso [documento sobre Dashtable](/docs/dashtable.md). ## Configuração O Dragonfly suporta argumentos comuns do Redis quando aplicável. Por exemplo, você pode executar: `dragonfly --requirepass=foo --bind localhost`. Atualmente, o Dragonfly suporta os seguintes argumentos específicos do Redis: - `port`: Porta de conexão Redis (`padrão: 6379`). - `bind`: Use `localhost` para permitir conexões apenas locais ou um IP público para permitir conexões **para esse IP** (ou seja, externas também). Use `0.0.0.0` para permitir todas as conexões IPv4. - `requirepass`: Senha para autenticação AUTH (`padrão: ""`). - `maxmemory`: Limite de memória máxima (em bytes legíveis) usada pelo banco (`padrão: 0`). Um valor `0` significa que o programa determinará automaticamente o uso máximo de memória. - `dir`: O Docker do Dragonfly usa a pasta `/data` para snapshots por padrão, o CLI usa `""`. Você pode usar a opção `-v` do Docker para mapear para uma pasta do host. - `dbfilename`: Nome do arquivo para salvar/carregar o banco de dados (`padrão: dump`). Também há argumentos específicos do Dragonfly: - `memcached_port`: Porta para habilitar API compatível com Memcached (`padrão: desabilitado`). - `keys_output_limit`: Número máximo de chaves retornadas no comando `keys` (`padrão: 8192`). Note que `keys` é um comando perigoso. Limitamos o resultado para evitar explosão de uso de memória ao buscar muitas chaves. - `dbnum`: Número máximo de bancos de dados suportados para `select`. - `cache_mode`: Veja a seção sobre [design de cache inovador](#novel-cache-design). - `hz`: Frequência de avaliação de expiração de chave (`padrão: 100`). Frequências menores usam menos CPU em idle, mas têm menor taxa de remoção. - `snapshot_cron`: Expressão cron para snapshots automáticos usando sintaxe cron padrão, com granularidade de minutos (`padrão: ""`). Exemplos: | Expressão Cron | Descrição | | -------------- | ----------------------------------- | | `* * * * *` | A cada minuto | | `*/5 * * * *` | A cada 5 minutos | | `5 */2 * * *` | No minuto 5 de cada 2 horas | | `0 0 * * *` | Às 00:00 (meia-noite) todos os dias | | `0 6 * * 1-5` | Às 06:00 (manhã) de segunda a sexta | - `primary_port_http_enabled`: Permite acesso ao console HTTP na porta TCP principal se `true` (`padrão: true`). - `admin_port`: Habilita acesso admin ao console na porta atribuída (`padrão: desabilitado`). Suporta protocolos HTTP e RESP. - `admin_bind`: Define o IP de binding do console admin (`padrão: qualquer`). Suporta HTTP e RESP. - `admin_nopass`: Habilita acesso admin sem autenticação (`padrão: false`). Suporta HTTP e RESP. - `cluster_mode`: Modo cluster suportado (`padrão: ""`). Atualmente só `emulated`. - `cluster_announce_ip`: IP que os comandos de cluster anunciam ao cliente. - `announce_port`: Porta que os comandos de cluster anunciam ao cliente e ao master de replicação. ### Exemplo de script de inicialização com opções populares: ```bash ./dragonfly-x86_64 --logtostderr --requirepass=youshallnotpass --cache_mode=true -dbnum 1 --bind localhost --port 6379 --maxmemory=12gb --keys_output_limit=12288 --dbfilename dump.rdb ``` Argumentos também podem ser passados via: - `--flagfile `: O arquivo deve conter um flag por linha, com `=` em vez de espaços para flags com valor. Não usar aspas. - Variáveis de ambiente. Use `DFLY_x`, onde `x` é o nome exato do flag (case sensitive). Para mais opções como logs ou suporte a TLS, execute `dragonfly --help`. ## Roadmap e status Atualmente o Dragonfly suporta \~185 comandos Redis e todos os comandos Memcached exceto `cas`. Já quase no nível da API do Redis 5, o próximo marco é estabilizar as funcionalidades básicas e implementar a API de replicação. Caso precise de um comando ainda não implementado, abra uma issue. Para replicação nativa do Dragonfly, estamos projetando um formato de log distribuído que suportará velocidades ordens de magnitude maiores. Após a replicação, continuaremos adicionando comandos faltantes das versões 3 a 6 do Redis. Consulte nossa [Referência de Comandos](https://dragonflydb.io/docs/category/command-reference) para a lista atual. ## Decisões de design ### Design de cache inovador O Dragonfly tem um algoritmo de cache adaptativo, unificado e simples, eficiente em memória. Você pode habilitar o modo cache com o flag `--cache_mode=true`. Esse modo remove itens menos prováveis de serem acessados no futuro, mas **somente** próximo ao limite de `maxmemory`. ### Expiração com precisão relativa Intervalos de expiração são limitados a \~8 anos. Deadlines com precisão de milissegundos (PEXPIRE, PSETEX etc.) são arredondadas para o segundo mais próximo **quando superiores a 2^28ms**, com erro menor que 0.001%. Se isso for inadequado, entre em contato ou abra uma issue explicando o caso. Para mais diferenças entre os deadlines do Dragonfly e do Redis, [clique aqui](docs/differences.md). ### Console HTTP nativo e métricas compatíveis com Prometheus Por padrão, o Dragonfly permite acesso HTTP via porta TCP principal (6379). Ou seja, você pode conectar via protocolo Redis ou HTTP — o servidor reconhece automaticamente o protocolo ao conectar. Acesse com o navegador. Hoje o console HTTP tem pouca informação, mas no futuro incluirá debug e info de gerenciamento. Acesse `:6379/metrics` para ver métricas Prometheus-compatíveis. As métricas são compatíveis com o dashboard do Grafana, [veja aqui](tools/local/monitoring/grafana/provisioning/dashboards/dashboard.json). Importante: o console HTTP deve ser acessado em rede segura. Se expor a porta TCP do Dragonfly externamente, desabilite o console com `--http_admin_console=false` ou `--nohttp_admin_console`. ## Contexto O Dragonfly começou como um experimento para repensar um datastore in-memory em 2022. Baseado em lições como usuários e engenheiros de cloud, sabíamos que dois princípios deveriam ser preservados: garantias de atomicidade e latência sub-millisecond sob alto throughput. Desafio 1: Utilizar ao máximo CPU, memória e I/O em servidores modernos. A solução foi adotar [arquitetura shared-nothing](https://en.wikipedia.org/wiki/Shared-nothing_architecture), particionando o keyspace entre threads. Chamamos os slices de “shards”. A biblioteca que gerencia threads e I/O foi open-sourceada [aqui](https://github.com/romange/helio). Para garantir atomicidade em operações multi-key, usamos avanços recentes da pesquisa acadêmica. Escolhemos o paper ["VLL: a lock manager redesign for main memory database systems"](https://www.cs.umd.edu/~abadi/papers/vldbj-vll.pdf) como base para o framework transacional. A combinação VLL + shared-nothing permitiu compor operações atômicas multi-key **sem mutex ou spinlock**. O resultado foi um PoC com performance superior a outras soluções. Desafio 2: Estruturas de dados mais eficientes. Baseamos o hashtable no paper ["Dash: Scalable Hashing on Persistent Memory"](https://arxiv.org/pdf/2003.07302.pdf). Mesmo voltado à memória persistente, foi aplicável. O design permitiu manter: - Hash incremental durante crescimento. - Scan stateless mesmo com mudanças. Além disso, o Dash é mais eficiente em uso de CPU/memória. Com esse design, inovamos ainda com: - Expiração eficiente para registros TTL. - Algoritmo de cache com mais hits que LRU/LFU com **zero overhead**. - Algoritmo de snapshot **sem fork**. Com essa base pronta e [performance satisfatória](#benchmarks), implementamos as APIs Redis e Memcached (\~185 comandos Redis, equivalente ao Redis 5.0, e 13 do Memcached). Por fim,
Nossa missão é construir um datastore in-memory rápido, eficiente e bem projetado para cargas em nuvem, aproveitando o hardware moderno. Queremos resolver as dores das soluções atuais mantendo APIs e propostas de valor. ================================================ FILE: README.zh-CN.md ================================================

Dragonfly

[![ci-tests](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml/badge.svg)](https://github.com/dragonflydb/dragonfly/actions/workflows/ci.yml) [![Twitter URL](https://img.shields.io/twitter/follow/dragonflydbio?style=social)](https://twitter.com/dragonflydbio) > 在您继续之前,请考虑给我们一个 GitHub 星标 ⭐️。谢谢! 其他语言: [English](README.md) [日本語](README.ja-JP.md) [한국어](README.ko-KR.md) [Português](README.pt-BR.md) [主页](https://dragonflydb.io/) • [快速入门](https://github.com/dragonflydb/dragonfly/tree/main/docs/quick-start) • [社区 Discord](https://discord.gg/HsPjXGVH85) • [Dragonfly 论坛](https://dragonfly.discourse.group/) • [加入 Dragonfly 社区](https://www.dragonflydb.io/community) [GitHub Discussions](https://github.com/dragonflydb/dragonfly/discussions) • [GitHub Issues](https://github.com/dragonflydb/dragonfly/issues) • [贡献指南](https://github.com/dragonflydb/dragonfly/blob/main/CONTRIBUTING.md) ## 全世界最快的内存数据库 Dragonfly是一种针对现代应用程序负荷需求而构建的内存数据库,完全兼容Redis和Memcached的 API,迁移时无需修改任何代码。相比于这些传统的内存数据库,Dragonfly提供了其25倍的吞吐量,高缓存命中率和低尾延迟,并且对于相同大小的工作负载运行资源最多可减少80%。 ## 目录 - [基准测试](#基准测试) - [快速入门](https://github.com/dragonflydb/dragonfly/tree/main/docs/quick-start) - [配置方法](#配置方法) - [开发路线和开发现状](#开发路线和开发现状) - [设计决策](#设计决策) - [开发背景](#开发背景) ## 基准测试 Dragonfly在c6gn.16xlarge上达到了每秒380万个查询(QPS),相比于Redis,吞吐量提高了25倍。 在Dragonfly的峰值吞吐量下,P99延迟如下: | op | r6g | c6gn | c7g | | ----- | ----- | ----- | ----- | | set | 0.8ms | 1ms | 1ms | | get | 0.9ms | 0.9ms | 0.8ms | | setex | 0.9ms | 1.1ms | 1.3ms | *所有基准测试均使用`memtier_benchmark`(见下文),根据服务器类型和实例类型调整线程数。`memtier`运行在独立的c6gn.16xlarge机器上。对于setex基准测试,我们使用了500的到期范围,以便其能够存活直到测试结束。* ```bash memtier_benchmark --ratio ... -t -c 30 -n 200000 --distinct-client-seed -d 256 \ --expiry-range=... ``` 当以管道模式运行,并设置参数`--pipeline=30`时,Dragonfly可以实现**10M qps**的SET操作和 **15M qps**的GET操作。 ### Memcached / Dragonfly 我们在 AWS 的 `c6gn.16xlarge` 实例上比较了 memcached 和 Dragonfly。如下图所示,与 memcached 相比,Dragonfly 的吞吐量在读写两方面上都占据了优势,并且在延迟方面也还不错。对于写入工作,Dragonfly 的延迟更低,这是由于在 memcached 的写入路径上存在竞争(请参见[此处](docs/memcached_benchmark.md))。 #### SET benchmark | Server | QPS(thousands qps) | latency 99% | 99.9% | | :-------: | :----------------: | :---------: | :-----: | | Dragonfly | 🟩 3844 | 🟩 0.9ms | 🟩 2.4ms | | Memcached | 806 | 1.6ms | 3.2ms | #### GET benchmark | Server | QPS(thousands qps) | latency 99% | 99.9% | | --------- | :----------------: | :---------: | :-----: | | Dragonfly | 🟩 3717 | 1ms | 2.4ms | | Memcached | 2100 | 🟩 0.34ms | 🟩 0.6ms | 对于读取基准测试,Memcached 表现出了更低的延迟,但在吞吐量方面比不上Dragonfly。 ### 内存效率 在接下来的测试中,我们使用 `debug populate 5000000 key 1024` 命令向 Dragonfly 和 Redis 分别写入了约 5GB 的数据。然后我们使用 `memtier` 发送更新流量并使用 `bgsave` 命令启动快照。下图清楚地展示了这两个服务器在内存效率方面的表现。 在空闲状态下,Dragonfly 比 Redis 节省约 30% 的内存。 在快照阶段,Dragonfly 也没有显示出任何明显的内存增加。 但同时,Redis 在峰值时的内存几乎达到了 Dragonfly 的 3 倍。 Dragonfly 完成快照也很快,仅在启动后几秒钟内就完成了。 有关 Dragonfly 内存效率的更多信息,请参见 [dashtable 文档](/docs/dashtable.md)。 ## 配置方法 Dragonfly 支持 Redis 的常见参数。 例如,您可以运行:`dragonfly --requirepass=foo --bind localhost`。 目前,Dragonfly 支持以下 Redis 特定参数: * `port`:Redis 连接端口,默认为 `6379`。 * `bind`:使用本地主机名仅允许本地连接,使用公共 IP 地址允许外部连接到**该 IP 地址**。 * `requirepass`:AUTH 认证密码,默认为空 `""`。 * `maxmemory`:限制数据库使用的最大内存(以字节为单位)。`0` 表示程序将自动确定其最大内存使用量。默认为 `0`。 * `dir`:默认情况下,dragonfly docker 使用 `/data` 文件夹进行快照。CLI 使用的是 `""`。你可以使用 `-v` docker 选项将其映射到主机文件夹。 * `dbfilename`:保存/加载数据库的文件名。默认为 `dump`; 此外,还有 Dragonfly 特定的参数选项: * `memcached_port`:在此端口上启用 memcached 兼容的 API。默认禁用。 * `keys_output_limit`:在`keys` 命令中返回的最大键数。默认为 `8192`。 `keys` 命令是危险命令。我们会截断结果以避免在获取太多键时内存溢出。 * `dbnum`:`select` 支持的最大数据库数。 * `cache_mode`:请参见下面的 [缓存](#全新的缓存设计) 部分。 * `hz`:键到期评估频率。默认为 `100`。空闲时,使用较低的频率可以占用较少的 CPU资源,但这会导致清理过期键的速度下降。 * `snapshot_cron`:定时自动备份快照的 cron 表达式,使用标准的、精确到分钟的 cron 语法。默认为空 `""`。 下面是一些 cron 表达式的示例,更多关于此参数的细节请参见[文档](https://www.dragonflydb.io/docs/managing-dragonfly/backups#the-snapshot_cron-flag)。 | Cron 表达式 | 描述 | |---------------|----------------------------------| | `* * * * *` | 每分钟 | | `*/5 * * * *` | 每隔 5 分钟 (00:00, 00:05, 00:10...) | | `5 */2 * * *` | 每隔 2 小时的第 5 分钟 | | `0 0 * * *` | 每天的 00:00 午夜 | | `0 6 * * 1-5` | 从星期一到星期五的每天 06:00 黎明 | * `primary_port_http_enabled`:如果为 true,则允许在主 TCP 端口上访问 HTTP 控制台。默认为 `true`。 * `admin_port`:如果设置,将在指定的端口上启用对控制台的管理访问。支持 HTTP 和 RESP 协议。默认禁用。 * `admin_bind`:如果设置,将管理控制台 TCP 连接绑定到给定地址。支持 HTTP 和 RESP 协议。默认为 `any`。 * `admin_nopass`: 如果设置,允许在不提供任何认证令牌的情况下,通过指定的端口访问管理控制台。同时支持 HTTP 和 RESP 协议。 默认为 `false`。 * `cluster_mode`:支持集群模式。目前仅支持 `emulated`。默认为空 `""`。 * `cluster_announce_ip`:集群模式下向客户端公开的 IP。 ### 启动脚本示例,包含常用选项: ```bash ./dragonfly-x86_64 --logtostderr --requirepass=youshallnotpass --cache_mode=true -dbnum 1 --bind localhost --port 6379 --maxmemory=12gb --keys_output_limit=12288 --dbfilename dump.rdb ``` 还可以通过运行 `dragonfly --flagfile ` 从配置文件中获取参数,配置文件的每行应该列出一个参数,并用等号代替键值参数的空格。 要获取更多选项,如日志管理或TLS支持,请运行 `dragonfly --help`。 ## 开发路线和开发现状 目前,Dragonfly支持约185个Redis命令以及除 `cas` 之外的所有 Memcached 命令。 我们几乎达到了Redis 5 API的水平。我们的下一个里程碑更新将会稳定基本功能并实现复刻API。 如果您发现您需要的命令尚未实现,请提出一个Issue。 对于dragonfly-native复制技术,我们正在设计一种分布式日志格式,该格式将支持更高的速度。 在实现复制功能之后,我们将继续实现API 3-6中其他缺失的Redis命令。 请参见[命令参考](https://dragonflydb.io/docs/category/command-reference)以了解Dragonfly当前支持的命令。 ## 设计决策 ### 全新的缓存设计 Dragonfly采用单一的自适应缓存算法,该算法非常简单且具备高内存效率。 你可以通过使用 `--cache_mode=true` 参数来启用缓存模式。一旦启用了此模式,Dragonfly将会删除最低概率可能被使用的内容,但这只会在接近最大内存限制时发生。 ### 相对准确的过期期限 过期范围限制最高为约8年。此外,**对于大于2^28ms的到期期限**,毫秒精度级别(PEXPIRE/PSETEX等)会被简化到秒级。 这种舍入的误差小于0.001%,我希望这在长时间范围情况下是可以接受的。 如果这不符合你的使用需求,请与我联系或提出一个Issue,并解释您的情况。 关于与Redis实现之间的更多差异,请参见[此处](docs/differences.md)。 ### 原生HTTP控制台和兼容Prometheus的标准 默认情况下,Dragonfly允许通过其主TCP端口(6379)进行HTTP访问。没错,您可以通过Redis协议或HTTP协议连接到Dragonfly - 服务器会在连接初始化期间自动识别协议。 不妨在你自己的浏览器中尝试一下。现在HTTP访问没有太多信息可供参考,但在将来,我们计划添加有用的调试和管理信息。如果您转到`: 6379/metrics` URL,您将看到一些兼容Prometheus的标准。 Prometheus导出的标准与Grafana仪表盘兼容,[请参见此处](tools/local/monitoring/grafana/provisioning/dashboards/dashboard.json)。 重要!HTTP控制台仅应在安全网络内访问。如果您将Dragonfly的TCP端口暴露在外部,则建议使用`--http_admin_console=false`或`--nohttp_admin_console`禁用控制台。 ## 开发背景 Dragonfly始于一项实验,旨在探索如果在2022年重新设计内存数据库,它会是什么样子。基于我们作为内存存储的用户以及作为云服务公司的工程师的经验教训,我们得知需要保留Dragonfly的两个关键属性:a) 为其所有操作提供原子性保证,b) 保证在非常高的吞吐量下实现低于毫秒的延迟。 我们面临的首要挑战是如何充分利用当今云服务器的CPU、内存和I/O资源。为了解决这个问题,我们使用了 [无共享式架构(shared-nothing architecture)](https://en.wikipedia.org/wiki/Shared-nothing_architecture),它允许我们在不同的线程之间分割内存存储的空间,使得每个线程可以管理自己的字典数据切片。我们称这些切片为“分片(shards)”。为无共享式架构提供线程和I/O管理功能的库在[这里](https://github.com/romange/helio)开源。 为了提供对多键并发操作的原子性保证,我们使用了最近学术研究的进展。我们选择了论文 ["VLL: a lock manager redesign for main memory database systems”](https://www.cs.umd.edu/~abadi/papers/vldbj-vll.pdf) 来开发Dragonfly的事务框架。无共享式架构和VLL的选择使我们能够在不使用互斥锁或自旋锁的情况下组合原子的多键操作。这是我们 PoC 的一个重要里程碑,它的性能在商业和开源解决方案中脱颖而出。 我们面临的第二个挑战是为新存储设计更高效的数据结构。为了实现这个目标,我们基于论文["Dash: Scalable Hashing on Persistent Memory"](https://arxiv.org/pdf/2003.07302.pdf)构建了核心哈希表结构。这篇论文本身是以持久性内存为中心的,与主存没有直接相关性。 然而,它非常适用于我们的问题。它提出了一种哈希表设计,允许我们维护Redis字典中存在的两个特殊属性:a) 数据存储增长时的渐进式哈希能力;b)使用无状态扫描操作时,遍历变化的字典的能力。除了这两个属性之外,Dash在CPU和内存方面都更加高效。通过利用Dash的设计,我们能够进一步创新,实现以下功能: - 针对TTL的高效记录过期功能。 - 一种新颖的缓存驱逐算法,具有比其他缓存策略(如LRU和LFU)更高的命中率,同时**零内存开销**。 - 一种新颖的无fork快照算法。 在我们为Dragonfly打下基础并满意其[性能](#基准测试)后,我们开始实现Redis和Memcached功能。 目前,我们已经实现了约185个Redis命令(大致相当于Redis 5.0 API)和13个Memcached命令。 最后,
我们的使命是构建一个设计良好、超高速、成本效益高的云工作负载内存数据存储系统,利用最新的硬件技术。我们旨在解决当前解决方案的痛点,同时保留其产品API和优势。 ================================================ FILE: TODO.md ================================================ 1. To move lua_project to dragonfly from helio (DONE) 2. To limit lua stack to something reasonable like 4096. 3. To inject our own allocator to lua to track its memory. ## Object lifecycle and thread-safety. Currently our transactional and locking model is based on an assumption that any READ or WRITE access to objects must be performed in a shard where they belong. However, this assumption can be relaxed to get significant gains for read-only queries. ### Explanation Our transactional framework prevents from READ-locked objects to be mutated. It does not prevent from their PrimaryTable to grow or change, of course. These objects can move to different entries inside the table. However, our CompactObject maintains the following property - its reference CompactObject.AsRef() is valid no matter where the master object moves and it's valid and safe for reading even from other threads. The exception regarding thread safety is SmallString which uses translation table for its pointers. If we change the SmallString translation table to be global and thread-safe (it should not have lots of write contention anyway) we may access primetable keys and values from another thread and write them directly to sockets. Use-case: large strings that need to be copied. Sets that need to be serialized for SMEMBERS/HGETALL commands etc. Additional complexity - we will need to lock those variables even for single hop transactions and unlock them afterwards. The unlocking hop does not need to increase user-visible latency since it can be done after we send reply to the socket. ================================================ FILE: contrib/charts/dragonfly/.helmignore ================================================ # Patterns to ignore when building packages. # This supports shell glob matching, relative path matching, and # negation (prefixed with !). Only one pattern per line. .DS_Store # Common VCS dirs .git/ .gitignore .bzr/ .bzrignore .hg/ .hgignore .svn/ # Common backup files *.swp *.bak *.tmp *.orig *~ # Various IDEs .project .idea/ *.tmproj .vscode/ ci/ *.go go.mod go.sum ================================================ FILE: contrib/charts/dragonfly/Chart.yaml ================================================ apiVersion: v2 name: dragonfly description: Dragonfly is a modern in-memory datastore, fully compatible with Redis and Memcached APIs. # A chart can be either an 'application' or a 'library' chart. # # Application charts are a collection of templates that can be packaged into versioned archives # to be deployed. # # Library charts provide useful utilities or functions for the chart developer. They're included as # a dependency of application charts to inject those utilities and functions into the rendering # pipeline. Library charts do not define any templates and therefore cannot be deployed. type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) version: v1.37.0 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to # follow Semantic Versioning. They should reflect the version the application is using. # It is recommended to use it with quotes. appVersion: "v1.37.0" home: https://dragonflydb.io/ keywords: - database - keyvalue - cache sources: - https://github.com/dragonflydb/dragonfly kubeVersion: ">=1.23.0-0" ================================================ FILE: contrib/charts/dragonfly/README.md ================================================ # dragonfly ![Version: v0.12.0](https://img.shields.io/badge/Version-v0.12.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: v0.12.0](https://img.shields.io/badge/AppVersion-v0.12.0-informational?style=flat-square) Dragonfly is a modern in-memory datastore, fully compatible with Redis and Memcached APIs. **Homepage:** ## Source Code * ## Requirements Kubernetes: `>=1.23.0-0` ## Installing from a pre-packaged OCI Pick a version from https://github.com/dragonflydb/dragonfly/pkgs/container/dragonfly%2Fhelm%2Fdragonfly Example: ```shell VERSION=v1.12.1 helm upgrade --install dragonfly oci://ghcr.io/dragonflydb/dragonfly/helm/dragonfly --version $VERSION ``` ## Values | Key | Type | Default | Description | |-----|------|---------|-------------| | affinity | object | `{}` | Affinity for pod assignment | | command | list | `[]` | Allow overriding the container's command | | commonLabels | object | `{}` | Common labels to add to all K8s resources | | extraArgs | list | `[]` | Extra arguments to pass to the dragonfly binary | | extraContainers | list | `[]` | Additional sidecar containers | | extraObjects | list | `[]` | extra K8s manifests to deploy | | extraVolumeMounts | list | `[]` | Extra volume mounts corresponding to the volumes mounted above | | extraVolumes | list | `[]` | Extra volumes to mount into the pods | | fullnameOverride | string | `""` | String to fully override dragonfly.fullname | | image.pullPolicy | string | `"IfNotPresent"` | Dragonfly image pull policy | | image.repository | string | `"docker.dragonflydb.io/dragonflydb/dragonfly"` | Container Image Registry to pull the image from | | image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion. | | imagePullSecrets | list | `[]` | Container Registry Secret names in an array | | initContainers | list | `[]` | A list of initContainers to run before each pod starts | | nameOverride | string | `""` | String to partially override dragonfly.fullname | | nodeSelector | object | `{}` | Node labels for pod assignment | | podAnnotations | object | `{}` | Annotations for pods | | podSecurityContext | object | `{}` | Set securityContext for pod itself | | probes.livenessProbe.exec.command[0] | string | `"/bin/sh"` | | | probes.livenessProbe.exec.command[1] | string | `"/usr/local/bin/healthcheck.sh"` | | | probes.livenessProbe.failureThreshold | int | `3` | | | probes.livenessProbe.initialDelaySeconds | int | `10` | | | probes.livenessProbe.periodSeconds | int | `10` | | | probes.livenessProbe.successThreshold | int | `1` | | | probes.livenessProbe.timeoutSeconds | int | `5` | | | probes.readinessProbe.exec.command[0] | string | `"/bin/sh"` | | | probes.readinessProbe.exec.command[1] | string | `"/usr/local/bin/healthcheck.sh"` | | | probes.readinessProbe.failureThreshold | int | `3` | | | probes.readinessProbe.initialDelaySeconds | int | `10` | | | probes.readinessProbe.periodSeconds | int | `10` | | | probes.readinessProbe.successThreshold | int | `1` | | | probes.readinessProbe.timeoutSeconds | int | `5` | | | prometheusRule.enabled | bool | `false` | Deploy a PrometheusRule | | prometheusRule.spec | list | `[]` | PrometheusRule.Spec https://awesome-prometheus-alerts.grep.to/rules | | replicaCount | int | `1` | Number of replicas to deploy | | resources.limits | object | `{}` | The resource limits for the containers | | resources.requests | object | `{}` | The requested resources for the containers | | env | list | `[]` | Extra environment variables | | envFrom | list | `[]` | Extra environment variables from K8s objects | | securityContext | object | `{}` | Set securityContext for containers | | service.annotations | object | `{}` | Extra annotations for the service | | service.labels | object | `{}` | Extra labels for the service | | service.metrics.portName | string | `"metrics"` | name for the metrics port | | service.metrics.serviceType | string | `"ClusterIP"` | serviceType for the metrics service | | service.port | int | `6379` | Dragonfly service port | | service.type | string | `"ClusterIP"` | Service type to provision. Can be NodePort, ClusterIP or LoadBalancer | | serviceAccount.annotations | object | `{}` | Annotations to add to the service account | | serviceAccount.create | bool | `true` | Specifies whether a service account should be created | | serviceAccount.name | string | `""` | The name of the service account to use. If not set and create is true, a name is generated using the fullname template | | serviceMonitor.annotations | object | `{}` | additional annotations to apply to the metrics | | serviceMonitor.enabled | bool | `false` | If true, a ServiceMonitor CRD is created for a prometheus operator | | serviceMonitor.interval | string | `"10s"` | scrape interval | | serviceMonitor.labels | object | `{}` | additional labels to apply to the metrics | | serviceMonitor.namespace | string | `""` | namespace in which to deploy the ServiceMonitor CR. defaults to the application namespace | | serviceMonitor.scrapeTimeout | string | `"10s"` | scrape timeout | | storage.enabled | bool | `false` | If /data should persist. This will provision a StatefulSet instead. | | storage.requests | string | `"128Mi"` | Volume size to request for the PVC | | storage.storageClassName | string | `""` | Global StorageClass for Persistent Volume(s) | | tls.cert | string | `""` | TLS certificate | | tls.createCerts | bool | `false` | use cert-manager to automatically create the certificate | | tls.duration | string | `"87600h0m0s"` | duration or ttl of the validity of the created certificate | | tls.enabled | bool | `false` | enable TLS | | tls.existing_secret | string | `""` | use TLS certificates from existing secret | | tls.issuer.kind | string | `"ClusterIssuer"` | cert-manager issuer kind. Usually Issuer or ClusterIssuer | | tls.issuer.name | string | `"selfsigned"` | name of the referenced issuer | | tls.key | string | `""` | TLS private key | | tolerations | list | `[]` | Tolerations for pod assignment | ---------------------------------------------- Autogenerated from chart metadata using [helm-docs v1.11.0](https://github.com/norwoodj/helm-docs/releases/v1.11.0) ================================================ FILE: contrib/charts/dragonfly/ci/affinity-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: affinity: podAntiAffinity: preferredDuringSchedulingIgnoredDuringExecution: - podAffinityTerm: labelSelector: matchExpressions: - key: app.kubernetes.io/name operator: In values: - dragonfly topologyKey: kubernetes.io/hostname weight: 100 serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/affinity-values.yaml ================================================ affinity: podAntiAffinity: preferredDuringSchedulingIgnoredDuringExecution: - podAffinityTerm: labelSelector: matchExpressions: - key: app.kubernetes.io/name operator: In values: - dragonfly topologyKey: kubernetes.io/hostname weight: 100 ================================================ FILE: contrib/charts/dragonfly/ci/command_extraargs-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 command: - /usr/local/bin/dragonfly - --logtostderr args: - "--alsologtostderr" - --cache_mode=true resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/command_extraargs-values.yaml ================================================ command: - /usr/local/bin/dragonfly - --logtostderr extraArgs: - --cache_mode=true ================================================ FILE: contrib/charts/dragonfly/ci/commonlabels-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm project: cache-infrastructure team: platform --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm project: cache-infrastructure team: platform spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm project: cache-infrastructure team: platform spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test project: cache-infrastructure team: platform spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/commonlabels-values.yaml ================================================ commonLabels: team: platform project: cache-infrastructure ================================================ FILE: contrib/charts/dragonfly/ci/extracontainer-string-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - args: - -c - date; sleep 3600; command: - /bin/sh image: busybox:latest name: sidecar-string - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/extracontainer-string-values.yaml ================================================ extraContainers: - name: sidecar-string image: busybox:latest command: ["/bin/sh"] args: ["-c", "date; sleep 3600;"] ================================================ FILE: contrib/charts/dragonfly/ci/extracontainer-tpl-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: sidecar-tpl image: docker.dragonflydb.io/dragonflydb/dragonfly:latest command: ["/bin/sh"] args: ["-c", "date; sleep 3600;"] - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/extracontainer-tpl-values.yaml ================================================ extraContainers: | - name: sidecar-tpl image: {{ .Values.image.repository }}:latest command: ["/bin/sh"] args: ["-c", "date; sleep 3600;"] ================================================ FILE: contrib/charts/dragonfly/ci/extraenv-and-passwordSecret-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: my-secret stringData: password: password username: username type: Opaque --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 data: configKey1: configValue1 configKey2: configValue2 kind: ConfigMap metadata: name: my-configmap --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} env: - name: DFLY_requirepass valueFrom: secretKeyRef: name: dfly-password key: password - name: ENV_VAR43 value: value1 - name: ENV_VAR323 value: value2 envFrom: - configMapRef: name: my-configmap - secretRef: name: my-secret ================================================ FILE: contrib/charts/dragonfly/ci/extraenv-and-passwordSecret-values.yaml ================================================ extraObjects: - apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar - apiVersion: v1 kind: ConfigMap metadata: name: my-configmap data: configKey1: configValue1 configKey2: configValue2 - apiVersion: v1 kind: Secret metadata: name: my-secret type: Opaque stringData: username: username password: password env: - name: ENV_VAR43 value: value1 - name: ENV_VAR323 value: value2 envFrom: - configMapRef: name: my-configmap - secretRef: name: my-secret passwordFromSecret: enable: true existingSecret: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/extraenv-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: my-secret stringData: password: password username: username type: Opaque --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 data: configKey1: configValue1 configKey2: configValue2 kind: ConfigMap metadata: name: my-configmap --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} env: - name: ENV_VAR43 value: value1 - name: ENV_VAR323 value: value2 envFrom: - configMapRef: name: my-configmap - secretRef: name: my-secret ================================================ FILE: contrib/charts/dragonfly/ci/extraenv-values.yaml ================================================ extraObjects: - apiVersion: v1 kind: ConfigMap metadata: name: my-configmap data: configKey1: configValue1 configKey2: configValue2 - apiVersion: v1 kind: Secret metadata: name: my-secret type: Opaque stringData: username: username password: password env: - name: ENV_VAR43 value: value1 - name: ENV_VAR323 value: value2 envFrom: - configMapRef: name: my-configmap - secretRef: name: my-secret ================================================ FILE: contrib/charts/dragonfly/ci/extravolumes-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} volumeMounts: - mountPath: /tmp name: tmp volumes: - emptyDir: sizeLimit: 500Mi name: tmp ================================================ FILE: contrib/charts/dragonfly/ci/extravolumes-values.yaml ================================================ extraVolumes: - name: tmp emptyDir: sizeLimit: 500Mi extraVolumeMounts: - mountPath: /tmp name: tmp ================================================ FILE: contrib/charts/dragonfly/ci/initcontainer-string-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly initContainers: - args: - -c - date; sleep 1; command: - /bin/sh image: busybox:1.28 name: initcontainer-string containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/initcontainer-string-values.yaml ================================================ initContainers: - name: initcontainer-string image: busybox:1.28 command: ["/bin/sh"] args: ["-c", "date; sleep 1;"] ================================================ FILE: contrib/charts/dragonfly/ci/initcontainer-tpl-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly initContainers: - name: initcontainer-tpl image: docker.dragonflydb.io/dragonflydb/dragonfly:latest command: ["/bin/sh"] args: ["-c", "date; sleep 1;"] containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/initcontainer-tpl-values.yaml ================================================ initContainers: | - name: initcontainer-tpl image: {{ .Values.image.repository }}:latest command: ["/bin/sh"] args: ["-c", "date; sleep 1;"] ================================================ FILE: contrib/charts/dragonfly/ci/password-old-env-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.13.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} env: - name: DFLY_PASSWORD valueFrom: secretKeyRef: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/password-old-env-values.yaml ================================================ image: tag: "v1.13.0" extraObjects: - apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar passwordFromSecret: enable: true existingSecret: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/passwordsecret-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} env: - name: DFLY_requirepass valueFrom: secretKeyRef: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/passwordsecret-values.tpl.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: dragonfly-password stringData: password: foobar --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} env: - name: DFLY_requirepass valueFrom: secretKeyRef: name: dragonfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/passwordsecret-values.tpl.yaml ================================================ extraObjects: - apiVersion: v1 kind: Secret metadata: name: dragonfly-password stringData: password: foobar passwordFromSecret: enable: true existingSecret: name: '{{ include "dragonfly.name" $ }}-password' key: password ================================================ FILE: contrib/charts/dragonfly/ci/passwordsecret-values.yaml ================================================ extraObjects: - apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar passwordFromSecret: enable: true existingSecret: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/persistence-and-existing-secret.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/statefulset.yaml apiVersion: apps/v1 kind: StatefulSet metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: serviceName: test replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} volumeMounts: - mountPath: /data name: "test-data" env: - name: DFLY_requirepass valueFrom: secretKeyRef: name: dfly-password key: password volumeClaimTemplates: - metadata: name: "test-data" spec: accessModes: [ "ReadWriteOnce" ] storageClassName: standard resources: requests: storage: 128Mi ================================================ FILE: contrib/charts/dragonfly/ci/persistence-and-existing-secret.yaml ================================================ storage: enabled: true storageClassName: "standard" requests: 128Mi extraObjects: - apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar passwordFromSecret: enable: true existingSecret: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/persistent-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/statefulset.yaml apiVersion: apps/v1 kind: StatefulSet metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: serviceName: test replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} volumeMounts: - mountPath: /data name: "test-data" volumeClaimTemplates: - metadata: name: "test-data" spec: accessModes: [ "ReadWriteOnce" ] storageClassName: standard resources: requests: storage: 128Mi ================================================ FILE: contrib/charts/dragonfly/ci/persistent-values.yaml ================================================ storage: enabled: true storageClassName: "standard" requests: 128Mi ================================================ FILE: contrib/charts/dragonfly/ci/priorityclassname-values.golden.yaml ================================================ --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: scheduling.k8s.io/v1 description: This priority class should be used only for tests. globalDefault: false kind: PriorityClass metadata: name: high-priority value: 1000000 --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: priorityClassName: high-priority serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/priorityclassname-values.yaml ================================================ priorityClassName: "high-priority" extraObjects: - apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: name: high-priority value: 1000000 globalDefault: false description: "This priority class should be used only for tests." ================================================ FILE: contrib/charts/dragonfly/ci/prometheusrules-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/metrics-service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly-metrics namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm type: metrics spec: type: ClusterIP ports: - name: metrics port: 6379 targetPort: 6379 protocol: TCP selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} --- # Source: dragonfly/templates/servicemonitor.yaml apiVersion: monitoring.coreos.com/v1 kind: ServiceMonitor metadata: name: test-dragonfly-metrics namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: endpoints: - interval: 10s scrapeTimeout: 10s honorLabels: true port: metrics path: /metrics scheme: http jobLabel: "test" selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test type: metrics namespaceSelector: matchNames: - default ================================================ FILE: contrib/charts/dragonfly/ci/prometheusrules-values.yaml ================================================ serviceMonitor: enabled: true prometheusRule: enabled: true namespace: default spec: - alert: RedisDown expr: absent(dragonfly_master > 0) for: 0m labels: severity: critical annotations: summary: Redis instance is down description: > "Redis instance is down" runbook_url: "https://octopus.com/docs/runbooks/runbook-examples" ================================================ FILE: contrib/charts/dragonfly/ci/resources-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: cpu: 100m memory: 400Mi requests: cpu: 100m memory: 300Mi ================================================ FILE: contrib/charts/dragonfly/ci/resources-values.yaml ================================================ resources: requests: cpu: 100m memory: 300Mi limits: cpu: 100m memory: 400Mi ================================================ FILE: contrib/charts/dragonfly/ci/securitycontext-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly securityContext: allowPrivilegeEscalation: false readOnlyRootFilesystem: true image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/securitycontext-values.yaml ================================================ podSecurityContext: {} securityContext: allowPrivilegeEscalation: false readOnlyRootFilesystem: true ================================================ FILE: contrib/charts/dragonfly/ci/service-loadbalancer-ip.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: LoadBalancer loadBalancerIP: 127.0.0.1 ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/service-loadbalancer-ip.yaml ================================================ service: type: LoadBalancer loadBalancerIP: "127.0.0.1" ================================================ FILE: contrib/charts/dragonfly/ci/service-monitor-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/metrics-service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly-metrics namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm type: metrics spec: type: ClusterIP ports: - name: metrics port: 6379 targetPort: 6379 protocol: TCP selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} --- # Source: dragonfly/templates/servicemonitor.yaml apiVersion: monitoring.coreos.com/v1 kind: ServiceMonitor metadata: name: test-dragonfly-metrics namespace: default labels: release: prometheus-stack app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: endpoints: - interval: 10s scrapeTimeout: 10s honorLabels: true port: metrics path: /metrics scheme: http jobLabel: "test" selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test type: metrics namespaceSelector: matchNames: - default ================================================ FILE: contrib/charts/dragonfly/ci/service-monitor-values.yaml ================================================ serviceMonitor: enabled: true namespace: "" labels: release: prometheus-stack annotations: {} interval: 10s scrapeTimeout: 10s ================================================ FILE: contrib/charts/dragonfly/ci/taints-tolerations-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: tolerations: - effect: NoSchedule key: key/high-memory operator: Equal value: "true" - effect: PreferNoSchedule key: key/high-memory operator: Equal value: "true" affinity: nodeAffinity: requiredDuringSchedulingIgnoredDuringExecution: nodeSelectorTerms: - matchExpressions: - key: key/node-kind operator: In values: - high-memory serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/taints-tolerations-values.yaml ================================================ tolerations: - key: key/high-memory operator: "Equal" value: "true" effect: "NoSchedule" - key: key/high-memory operator: "Equal" value: "true" effect: "PreferNoSchedule" affinity: nodeAffinity: requiredDuringSchedulingIgnoredDuringExecution: nodeSelectorTerms: - matchExpressions: - key: key/node-kind operator: In values: - high-memory ================================================ FILE: contrib/charts/dragonfly/ci/tls-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/extra-manifests.yaml apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar --- # Source: dragonfly/templates/tls-secret.yaml apiVersion: v1 kind: Secret metadata: name: test-dragonfly-tls namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm type: kubernetes.io/tls data: tls.crt: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUI4ekNDQVpxZ0F3SUJBZ0lFYmIyWjJqQUtCZ2dxaGtqT1BRUURBekJaTVFzd0NRWURWUVFHRXdKR1R6RWcKTUI0R0ExVUVBd3dYWkhKaFoyOXVabXg1TG1SeVlXZHZibVpzZVM1emRtTXhEREFLQmdOVkJBZ01BMlp2YnpFTQpNQW9HQTFVRUJ3d0RabTl2TVF3d0NnWURWUVFLREFObWIyOHdIaGNOTWpJeE1qSTVNVEl3TXpJM1doY05Nekl4Ck1qSTJNVEl3TXpJM1dqQlpNUXN3Q1FZRFZRUUdFd0pHVHpFZ01CNEdBMVVFQXd3WFpISmhaMjl1Wm14NUxtUnkKWVdkdmJtWnNlUzV6ZG1NeEREQUtCZ05WQkFnTUEyWnZiekVNTUFvR0ExVUVCd3dEWm05dk1Rd3dDZ1lEVlFRSwpEQU5tYjI4d1dUQVRCZ2NxaGtqT1BRSUJCZ2dxaGtqT1BRTUJCd05DQUFRV05mVHVOamhQRWk3aDFjaUNTMEl0CmZLZ2lCaHhMR2xGM010amxGVGpDcnpreW5TU0FCb010TmxqY0RFMGhtL2l6YlJVb2dBY0RGY3ZrbnZDaHp4YXEKbzFBd1RqQWRCZ05WSFE0RUZnUVVTTjZGYnNKWjJFVWZYM2JlQ2g1Y0VvNmNrdFF3SHdZRFZSMGpCQmd3Rm9BVQpTTjZGYnNKWjJFVWZYM2JlQ2g1Y0VvNmNrdFF3REFZRFZSMFRCQVV3QXdFQi96QUtCZ2dxaGtqT1BRUURBd05ICkFEQkVBaUI2dEc1eHp5ajRpVC9lMHdwQ01SSE92bFFLUWV4QnloeU5QQWhybzlaQ1JnSWdhRGNkOXZNOHJDYmIKSlBSeXptMGlOOU9XTS9BMjRubW0zaXRuM0k0cmNEMD0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=" tls.key: "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSU5oNmVNRHJCbEFpVDY4VDhvdnpHbjZKWmJKZXZVZWZZa0lJWU5Xd3c1NXlvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFRmpYMDdqWTRUeEl1NGRYSWdrdENMWHlvSWdZY1N4cFJkekxZNVJVNHdxODVNcDBrZ0FhRApMVFpZM0F4TkladjRzMjBWS0lBSEF4WEw1Sjd3b2M4V3FnPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: checksum/tls-secret: b97190b6585f160d4f709b965d275564bb51cd19202c6e014e1d42a972446a5c labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" - "--tls" - "--tls_cert_file=/etc/dragonfly/tls/tls.crt" - "--tls_key_file=/etc/dragonfly/tls/tls.key" resources: limits: {} requests: {} volumeMounts: - mountPath: /etc/dragonfly/tls name: tls env: - name: DFLY_requirepass valueFrom: secretKeyRef: name: dfly-password key: password volumes: - name: tls secret: secretName: test-dragonfly-tls ================================================ FILE: contrib/charts/dragonfly/ci/tls-values.yaml ================================================ tls: enabled: true existing_secret: "" cert: | -----BEGIN CERTIFICATE----- MIIB8zCCAZqgAwIBAgIEbb2Z2jAKBggqhkjOPQQDAzBZMQswCQYDVQQGEwJGTzEg MB4GA1UEAwwXZHJhZ29uZmx5LmRyYWdvbmZseS5zdmMxDDAKBgNVBAgMA2ZvbzEM MAoGA1UEBwwDZm9vMQwwCgYDVQQKDANmb28wHhcNMjIxMjI5MTIwMzI3WhcNMzIx MjI2MTIwMzI3WjBZMQswCQYDVQQGEwJGTzEgMB4GA1UEAwwXZHJhZ29uZmx5LmRy YWdvbmZseS5zdmMxDDAKBgNVBAgMA2ZvbzEMMAoGA1UEBwwDZm9vMQwwCgYDVQQK DANmb28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQWNfTuNjhPEi7h1ciCS0It fKgiBhxLGlF3MtjlFTjCrzkynSSABoMtNljcDE0hm/izbRUogAcDFcvknvChzxaq o1AwTjAdBgNVHQ4EFgQUSN6FbsJZ2EUfX3beCh5cEo6cktQwHwYDVR0jBBgwFoAU SN6FbsJZ2EUfX3beCh5cEo6cktQwDAYDVR0TBAUwAwEB/zAKBggqhkjOPQQDAwNH ADBEAiB6tG5xzyj4iT/e0wpCMRHOvlQKQexByhyNPAhro9ZCRgIgaDcd9vM8rCbb JPRyzm0iN9OWM/A24nmm3itn3I4rcD0= -----END CERTIFICATE----- key: | -----BEGIN EC PRIVATE KEY----- MHcCAQEEINh6eMDrBlAiT68T8ovzGn6JZbJevUefYkIIYNWww55yoAoGCCqGSM49 AwEHoUQDQgAEFjX07jY4TxIu4dXIgktCLXyoIgYcSxpRdzLY5RU4wq85Mp0kgAaD LTZY3AxNIZv4s20VKIAHAxXL5J7woc8Wqg== -----END EC PRIVATE KEY----- extraObjects: - apiVersion: v1 kind: Secret metadata: name: dfly-password stringData: password: foobar passwordFromSecret: enable: true existingSecret: name: dfly-password key: password ================================================ FILE: contrib/charts/dragonfly/ci/tolerations-values.golden.yaml ================================================ --- # Source: dragonfly/templates/serviceaccount.yaml apiVersion: v1 kind: ServiceAccount metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm --- # Source: dragonfly/templates/service.yaml apiVersion: v1 kind: Service metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: type: ClusterIP ports: - port: 6379 targetPort: dragonfly protocol: TCP name: dragonfly selector: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test --- # Source: dragonfly/templates/deployment.yaml apiVersion: apps/v1 kind: Deployment metadata: name: test-dragonfly namespace: default labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test app.kubernetes.io/version: "v1.37.0" app.kubernetes.io/managed-by: Helm spec: replicas: 1 selector: matchLabels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test template: metadata: annotations: labels: app.kubernetes.io/name: dragonfly app.kubernetes.io/instance: test spec: tolerations: - effect: NoSchedule operator: Exists serviceAccountName: test-dragonfly containers: - name: dragonfly image: "docker.dragonflydb.io/dragonflydb/dragonfly:v1.37.0" imagePullPolicy: IfNotPresent ports: - name: dragonfly containerPort: 6379 protocol: TCP livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh failureThreshold: 3 initialDelaySeconds: 10 periodSeconds: 10 successThreshold: 1 timeoutSeconds: 5 args: - "--alsologtostderr" resources: limits: {} requests: {} ================================================ FILE: contrib/charts/dragonfly/ci/tolerations-values.yaml ================================================ tolerations: - effect: NoSchedule operator: Exists ================================================ FILE: contrib/charts/dragonfly/go.mod ================================================ module dragonfly go 1.24.0 toolchain go1.24.7 require github.com/gruntwork-io/terratest v0.51.0 require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/BurntSushi/toml v1.5.0 // indirect github.com/aws/aws-sdk-go-v2 v1.39.1 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect github.com/aws/aws-sdk-go-v2/config v1.31.10 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.18.14 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.8 // indirect github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.8 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.8 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.8 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.8 // indirect github.com/aws/aws-sdk-go-v2/service/acm v1.37.5 // indirect github.com/aws/aws-sdk-go-v2/service/autoscaling v1.59.2 // indirect github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.58.1 // indirect github.com/aws/aws-sdk-go-v2/service/dynamodb v1.50.4 // indirect github.com/aws/aws-sdk-go-v2/service/ec2 v1.254.0 // indirect github.com/aws/aws-sdk-go-v2/service/ecr v1.50.4 // indirect github.com/aws/aws-sdk-go-v2/service/ecs v1.64.1 // indirect github.com/aws/aws-sdk-go-v2/service/iam v1.47.6 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.8 // indirect github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.8 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.8 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.8 // indirect github.com/aws/aws-sdk-go-v2/service/kms v1.45.5 // indirect github.com/aws/aws-sdk-go-v2/service/lambda v1.77.5 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.107.1 // indirect github.com/aws/aws-sdk-go-v2/service/route53 v1.58.3 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.88.2 // indirect github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.39.5 // indirect github.com/aws/aws-sdk-go-v2/service/sns v1.38.4 // indirect github.com/aws/aws-sdk-go-v2/service/sqs v1.42.7 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.65.0 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.29.4 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.0 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.38.5 // indirect github.com/aws/smithy-go v1.23.0 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-errors/errors v1.5.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-openapi/jsonpointer v0.22.0 // indirect github.com/go-openapi/jsonreference v0.21.1 // indirect github.com/go-openapi/swag v0.25.0 // indirect github.com/go-openapi/swag/cmdutils v0.25.0 // indirect github.com/go-openapi/swag/conv v0.25.0 // indirect github.com/go-openapi/swag/fileutils v0.25.0 // indirect github.com/go-openapi/swag/jsonname v0.25.0 // indirect github.com/go-openapi/swag/jsonutils v0.25.0 // indirect github.com/go-openapi/swag/loading v0.25.0 // indirect github.com/go-openapi/swag/mangling v0.25.0 // indirect github.com/go-openapi/swag/netutils v0.25.0 // indirect github.com/go-openapi/swag/stringutils v0.25.0 // indirect github.com/go-openapi/swag/typeutils v0.25.0 // indirect github.com/go-openapi/swag/yamlutils v0.25.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/gonvenience/bunt v1.4.2 // indirect github.com/gonvenience/idem v0.0.2 // indirect github.com/gonvenience/neat v1.3.16 // indirect github.com/gonvenience/term v1.0.4 // indirect github.com/gonvenience/text v1.0.9 // indirect github.com/gonvenience/ytbx v1.4.7 // indirect github.com/google/gnostic-models v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/gruntwork-io/go-commons v0.17.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/homeport/dyff v1.10.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.6 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-ciede2000 v0.0.0-20170301095244-782e8c62fec3 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-zglob v0.0.6 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect github.com/mitchellh/hashstructure v1.1.0 // indirect github.com/moby/spdystream v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pquerna/otp v1.5.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sergi/go-diff v1.4.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/testify v1.11.1 // indirect github.com/texttheater/golang-levenshtein v1.0.1 // indirect github.com/urfave/cli/v2 v2.27.7 // indirect github.com/virtuald/go-ordered-json v0.0.0-20170621173500-b18e6e673d74 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.42.0 // indirect golang.org/x/exp v0.0.0-20250911091902-df9299821621 // indirect golang.org/x/net v0.44.0 // indirect golang.org/x/oauth2 v0.31.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/term v0.35.0 // indirect golang.org/x/text v0.29.0 // indirect golang.org/x/time v0.13.0 // indirect google.golang.org/protobuf v1.36.9 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.34.1 // indirect k8s.io/apimachinery v0.34.1 // indirect k8s.io/client-go v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect sigs.k8s.io/yaml v1.6.0 // indirect ) ================================================ FILE: contrib/charts/dragonfly/go.sum ================================================ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/aws/aws-sdk-go-v2 v1.39.1 h1:fWZhGAwVRK/fAN2tmt7ilH4PPAE11rDj7HytrmbZ2FE= github.com/aws/aws-sdk-go-v2 v1.39.1/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1/go.mod h1:ddqbooRZYNoJ2dsTwOty16rM+/Aqmk/GOXrK8cg7V00= github.com/aws/aws-sdk-go-v2/config v1.31.10 h1:7LllDZAegXU3yk41mwM6KcPu0wmjKGQB1bg99bNdQm4= github.com/aws/aws-sdk-go-v2/config v1.31.10/go.mod h1:Ge6gzXPjqu4v0oHvgAwvGzYcK921GU0hQM25WF/Kl+8= github.com/aws/aws-sdk-go-v2/credentials v1.18.14 h1:TxkI7QI+sFkTItN/6cJuMZEIVMFXeu2dI1ZffkXngKI= github.com/aws/aws-sdk-go-v2/credentials v1.18.14/go.mod h1:12x4Uw/vijC11XkctTjy92TNCQ+UnNJkT7fzX0Yd93E= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.8 h1:gLD09eaJUdiszm7vd1btiQUYE0Hj+0I2b8AS+75z9AY= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.8/go.mod h1:4RW3oMPt1POR74qVOC4SbubxAwdP4pCT0nSw3jycOU4= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.8 h1:QcAh/TNGM3MWe95ilMWwnieXWXsyM33Mb/RuTGlWLm4= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.8/go.mod h1:72m/ZCCgYpXJzsgI8uJFYMnXEjtZ4kkaolL9NRXLSnU= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.8 h1:6bgAZgRyT4RoFWhxS+aoGMFyE0cD1bSzFnEEi4bFPGI= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.8/go.mod h1:KcGkXFVU8U28qS4KvLEcPxytPZPBcRawaH2Pf/0jptE= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.8 h1:HhJYoES3zOz34yWEpGENqJvRVPqpmJyR3+AFg9ybhdY= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.8/go.mod h1:JnA+hPWeYAVbDssp83tv+ysAG8lTfLVXvSsyKg/7xNA= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.8 h1:1/bT9kDdLQzfZ1e6J6hpW+SfNDd6xrV8F3M2CuGyUz8= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.8/go.mod h1:RbdwTONAIi59ej/+1H+QzZORt5bcyAtbrS7FQb2pvz0= github.com/aws/aws-sdk-go-v2/service/acm v1.37.5 h1:vTmyvkmMJEKZgyhSuaEv8gZCJJlgNpSpYy/4CExjHoA= github.com/aws/aws-sdk-go-v2/service/acm v1.37.5/go.mod h1:TmyW/AiLmFEXwFsm5hh2T86BpgFbcB1icshuzFu8LgY= github.com/aws/aws-sdk-go-v2/service/autoscaling v1.59.2 h1:YOWVoIjUoiwAVIRVU3PG2yNldh9dQT5OegnO99RO4ls= github.com/aws/aws-sdk-go-v2/service/autoscaling v1.59.2/go.mod h1:t08UbddtoRQcKiIW2ZTfxX5x6vRaTj6KrKcf1R0I4tw= github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.58.1 h1:JMYpgsJ31l0wjJCerJtIBo39HznZJ/ENJJzOSTcJh68= github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs v1.58.1/go.mod h1:zqtpx8Y/EydPCFy5MA9AJJBfJ+mCQz8BNHj2CvDvaYA= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.50.4 h1:3EE5TTeBHPTKQNNeIHdXcJ6ENDsN7c2rCQUtbdolwV8= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.50.4/go.mod h1:8rWv4Lq/jrlspgd/wpdFeKrxLByJlfpFEk9g0Tw5iOw= github.com/aws/aws-sdk-go-v2/service/ec2 v1.254.0 h1:fTLR6dLDTGChAjecRPlVrKeznT0rVdzR4yn9Z68MTGk= github.com/aws/aws-sdk-go-v2/service/ec2 v1.254.0/go.mod h1:V0jbRy1/IPapnkqgXSwVOFB+u5pnCwd9S+R3pKWULC4= github.com/aws/aws-sdk-go-v2/service/ecr v1.50.4 h1:kPe1ZLqERYZxxDi6ysoX4oYavSJ6lkGaadsN1ogg3I8= github.com/aws/aws-sdk-go-v2/service/ecr v1.50.4/go.mod h1:cAJR/1pLXISKFSSJsrsTZPw05PLL5xOIpbbzxM7GLiI= github.com/aws/aws-sdk-go-v2/service/ecs v1.64.1 h1:kAzHjjqQnu3ET5/cX1N5tKPqtExYk97wpD6MpRadq/A= github.com/aws/aws-sdk-go-v2/service/ecs v1.64.1/go.mod h1:HIaZTpBD7+mgQEIv2wMzXYJw2T23sMFVNp2Mkw/ODFk= github.com/aws/aws-sdk-go-v2/service/iam v1.47.6 h1:EWehQXACWr+6hzfZPwZChlfoVhiUCfLHE0Xh3kAfzWQ= github.com/aws/aws-sdk-go-v2/service/iam v1.47.6/go.mod h1:qRXgEBWPIltrWHQwU+HkyBvwh1QgeigFcaCGCIVrWk0= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.8 h1:tIN8MFT1z5STK5kTdOT1TCfMN/bn5fSEnlKsTL8qBOU= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.8/go.mod h1:VKS56txtNWjKI8FqD/hliL0BcshyF4ZaLBa1rm2Y+5s= github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.8 h1:0lJ7+zL81zesTu1nd1ocKpEoYi6BqDppjoAJLn18Vr0= github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.8/go.mod h1:5t+iImUczd3RYSVnc20t/ohBrmrkpdcy89pm62BSDQo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.8 h1:M6JI2aGFEzYxsF6CXIuRBnkge9Wf9a2xU39rNeXgu10= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.8/go.mod h1:Fw+MyTwlwjFsSTE31mH211Np+CUslml8mzc0AFEG09s= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.8 h1:AgYCo1Rb8XChJXA871BXHDNxNWOTAr6V5YdsRIBbgv0= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.8/go.mod h1:Au9dvIGm1Hbqnt29d3VakOCQuN9l0WrkDDTRq8biWS4= github.com/aws/aws-sdk-go-v2/service/kms v1.45.5 h1:5AsmehPcxIp+Y8GVRa91UKpu3AO1gxhdckippth6bnA= github.com/aws/aws-sdk-go-v2/service/kms v1.45.5/go.mod h1:ooAdc5n3rjgEznIXncCYY6V9+YQDcJAYyZDJ4TwLSDM= github.com/aws/aws-sdk-go-v2/service/lambda v1.77.5 h1:rKc5Ad3PJlXGo5pigWii+m/hSPgxbNJtOicEP5nbV2E= github.com/aws/aws-sdk-go-v2/service/lambda v1.77.5/go.mod h1:fPYDox6U6puh6xhMyWpUWd19QIIqMlcQ6iCdC1jk2cE= github.com/aws/aws-sdk-go-v2/service/rds v1.107.1 h1:j7GQZWF0CbHCObPEZUK6QuP3yUQwjBJmlaojHPRZ6f8= github.com/aws/aws-sdk-go-v2/service/rds v1.107.1/go.mod h1:OW/mwGWAs6l1HnZpJupatcUFt1V0y6OiUMUp+Wd0DEc= github.com/aws/aws-sdk-go-v2/service/route53 v1.58.3 h1:jQzRC+0eI/l5mFXVoPTyyolrqyZtKIYaKHSuKJoIJKs= github.com/aws/aws-sdk-go-v2/service/route53 v1.58.3/go.mod h1:1GNaojT/gG4Ru9tT39ton6kRZ3FvptJ/QRKBoqUOVX4= github.com/aws/aws-sdk-go-v2/service/s3 v1.88.2 h1:T7b3qniouutV5Wwa9B1q7gW+Y8s1B3g9RE9qa7zLBIM= github.com/aws/aws-sdk-go-v2/service/s3 v1.88.2/go.mod h1:tW9TsLb6t1eaTdBE6LITyJW1m/+DjQPU78Q/jT2FJu8= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.39.5 h1:ssRo1z8FdFaoZc1AWz1R6/amdsxy56akVPql15/AYSs= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.39.5/go.mod h1:ut4ISJEOb5t2M1DNfx1787tF3UJGlwF3Q97uEulV/lU= github.com/aws/aws-sdk-go-v2/service/sns v1.38.4 h1:MkaMcZGwW9vt0cW+N2i5JSF/zkxKyDqpGCP1VWip3YM= github.com/aws/aws-sdk-go-v2/service/sns v1.38.4/go.mod h1:S0rwG+VHP1/jKoT6xJDe8f8Apz9HO42dUI8DmnOzYYU= github.com/aws/aws-sdk-go-v2/service/sqs v1.42.7 h1:KZldI+77SMG8vHDE55HYSjPcKSeOy2WIRo+HtIz2IY8= github.com/aws/aws-sdk-go-v2/service/sqs v1.42.7/go.mod h1:wbgNsM9psd+xQtLSDUAICjFCT/HXNZIgx3qyjqQNt88= github.com/aws/aws-sdk-go-v2/service/ssm v1.65.0 h1:6bPuMpky+qG4L7VQ1RyYVkBrEix1JRC/JPweTRfRDko= github.com/aws/aws-sdk-go-v2/service/ssm v1.65.0/go.mod h1:mbnkxOJSgkV4YHA5dWSlLolvC1EuxNcaGfn0Gf4e9UU= github.com/aws/aws-sdk-go-v2/service/sso v1.29.4 h1:FTdEN9dtWPB0EOURNtDPmwGp6GGvMqRJCAihkSl/1No= github.com/aws/aws-sdk-go-v2/service/sso v1.29.4/go.mod h1:mYubxV9Ff42fZH4kexj43gFPhgc/LyC7KqvUKt1watc= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.0 h1:I7ghctfGXrscr7r1Ga/mDqSJKm7Fkpl5Mwq79Z+rZqU= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.0/go.mod h1:Zo9id81XP6jbayIFWNuDpA6lMBWhsVy+3ou2jLa4JnA= github.com/aws/aws-sdk-go-v2/service/sts v1.38.5 h1:+LVB0xBqEgjQoqr9bGZbRzvg212B0f17JdflleJRNR4= github.com/aws/aws-sdk-go-v2/service/sts v1.38.5/go.mod h1:xoaxeqnnUaZjPjaICgIy5B+MHCSb/ZSOn4MvkFNOUA0= github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes= github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8bk= github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-openapi/jsonpointer v0.22.0 h1:TmMhghgNef9YXxTu1tOopo+0BGEytxA+okbry0HjZsM= github.com/go-openapi/jsonpointer v0.22.0/go.mod h1:xt3jV88UtExdIkkL7NloURjRQjbeUgcxFblMjq2iaiU= github.com/go-openapi/jsonreference v0.21.1 h1:bSKrcl8819zKiOgxkbVNRUBIr6Wwj9KYrDbMjRs0cDA= github.com/go-openapi/jsonreference v0.21.1/go.mod h1:PWs8rO4xxTUqKGu+lEvvCxD5k2X7QYkKAepJyCmSTT8= github.com/go-openapi/swag v0.25.0 h1:xyZhlgInBg6wOtyTD5b+pzwVqHSOliAvgvKW+POFUts= github.com/go-openapi/swag v0.25.0/go.mod h1:yhsa7GJvO1JBFZccLq9uh/MawsC0PQd8sNz88VBXQlU= github.com/go-openapi/swag/cmdutils v0.25.0 h1:iYZ24DEGPEk6L1jO09vw39KfpxbG7KhS+WeQexS8U5A= github.com/go-openapi/swag/cmdutils v0.25.0/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0= github.com/go-openapi/swag/conv v0.25.0 h1:5K+e44HkOgCVE0IJTbivurzHahT62DPr2DEJqR/+4pA= github.com/go-openapi/swag/conv v0.25.0/go.mod h1:oa1ZZnb1jubNdZlD1iAhGXt6Ic4hHtuO23MwTgAXR88= github.com/go-openapi/swag/fileutils v0.25.0 h1:t7aQRuRfsP29dY4vfrNvDZv7RurwRHuyjUedtYVDmYY= github.com/go-openapi/swag/fileutils v0.25.0/go.mod h1:+NXtt5xNZZqmpIpjqcujqojGFek9/w55b3ecmOdtg8M= github.com/go-openapi/swag/jsonname v0.25.0 h1:+fuNs9gdkb2w10hgsgOBx9jtx0pvtUaDRYxD91BEpEQ= github.com/go-openapi/swag/jsonname v0.25.0/go.mod h1:71Tekow6UOLBD3wS7XhdT98g5J5GR13NOTQ9/6Q11Zo= github.com/go-openapi/swag/jsonutils v0.25.0 h1:ELKpJT29T4N/AvmDqMeDFLx2QRZQOYFthzctbIX30+A= github.com/go-openapi/swag/jsonutils v0.25.0/go.mod h1:KYL8GyGoi6tek9ajpvn0le4BWmKoUVVv8yPxklViIMo= github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.0 h1:ca9vKxLnJegL2bzqXRWNabKdqVGxBzrnO8/UZnr5W0Y= github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.0/go.mod h1:kjmweouyPwRUEYMSrbAidoLMGeJ5p6zdHi9BgZiqmsg= github.com/go-openapi/swag/loading v0.25.0 h1:e9mjE5fJeaK0LTepHMtG0Ief+9ETXLFhWCx7ZfiI6LI= github.com/go-openapi/swag/loading v0.25.0/go.mod h1:2ZCWXwVY1XYuoue8Bdjbn5GJK4/ufXbCfcvoSPFQJqM= github.com/go-openapi/swag/mangling v0.25.0 h1:VdTfDWX5lS3yURxYHF5SK7kYelSK69Lv2xEAeudTzM8= github.com/go-openapi/swag/mangling v0.25.0/go.mod h1:CdiMQ6pnfAgyQGSOIYnZkXvqhnnwOn997uXZMAd/7mQ= github.com/go-openapi/swag/netutils v0.25.0 h1:/e1LPmXfF9fcOYbbaP3+SQgon1fRwe5EZ0FjpR4vAjs= github.com/go-openapi/swag/netutils v0.25.0/go.mod h1:CAkkvqnUJX8NV96tNhEQvKz8SQo2KF0f7LleiJwIeRE= github.com/go-openapi/swag/stringutils v0.25.0 h1:iYfCF45GUeI/1Yrh8rQtTFCp5K1ToqWhUdzJZwvXvv8= github.com/go-openapi/swag/stringutils v0.25.0/go.mod h1:JLdSAq5169HaiDUbTvArA2yQxmgn4D6h4A+4HqVvAYg= github.com/go-openapi/swag/typeutils v0.25.0 h1:iUTsxu3F3h9v6CBzVFGXKPSBQt6d8XXgYy1YAlu+HJ8= github.com/go-openapi/swag/typeutils v0.25.0/go.mod h1:9McMC/oCdS4BKwk2shEB7x17P6HmMmA6dQRtAkSnNb8= github.com/go-openapi/swag/yamlutils v0.25.0 h1:apgy77seWLEM9HKDcieIgW8bG9aSZgH6nQ9THlHYgHA= github.com/go-openapi/swag/yamlutils v0.25.0/go.mod h1:0JvBRtc0mR02IqHURUeGgS9cG+Dfms4FCGXCnsgnt7c= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gonvenience/bunt v1.4.2 h1:nTgkFZsw38SIJKABhLj8aXj2rqion9Zo1so/EBkbFBY= github.com/gonvenience/bunt v1.4.2/go.mod h1:WjyEO2rSYR+OLZg67Ucl+gjdXPs8GpFl63SCA02XDyI= github.com/gonvenience/idem v0.0.2 h1:jWHknjPfSbiWgYKre9wB2FhMgVLd1RWXCXzVq+7VIWg= github.com/gonvenience/idem v0.0.2/go.mod h1:0Xv1MpnNL40+dsyOxaJFa7L8ekeTRr63WaWXpiWLFFM= github.com/gonvenience/neat v1.3.16 h1:Vb0iCkSHGWaA+ry69RY3HpQ6Ooo6o/g2wjI80db8DjI= github.com/gonvenience/neat v1.3.16/go.mod h1:sLxdQNNluxbpROxTTHs3XBSJX8fwFX5toEULUy74ODA= github.com/gonvenience/term v1.0.4 h1:qkCGfmUtpzs9W4jWgNijaGF6dg3oSIh+kZCzT5cPNZY= github.com/gonvenience/term v1.0.4/go.mod h1:OzNdQC5NVBou9AifaHd1QG6EP8iDdpaT7GFm1bVgslg= github.com/gonvenience/text v1.0.9 h1:U29BxT3NZnNPcfiEnAwt6yHXe38fQs2Q+WTqs1X+atI= github.com/gonvenience/text v1.0.9/go.mod h1:JQF1ifXNRaa66jnPLqoITA+y8WATlG0eJzFC9ElJS3s= github.com/gonvenience/ytbx v1.4.7 h1:3wJ7EOfdv3Lg+h0mzKo7f8d1zMY1EJtVzzYrA3UhjHQ= github.com/gonvenience/ytbx v1.4.7/go.mod h1:ZmAU727eOTYeC4aUJuqyb9vogNAN7NiSKfw6Aoxbqys= github.com/google/gnostic-models v0.7.0 h1:qwTtogB15McXDaNqTZdzPJRHvaVJlAl+HVQnLmJEJxo= github.com/google/gnostic-models v0.7.0/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= github.com/gruntwork-io/go-commons v0.17.2 h1:14dsCJ7M5Vv2X3BIPKeG9Kdy6vTMGhM8L4WZazxfTuY= github.com/gruntwork-io/go-commons v0.17.2/go.mod h1:zs7Q2AbUKuTarBPy19CIxJVUX/rBamfW8IwuWKniWkE= github.com/gruntwork-io/terratest v0.51.0 h1:RCXlCwWlHqhUoxgF6n3hvywvbvrsTXqoqt34BrnLekw= github.com/gruntwork-io/terratest v0.51.0/go.mod h1:evZHXb8VWDgv5O5zEEwfkwMhkx9I53QR/RB11cISrpg= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/homeport/dyff v1.10.2 h1:XyB+D0KVwjbUFTZYIkvPtsImwkfh+ObH2CEdEHTqdr4= github.com/homeport/dyff v1.10.2/go.mod h1:0kIjL/JOGaXigzrLY6kcl5esSStbAa99r6GzEvr7lrs= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-ciede2000 v0.0.0-20170301095244-782e8c62fec3 h1:BXxTozrOU8zgC5dkpn3J6NTRdoP+hjok/e+ACr4Hibk= github.com/mattn/go-ciede2000 v0.0.0-20170301095244-782e8c62fec3/go.mod h1:x1uk6vxTiVuNt6S5R2UYgdhpj3oKojXvOXauHZ7dEnI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-zglob v0.0.6 h1:mP8RnmCgho4oaUYDIDn6GNxYk+qJGUs8fJLn+twYj2A= github.com/mattn/go-zglob v0.0.6/go.mod h1:MxxjyoXXnMxfIpxTK2GAkw1w8glPsQILx3N5wrKakiY= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc= github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/mitchellh/hashstructure v1.1.0 h1:P6P1hdjqAAknpY/M1CGipelZgp+4y9ja9kmUZPXP+H0= github.com/mitchellh/hashstructure v1.1.0/go.mod h1:xUDAozZz0Wmdiufv0uyhnHkUTN6/6d8ulp4AwfLKrmA= github.com/moby/spdystream v0.5.0 h1:7r0J1Si3QO/kjRitvSLVVFUjxMEb/YLj6S9FF62JBCU= github.com/moby/spdystream v0.5.0/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8= github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= github.com/onsi/gomega v1.38.0 h1:c/WX+w8SLAinvuKKQFh77WEucCnPk4j2OTUr7lt7BeY= github.com/onsi/gomega v1.38.0/go.mod h1:OcXcwId0b9QsE7Y49u+BTrL4IdKOBOKnD6VQNTJEB6o= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/texttheater/golang-levenshtein v1.0.1 h1:+cRNoVrfiwufQPhoMzB6N0Yf/Mqajr6t1lOv8GyGE2U= github.com/texttheater/golang-levenshtein v1.0.1/go.mod h1:PYAKrbF5sAiq9wd+H82hs7gNaen0CplQ9uvm6+enD/8= github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/virtuald/go-ordered-json v0.0.0-20170621173500-b18e6e673d74 h1:JwtAtbp7r/7QSyGz8mKUbYJBg2+6Cd7OjM8o/GNOcVo= github.com/virtuald/go-ordered-json v0.0.0-20170621173500-b18e6e673d74/go.mod h1:RmMWU37GKR2s6pgrIEB4ixgpVCt/cf7dnJv3fuH1J1c= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/exp v0.0.0-20250911091902-df9299821621 h1:2id6c1/gto0kaHYyrixvknJ8tUK/Qs5IsmBtrc+FtgU= golang.org/x/exp v0.0.0-20250911091902-df9299821621/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo= golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/evanphx/json-patch.v4 v4.13.0 h1:czT3CmqEaQ1aanPc5SdlgQrrEIb8w/wwCvWWnfEbYzo= gopkg.in/evanphx/json-patch.v4 v4.13.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/api v0.34.1 h1:jC+153630BMdlFukegoEL8E/yT7aLyQkIVuwhmwDgJM= k8s.io/api v0.34.1/go.mod h1:SB80FxFtXn5/gwzCoN6QCtPD7Vbu5w2n1S0J5gFfTYk= k8s.io/apimachinery v0.34.1 h1:dTlxFls/eikpJxmAC7MVE8oOeP1zryV7iRyIjB0gky4= k8s.io/apimachinery v0.34.1/go.mod h1:/GwIlEcWuTX9zKIg2mbw0LRFIsXwrfoVxn+ef0X13lw= k8s.io/client-go v0.34.1 h1:ZUPJKgXsnKwVwmKKdPfw4tB58+7/Ik3CrjOEhsiZ7mY= k8s.io/client-go v0.34.1/go.mod h1:kA8v0FP+tk6sZA0yKLRG67LWjqufAoSHA2xVGKw9Of8= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 h1:Y3gxNAuB0OBLImH611+UDZcmKS3g6CthxToOb37KgwE= k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d h1:wAhiDyZ4Tdtt7e46e9M5ZSAJ/MnPGPs+Ki1gHw4w1R0= k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= sigs.k8s.io/structured-merge-diff/v6 v6.3.0 h1:jTijUJbW353oVOd9oTlifJqOGEkUw2jB/fXCbTiQEco= sigs.k8s.io/structured-merge-diff/v6 v6.3.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= ================================================ FILE: contrib/charts/dragonfly/golden_test.go ================================================ package golden import ( "flag" "fmt" "os" "path/filepath" "regexp" "strings" "testing" "github.com/gruntwork-io/terratest/modules/helm" ) var update = flag.Bool("update", false, "update golden test output files") func TestHelmRender(t *testing.T) { files, err := os.ReadDir("./ci") if err != nil { t.Fatal(err) } for _, f := range files { if !f.IsDir() && strings.HasSuffix(f.Name(), ".yaml") && !strings.HasSuffix(f.Name(), ".golden.yaml") { // Render this values.yaml file output := helm.RenderTemplate(t, &helm.Options{ ValuesFiles: []string{"ci/" + f.Name()}, }, "../dragonfly", "test", nil, ) goldenFile := "ci/" + strings.TrimSuffix(f.Name(), filepath.Ext(".yaml")) + ".golden.yaml" regex := regexp.MustCompile(`\s+helm.sh/chart:\s+.*`) bytes := regex.ReplaceAll([]byte(output), []byte("")) output = fmt.Sprintf("%s\n", string(bytes)) if *update { err := os.WriteFile(goldenFile, []byte(output), 0644) if err != nil { t.Fatal(err) } } expected, err := os.ReadFile(goldenFile) if err != nil { t.Fatal(err) } if string(expected) != output { t.Fatalf("Expected %s, but got %s\n. Update golden files by running `go test -v ./... -update`", string(expected), output) } } } } ================================================ FILE: contrib/charts/dragonfly/templates/NOTES.txt ================================================ 1. Get the application URL by running these commands: {{- if contains "NodePort" .Values.service.type }} export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "dragonfly.fullname" . }}) export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}") echo http://$NODE_IP:$NODE_PORT {{- else if contains "LoadBalancer" .Values.service.type }} NOTE: It may take a few minutes for the LoadBalancer IP to be available. You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "dragonfly.fullname" . }}' export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "dragonfly.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}") echo http://$SERVICE_IP:{{ .Values.service.port }} {{- else if contains "ClusterIP" .Values.service.type }} export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "dragonfly.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}") export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}") echo "You can use redis-cli to connect against localhost:6379" kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 6379:$CONTAINER_PORT {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/_helpers.tpl ================================================ {{/* Expand the name of the chart. */}} {{- define "dragonfly.name" -}} {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} {{- end }} {{/* Create a default fully qualified app name. We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). If release name contains chart name it will be used as a full name. */}} {{- define "dragonfly.fullname" -}} {{- if .Values.fullnameOverride }} {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} {{- else }} {{- $name := default .Chart.Name .Values.nameOverride }} {{- if contains $name .Release.Name }} {{- .Release.Name | trunc 63 | trimSuffix "-" }} {{- else }} {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} {{- end }} {{- end }} {{- end }} {{/* Create chart name and version as used by the chart label. */}} {{- define "dragonfly.chart" -}} {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} {{- end }} {{/* Common labels */}} {{- define "dragonfly.labels" -}} helm.sh/chart: {{ include "dragonfly.chart" . }} {{ include "dragonfly.selectorLabels" . }} {{- if .Chart.AppVersion }} app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} {{- end }} app.kubernetes.io/managed-by: {{ .Release.Service }} {{- include "dragonfly.commonLabels" . }} {{- end }} {{/* User-defined common labels */}} {{- define "dragonfly.commonLabels" -}} {{- if .Values.commonLabels }} {{- range $key, $value := .Values.commonLabels }} {{ $key }}: {{ $value }} {{- end }} {{- end }} {{- end }} {{/* Selector labels */}} {{- define "dragonfly.selectorLabels" -}} app.kubernetes.io/name: {{ include "dragonfly.name" . }} app.kubernetes.io/instance: {{ .Release.Name }} {{- end }} {{/* Create the name of the service account to use */}} {{- define "dragonfly.serviceAccountName" -}} {{- if .Values.serviceAccount.create }} {{- default (include "dragonfly.fullname" .) .Values.serviceAccount.name }} {{- else }} {{- default "default" .Values.serviceAccount.name }} {{- end }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/_pod.tpl ================================================ {{- define "dragonfly.volumemounts" -}} {{- if or (.Values.storage.enabled) (.Values.extraVolumeMounts) (.Values.tls.enabled) }} volumeMounts: {{- if .Values.storage.enabled }} - mountPath: /data name: "{{ .Release.Name }}-data" {{- end }} {{- if and .Values.tls .Values.tls.enabled }} - mountPath: /etc/dragonfly/tls name: tls {{- end }} {{- with .Values.extraVolumeMounts }} {{- toYaml . | trim | nindent 2 }} {{- end }} {{- end }} {{- end }} {{- define "dragonfly.pod" -}} {{- if ne .Values.priorityClassName "" }} priorityClassName: {{ .Values.priorityClassName }} {{- end }} {{- with .Values.tolerations }} tolerations: {{- toYaml . | trim | nindent 2 -}} {{- end }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | trim | nindent 2 -}} {{- end }} {{- with .Values.affinity }} affinity: {{- toYaml . | trim | nindent 2 -}} {{- end }} serviceAccountName: {{ include "dragonfly.serviceAccountName" . }} {{- with .Values.imagePullSecrets }} imagePullSecrets: {{- toYaml . | trim | nindent 2 }} {{- end }} {{- with .Values.podSecurityContext }} securityContext: {{- toYaml . | trim | nindent 2 }} {{- end }} {{- if and (eq (typeOf .Values.hostNetwork) "bool") .Values.hostNetwork }} hostNetwork: true {{- end }} {{- with .Values.topologySpreadConstraints }} topologySpreadConstraints: {{- toYaml . | trim | nindent 2 }} {{- end }} {{- with .Values.initContainers }} initContainers: {{- if eq (typeOf .) "string" }} {{- tpl . $ | trim | nindent 2 }} {{- else }} {{- toYaml . | trim | nindent 2 }} {{- end }} {{- end }} containers: {{- with .Values.extraContainers }} {{- if eq (typeOf .) "string" -}} {{- tpl . $ | trim | nindent 2 }} {{- else }} {{- toYaml . | trim | nindent 2 }} {{- end }} {{- end }} - name: {{ .Chart.Name }} {{- with .Values.securityContext }} securityContext: {{- toYaml . | trim | nindent 6 }} {{- end }} image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: - name: dragonfly containerPort: 6379 protocol: TCP {{- with .Values.probes }} {{- toYaml . | trim | nindent 4 }} {{- end }} {{- with .Values.command }} command: {{- toYaml . | trim | nindent 6 }} {{- end }} args: - "--alsologtostderr" {{- with .Values.extraArgs }} {{- toYaml . | trim | nindent 6 }} {{- end }} {{- if .Values.tls.enabled }} - "--tls" - "--tls_cert_file=/etc/dragonfly/tls/tls.crt" - "--tls_key_file=/etc/dragonfly/tls/tls.key" {{- end }} {{- with .Values.resources }} resources: {{- toYaml . | trim | nindent 6 }} {{- end }} {{- include "dragonfly.volumemounts" . | trim | nindent 4 }} {{- if or .Values.passwordFromSecret.enable .Values.env }} env: {{- if .Values.passwordFromSecret.enable }} {{- $appVersion := .Chart.AppVersion | trimPrefix "v" }} {{- $imageTag := .Values.image.tag | trimPrefix "v" }} {{- $effectiveVersion := $appVersion }} {{- if and $imageTag (ne $imageTag "") }} {{- $effectiveVersion = $imageTag }} {{- end }} {{- if semverCompare ">=1.14.0" $effectiveVersion }} - name: DFLY_requirepass {{- else }} - name: DFLY_PASSWORD {{- end }} valueFrom: secretKeyRef: name: {{ tpl .Values.passwordFromSecret.existingSecret.name $ }} key: {{ .Values.passwordFromSecret.existingSecret.key }} {{- end }} {{- with .Values.env }} {{- toYaml . | trim | nindent 6 }} {{- end }} {{- end }} {{- with .Values.envFrom }} envFrom: {{- toYaml . | trim | nindent 6 }} {{- end }} {{- if or (.Values.tls.enabled) (.Values.extraVolumes) }} volumes: {{- if and .Values.tls .Values.tls.enabled }} {{- if .Values.tls.existing_secret }} - name: tls secret: secretName: {{ .Values.tls.existing_secret }} {{- else if .Values.tls.createCerts }} - name: tls secret: secretName: '{{ include "dragonfly.fullname" . }}-server-tls' {{- else }} - name: tls secret: secretName: {{ include "dragonfly.fullname" . }}-tls {{- end }} {{- end }} {{- with .Values.extraVolumes }} {{- toYaml . | trim | nindent 2 }} {{- end }} {{- end }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/certificate.yaml ================================================ {{- if and .Values.tls.enabled .Values.tls.createCerts }} apiVersion: cert-manager.io/v1 kind: Certificate metadata: name: {{ include "dragonfly.fullname" . }} namespace: {{ .Release.Namespace }} labels: {{- include "dragonfly.labels" . | nindent 4 }} spec: commonName: '{{ include "dragonfly.fullname" . }}' dnsNames: - '*.{{ include "dragonfly.fullname" . }}.{{ .Release.Namespace }}.svc.cluster.local' - '{{ include "dragonfly.fullname" . }}.{{ .Release.Namespace }}.svc.cluster.local' - '{{ include "dragonfly.fullname" . }}.{{ .Release.Namespace }}.svc' - '{{ include "dragonfly.fullname" . }}.{{ .Release.Namespace }}' - '{{ include "dragonfly.fullname" . }}' - localhost duration: {{ required "tls.duration is required, if createCerts is enabled" .Values.tls.duration }} ipAddresses: - 127.0.0.1 issuerRef: kind: {{ required "tls.issuer.kind is required, if createCerts is enabled" .Values.tls.issuer.kind }} name: {{ required "tls.issuer.name is required, if createCerts is enabled" .Values.tls.issuer.name }} group: {{ .Values.tls.issuer.group }} secretName: '{{ include "dragonfly.fullname" . }}-server-tls' usages: - client auth - server auth - signing - key encipherment {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/deployment.yaml ================================================ {{- if not .Values.storage.enabled }} apiVersion: apps/v1 kind: Deployment metadata: name: {{ include "dragonfly.fullname" . }} namespace: {{ .Release.Namespace }} labels: {{- include "dragonfly.labels" . | nindent 4 }} spec: replicas: {{ .Values.replicaCount }} selector: matchLabels: {{- include "dragonfly.selectorLabels" . | nindent 6 }} template: metadata: annotations: {{- if and (.Values.tls.enabled) (not .Values.tls.existing_secret) }} checksum/tls-secret: {{ include (print $.Template.BasePath "/tls-secret.yaml") . | sha256sum }} {{- end }} {{- with .Values.podAnnotations }} {{- toYaml . | nindent 8 }} {{- end }} labels: {{- include "dragonfly.selectorLabels" . | nindent 8 }} {{- if .Values.commonLabels }} {{- include "dragonfly.commonLabels" . | trim | nindent 8 }} {{- end }} spec: {{- include "dragonfly.pod" . | trim | nindent 6 }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/extra-manifests.yaml ================================================ {{ range .Values.extraObjects }} --- {{ tpl (toYaml .) $ }} {{ end }} ================================================ FILE: contrib/charts/dragonfly/templates/metrics-service.yaml ================================================ {{- if .Values.serviceMonitor.enabled }} apiVersion: v1 kind: Service metadata: name: {{ include "dragonfly.fullname" . }}-metrics namespace: {{ .Release.Namespace }} labels: {{- include "dragonfly.labels" . | nindent 4 }} type: metrics spec: type: {{ .Values.service.metrics.serviceType }} ports: - name: {{ .Values.service.metrics.portName }} port: {{ .Values.service.port }} targetPort: {{ .Values.service.port }} protocol: TCP selector: {{- include "dragonfly.selectorLabels" . | nindent 4 }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/prometheusrule.yaml ================================================ {{- if and ( .Capabilities.APIVersions.Has "monitoring.coreos.com/v1" ) .Values.serviceMonitor.enabled .Values.prometheusRule.enabled }} apiVersion: monitoring.coreos.com/v1 kind: PrometheusRule metadata: name: {{ template "dragonfly.fullname" . }}-metrics namespace: {{ .Values.prometheusRule.namespace | default .Release.Namespace }} labels: {{- include "dragonfly.labels" . | nindent 4 }} spec: groups: - name: {{ template "dragonfly.name" . }} rules: {{- toYaml .Values.prometheusRule.spec | nindent 6 }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/service.yaml ================================================ apiVersion: v1 kind: Service metadata: name: {{ include "dragonfly.fullname" . }} namespace: {{ .Release.Namespace }} {{- with .Values.service.annotations }} annotations: {{- toYaml . | nindent 4 }} {{- end }} labels: {{- with .Values.service.labels }} {{- toYaml . | nindent 4 }} {{- end }} {{- include "dragonfly.labels" . | nindent 4 }} spec: type: {{ .Values.service.type }} {{- if and (eq .Values.service.type "LoadBalancer") (ne .Values.service.loadBalancerIP "") }} loadBalancerIP: {{ .Values.service.loadBalancerIP }} {{- end }} {{- if and (eq .Values.service.type "ClusterIP") (ne .Values.service.clusterIP "") }} clusterIP: {{ .Values.service.clusterIP }} {{- end }} ports: - port: {{ .Values.service.port }} targetPort: dragonfly protocol: TCP name: dragonfly selector: {{- include "dragonfly.selectorLabels" . | nindent 4 }} ================================================ FILE: contrib/charts/dragonfly/templates/serviceaccount.yaml ================================================ {{- if .Values.serviceAccount.create -}} apiVersion: v1 kind: ServiceAccount metadata: name: {{ include "dragonfly.serviceAccountName" . }} namespace: {{ .Release.Namespace }} {{- with .Values.serviceAccount.annotations }} annotations: {{- toYaml . | nindent 4 }} {{- end }} labels: {{- include "dragonfly.labels" . | nindent 4 }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/servicemonitor.yaml ================================================ {{- if .Values.serviceMonitor.enabled }} apiVersion: monitoring.coreos.com/v1 kind: ServiceMonitor metadata: name: {{ template "dragonfly.fullname" . }}-metrics {{- if .Values.serviceMonitor.namespace }} namespace: {{ .Values.serviceMonitor.namespace }} {{- else }} namespace: {{ .Release.Namespace }} {{- end }} {{- with .Values.serviceMonitor.annotations }} annotations: {{- toYaml . | nindent 4 }} {{- end }} labels: {{- with .Values.serviceMonitor.labels }} {{- toYaml . | nindent 4 }} {{- end }} {{- include "dragonfly.labels" . | nindent 4 }} spec: endpoints: - interval: {{ .Values.serviceMonitor.interval }} {{- with .Values.serviceMonitor.scrapeTimeout }} scrapeTimeout: {{ . }} {{- end }} honorLabels: true port: {{ default "metrics" .Values.service.metrics.portName }} path: /metrics {{- if .Values.tls.enabled }} scheme: https tlsConfig: insecureSkipVerify: true {{- else }} scheme: http {{- end }} jobLabel: "{{ .Release.Name }}" selector: matchLabels: {{- include "dragonfly.selectorLabels" . | nindent 6 }} type: metrics namespaceSelector: matchNames: - {{ .Release.Namespace }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/statefulset.yaml ================================================ {{- if .Values.storage.enabled }} apiVersion: apps/v1 kind: StatefulSet metadata: name: {{ include "dragonfly.fullname" . }} namespace: {{ .Release.Namespace }} labels: {{- include "dragonfly.labels" . | nindent 4 }} spec: serviceName: {{ .Release.Name }} replicas: {{ .Values.replicaCount }} selector: matchLabels: {{- include "dragonfly.selectorLabels" . | nindent 6 }} template: metadata: annotations: {{- if and (.Values.tls.enabled) (not .Values.tls.existing_secret) }} checksum/tls-secret: {{ include (print $.Template.BasePath "/tls-secret.yaml") . | sha256sum }} {{- end }} {{- with .Values.podAnnotations }} {{- toYaml . | nindent 8 }} {{- end }} labels: {{- include "dragonfly.selectorLabels" . | nindent 8 }} {{- if .Values.commonLabels }} {{- include "dragonfly.commonLabels" . | trim | nindent 8 }} {{- end }} spec: {{- include "dragonfly.pod" . | trim | nindent 6 }} volumeClaimTemplates: - metadata: name: "{{ .Release.Name }}-data" spec: accessModes: [ "ReadWriteOnce" ] storageClassName: {{ .Values.storage.storageClassName }} resources: requests: storage: {{ .Values.storage.requests }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/templates/tls-secret.yaml ================================================ {{- if and (.Values.tls.enabled) (.Values.tls.cert) (.Values.tls.key) (not .Values.tls.existing_secret) }} apiVersion: v1 kind: Secret metadata: name: {{ include "dragonfly.fullname" . }}-tls namespace: {{ .Release.Namespace }} labels: {{- include "dragonfly.labels" . | nindent 4 }} type: kubernetes.io/tls data: tls.crt: {{ default "" .Values.tls.cert | b64enc | quote }} tls.key: {{ default "" .Values.tls.key | b64enc | quote }} {{- end }} ================================================ FILE: contrib/charts/dragonfly/values.yaml ================================================ # Default values for dragonfly. # This is a YAML-formatted file. # Declare variables to be passed into your templates. # -- Number of replicas to deploy replicaCount: 1 image: # -- Container Image Registry to pull the image from repository: docker.dragonflydb.io/dragonflydb/dragonfly # -- Dragonfly image pull policy pullPolicy: IfNotPresent # -- Overrides the image tag whose default is the chart appVersion. tag: "" # -- Container Registry Secret names in an array imagePullSecrets: [] # -- String to partially override dragonfly.fullname nameOverride: "" # -- String to fully override dragonfly.fullname fullnameOverride: "" # -- Common labels to add to all resources commonLabels: {} serviceAccount: # -- Specifies whether a service account should be created create: true # -- Annotations to add to the service account annotations: {} # -- The name of the service account to use. # If not set and create is true, a name is generated using the fullname template name: "" # -- Annotations for pods podAnnotations: {} # -- Set securityContext for pod itself podSecurityContext: {} # fsGroup: 2000 # -- Set securityContext for containers securityContext: {} # capabilities: # drop: # - ALL # readOnlyRootFilesystem: true # runAsNonRoot: true # runAsUser: 1000 # -- Set hostNetwork for pod hostNetwork: false service: # -- Service type to provision. Can be NodePort, ClusterIP or LoadBalancer type: ClusterIP # -- Load balancer static ip to use when service type is set to LoadBalancer loadBalancerIP: "" # -- Cluster IP address to assign to the service. Leave empty to auto-allocate clusterIP: "" # -- Dragonfly service port port: 6379 # -- Extra annotations for the service annotations: {} # -- Extra labels for the service labels: {} metrics: # -- name for the metrics port portName: metrics # -- serviceType for the metrics service serviceType: ClusterIP serviceMonitor: # -- If true, a ServiceMonitor CRD is created for a prometheus operator enabled: false # -- namespace in which to deploy the ServiceMonitor CR. defaults to the application namespace namespace: "" # -- additional labels to apply to the metrics labels: {} # -- additional annotations to apply to the metrics annotations: {} # -- scrape interval interval: 10s # -- scrape timeout scrapeTimeout: 10s prometheusRule: # -- Deploy a PrometheusRule enabled: false # -- PrometheusRule.Spec # https://awesome-prometheus-alerts.grep.to/rules spec: [] storage: # -- If /data should persist. This will provision a StatefulSet instead. enabled: false # -- Global StorageClass for Persistent Volume(s) storageClassName: "" # -- Volume size to request for the PVC requests: 128Mi tls: # -- enable TLS enabled: false # -- use cert-manager to automatically create the certificate createCerts: false # -- duration or ttl of the validity of the created certificate duration: 87600h0m0s issuer: # -- cert-manager issuer kind. Usually Issuer or ClusterIssuer kind: ClusterIssuer # -- name of the referenced issuer name: selfsigned # -- group of the referenced issuer # if you are using an external issuer, change this to that issuer group. group: cert-manager.io # -- use TLS certificates from existing secret existing_secret: "" # -- TLS certificate cert: "" # cert: | # -----BEGIN CERTIFICATE----- # MIIDazCCAlOgAwIBAgIUfV3ygaaVW3+yzK5Dq6Aw6TsZ494wDQYJKoZIhvcNAQEL # ... # BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM # zJAL4hNw4Tr6E52fqdmX # -----END CERTIFICATE----- # -- TLS private key key: "" # key: | # -----BEGIN RSA PRIVATE KEY----- # MIIEpAIBAAKCAQEAxeD5iQGQpCUlksFvjzzAxPTw6DMJd3MpifV+HoBY4LiTyDer # ... # HLunol88AeTOcKfD6hBYGvcRfu5NV29jJxZCOBfbFQXjnNlnrhRCag== # -----END RSA PRIVATE KEY----- # If enabled will set DFLY_PASSWORD environment variable with the specified existing secret value # Note that if enabled and the secret does not exist pods will not start passwordFromSecret: enable: false existingSecret: name: "" key: "" probes: livenessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh initialDelaySeconds: 10 periodSeconds: 10 timeoutSeconds: 5 failureThreshold: 3 successThreshold: 1 readinessProbe: exec: command: - /bin/sh - /usr/local/bin/healthcheck.sh initialDelaySeconds: 10 periodSeconds: 10 timeoutSeconds: 5 failureThreshold: 3 successThreshold: 1 # -- Allow overriding the container's command command: [] # -- Extra arguments to pass to the dragonfly binary extraArgs: [] # -- Extra volumes to mount into the pods extraVolumes: [] # -- Extra volume mounts corresponding to the volumes mounted above extraVolumeMounts: [] # -- A list of initContainers to run before each pod starts initContainers: [] # -- Additional sidecar containers extraContainers: [] # -- extra K8s manifests to deploy extraObjects: [] # - apiVersion: cert-manager.io/v1 # kind: ClusterIssuer # metadata: # name: selfsigned # spec: # selfSigned: {} resources: # -- The requested resources for the containers requests: {} # cpu: 100m # memory: 128Mi # -- The resource limits for the containers limits: {} # cpu: 100m # memory: 128Mi # -- extra environment variables env: [] # -- extra environment variables from K8s objects envFrom: [] # -- Priority class name for pod assignment priorityClassName: "" # -- Node labels for pod assignment nodeSelector: {} # -- Tolerations for pod assignment tolerations: [] # -- Affinity for pod assignment affinity: {} # -- Topology Spread Constraints for pod assignment topologySpreadConstraints: [] ================================================ FILE: contrib/docker/README.md ================================================

Dragonfly

# Dragonfly DB with Docker Compose This guide will have you up running DragonflyDB with `docker-compose` in just a few minutes. | This guide assumes you have `docker` and `docker-compose` installed on your machine. If not, [Install Docker](https://docs.docker.com/get-docker/) and [Install Docker Compose](https://docs.docker.com/compose/install/) before continuing. ## Step 1 ```bash # Download Official Dragonfly DB Docker Compose File wget https://raw.githubusercontent.com/dragonflydb/dragonfly/main/contrib/docker/docker-compose.yml # Launch the Dragonfly DB Instance docker-compose up -d # Confirm image is up docker ps | grep dragonfly # ac94b5ba30a0 docker.dragonflydb.io/dragonflydb/dragonfly "entrypoint.sh drago…" 45 seconds ago Up 31 seconds 0.0.0.0:6379->6379/tcp, :::6379->6379/tcp docker_dragonfly_1 # Log follow the dragonfly container docker logs -f docker_dragonfly_1 ``` Dragonfly DB will answer to both `http` and `redis` requests out of the box! You can use `redis-cli` to connect to `localhost:6379` or open a browser and visit `http://localhost:6379` ## Step 2 Connect with a redis client. From a new terminal: ```bash redis-cli 127.0.0.1:6379> set hello world OK 127.0.0.1:6379> keys * 1) "hello" 127.0.0.1:6379> get hello "world" 127.0.0.1:6379> ``` ## Step 3 Continue being great and build your app with the power of DragonflyDB! ## Tuning Dragonfly DB If you are attempting to tune Dragonfly DB for performance, consider `NAT` performance costs associated with containerization. > ## Performance Tuning > --- > In `docker-compose`, there is a meaningful difference between an `overlay` network(which relies on docker `NAT` traversal on every request) and using the `host` network(see [`docker-compose.yml`](https://github.com/dragonflydb/dragonfly/blob/main/contrib/docker/docker-compose.yml)). >   > Fore more information, see the [official docker-compose network_mode Docs](https://docs.docker.com/compose/compose-file/compose-file-v3/#network_mode) >   ### More Build Options - [Docker Quick Start](/docs/quick-start/) - [Kubernetes Deployment with Helm Chart](/contrib/charts/dragonfly/) - [Build From Source](/docs/build-from-source.md) ================================================ FILE: contrib/docker/docker-compose.yml ================================================ services: dragonfly: image: 'docker.dragonflydb.io/dragonflydb/dragonfly' ulimits: memlock: -1 ports: - "6379:6379" # For better performance, consider `host` mode instead `port` to avoid docker NAT. # `host` mode is NOT currently supported in Swarm Mode. # https://docs.docker.com/compose/compose-file/compose-file-v3/#network_mode # network_mode: "host" volumes: - dragonflydata:/data volumes: dragonflydata: ================================================ FILE: contrib/scripts/conventional-commits ================================================ #!/usr/bin/env bash # list of Conventional Commits types cc_types=("feat" "fix") default_types=("build" "chore" "ci" "docs" "${cc_types[@]}" "perf" "refactor" "revert" "style" "test") types=( "${cc_types[@]}" ) if [ $# -eq 1 ]; then types=( "${default_types[@]}" ) else while [ $# -gt 1 ]; do types+=( "$1" ) shift done fi msg_file="$1" r_types="($(IFS='|'; echo "${types[*]}"))" r_scope="(\([[:alnum:] \/-]+\))?" r_delim='!?:' r_subject=" [[:print:]].+" pattern="^$r_types$r_scope$r_delim$r_subject$" if grep -Eq "$pattern" "$msg_file"; then exit 0 fi echo "[Commit message] $( cat "$msg_file" )" echo " Thank you for your interest in Dragonfly DB. To keep things clean, we ask all commits to meet the following criteria: - Be Signed (git commit -s -m ...) - Valid Conventional Commit https://www.conventionalcommits.org/ Special Commit Words are correlated to versioning. Specifically \"fix\" and \"feat\" - fix: a commit of the type fix patches a bug in your codebase (this correlates with PATCH in Semantic Versioning). - feat: a commit of the type feat introduces a new feature to the codebase (this correlates with MINOR in Semantic Versioning). - Breaking changes have a ! before the \":\" Finally, If there is an Issue for this Commit, Please add it to the end of the commit message. - Reference Issue Number at End of Commit Message (Optional) Thank you for helping us label a \`fix\` and \`feat\` properly so that our commits, issues and semantic versioning are all aligned! A Signed Conventional Commit with Issue Number look like: git commit -s -m \"type(scope): description #112\" Valid types: $(IFS=' '; echo "${types[*]}") Example Document Change: docs(readme): Fix Example Links #121 Example Breaking New Feature feat(ingest)!: Add new ingest # 122 This is an example of a fix with an Issue # fix(ingest): Refactor for loop to list comprehension #123 Thank you for your contribution! Sincerely, The Dragonfly DB Contributors " exit 1 ================================================ FILE: contrib/scripts/signed-commit ================================================ #!/usr/bin/env bash if [[ -z "$1" ]] || [[ ! -f "$1" ]]; then echo "ERROR: Commit message file not provided or does not exist." exit 1 fi # Check if signed-off-by line is present (automatically added using -s flag) if ! grep -q 'Signed-off-by:' "$1"; then echo "ERROR: Commit message must contain a Signed-off-by line." echo "" echo "To sign your commits, use the -s flag:" echo " git commit -s -m \"your commit message\"" exit 1 fi exit 0 ================================================ FILE: docs/README.md ================================================

Dragonfly

# Quick Start The easiest way to get started with Dragonfly is with Docker. ## Deployment Method First, choose a deployment method. If you are new to Dragonfly, we recommend the [DragonflyDB Docker Quick Start Guide](/docs/quick-start/) Other options: ### - [Docker Compose](/contrib/docker/) ### - [Helm Chart for Kubernetes](/contrib/charts/dragonfly/) # Learn About DragonflyDB ## [FAQ](/docs/faq.md) ## [Differences Between DragonflyDB and Redis](/docs/differences.md) ## [API Commands Reference](https://dragonflydb.io/docs/category/command-reference) ================================================ FILE: docs/async-tiering.md ================================================ # Async Tiering Design Document ## Background Our current tiered storage component performs disk operations inline as part of executing shard-local operations. This approach introduces latency when processing commands, impacting both the system's throughput and overall command latency. The following document discusses a potential redesign that addresses this issue and enables the execution of operations without I/O blocking. ```mermaid graph LR %% Left Side: No Tiering subgraph S1 [Shard queue no tiering] direction TB A1[get] --- B1[set] B1 --- C1[get] C1 --- D1[" "] end %% Spacing and Arrows S1 --- Space1[ ] Space1 -.-> Space2[ ] Space2 --- S2 %% Right Side: With Tiering subgraph S2 [Shard queue with tiering] direction TB A2["get
I/O read"] --- B2["set
I/O write"] B2 --- C2["get
I/O read"] C2 --- D2[" "] end %% Styling style S1 fill:#fff,stroke:#ffcc00,stroke-width:2px style S2 fill:#fff,stroke:#ffcc00,stroke-width:2px style A1 fill:#fff,stroke:#ffcc00 style B1 fill:#fff,stroke:#ffcc00 style C1 fill:#fff,stroke:#ffcc00 style D1 fill:none,stroke:none style A2 fill:#fff,stroke:#ffcc00 style B2 fill:#fff,stroke:#ffcc00 style C2 fill:#fff,stroke:#ffcc00 style D2 fill:none,stroke:none %% Hide the spacer nodes style Space1 fill:none,stroke:none style Space2 fill:none,stroke:none ``` ## High level design The core goal is to perform tiered I/O operations concurrently while maintaining transparency for the transaction framework designed for instant RAM operations. Transactions issue asynchronous requests to the tiered storage, returning futures that the coordinating fiber awaits. Operations on the same key execute strictly in order, relying on the transactional framework for correctness, while operations on different keys can be interleaved for efficiency. ### The following diagram depicts a simplified flow for a GET operation: ```mermaid sequenceDiagram participant Coordinator participant Shard participant Disk Coordinator->>Shard: Get Shard->>Disk: IO_Read Shard-->>Coordinator: ResultFuture Disk-->>Shard: ReadCallback Shard-->>Coordinator: ResultFulfilled ``` The coordinator fiber schedules a command on a shard thread. The command performs initial work, issues an asynchronous read, and returns a `ResultFuture` to the coordinator. The coordinator waits for fulfillment before replying. This parallelism hides most I/O latency (assuming non-saturated SSDs). For complex operations like `APPEND`, that require reading the value and modifying it, a post-read handler runs on the shard thread. Since in-place disk modification isn't supported, `APPEND` becomes an IO-READ followed by a handler that modifies the value in memory. The result is returned to the coordinator and the modified value is uploaded to memory and is deleted on disk. It is important to note that only a single read is issued for all pending asynchronous commands for a given key. Once the read finished, all callbacks are executed consecutively and atomically. This guarantees correctness of operation order and outside observers. This execution loops is aided by specialized Decoder classes that keep an intermediary value in-between modifications or avoid creating it at all for read-only sequences. Unlike the previous design where `DbSlice::Find(...)` handled tiering transparently, command implementations handling offloaded values must now use callbacks or futures (e.g., via `TieredStorage::Read` or `Modify`). ### Tiered Storage Component The `TieredStorage` component manages the lifecycle of offloaded items. Externalized blobs are immutable on disk; operations involve stashing new blobs, reading existing ones, or marking them for deletion. #### Upstream API (TieredStorage) The primary interface used by commands includes: 1. `Read(DbIndex, Key, Value) -> Future`: Asynchronously fetch an offloaded value. 2. `Modify(DbIndex, Key, Value, ModFunc) -> Future`: Fetch, modify in memory (via callback), and update. 3. `TryStash(DbIndex, Key, Value) -> Future`: Schedule a value for offloading. 4. `Delete(DbIndex, Value)`: Remove offloaded value. 5. `CancelStash(DbIndex, Key, Value)`: Start cancelling a pending stash operation. #### Downstream API (DiskStorage) `DiskStorage` handles file management and async I/O: 1. `Read(DiskSegment, ReadCb)`: Read a segment from the backing file. 2. `PrepareStash(Length) -> Result>`: Allocate a segment and prepare a buffer. 3. `Stash(DiskSegment, UringBuf, StashCb)`: Write the buffer to the allocated segment. 4. `MarkAsFree(DiskSegment)`: Mark a segment for reuse. `DiskStorage` manages the underlying file growth and page allocation via an `ExternalAllocator`. ```mermaid graph TB subgraph Commands["called by commands or db_slice"] READ[READ] REMOVE[REMOVE] STASH[STASH] end subgraph TieredStorage["TieredStorage"] %% Invisible node to act as a landing point for the box TS_TOP[ ]:::invisible PR[pending reads
+ remove?
offset -> futures] PS[pending stashes
key -> version] TS_BOTTOM[ ]:::invisible end subgraph DiskStorage["DiskStorage"] DS_TOP[ ]:::invisible EA[external
allocator] IM[io manager] end %% Interactions between Commands and TieredStorage READ -.-> |"Future<string>"| TS_TOP TS_TOP -.-> READ REMOVE -.-> TS_TOP STASH -.-> TS_TOP %% Interactions between TieredStorage and DiskStorage TS_BOTTOM -.-> |"callback based i/o operations"| DS_TOP DS_TOP -.-> TS_BOTTOM %% Notes Note1[pending reads for a specific
offset are tracked to avoid
duplicate reads and removal
of segments still in use] Note2[pending stashes use incremental
versions to discard results of
outdated operations] Note1 -.-> TieredStorage Note2 -.-> TieredStorage %% Styling to make landing nodes invisible classDef invisible fill:none,stroke:none,color:none,width:0px,height:0px; ``` Consider, for example, two high level `Read` operations for two different keys K1 and K2 residing on the same page. For K1, we issue a page read from `DiskStorage` tracked by its offset. For K2, if we check and find an active operation fetching that offset, we link the K2 callback to the K1 completion, avoiding duplicate I/O. Consider issuing a `Read` request for a key (e.g., during `GET`). This triggers a disk read for the corresponding page. If `Delete` is called for the same key (e.g., via `DEL` or `SET` overwriting the key) while the read is in progress, we must be careful. Immediately calling `DiskStorage::MarkAsFree` could allow a subsequent `Stash` to overwrite the page while it's being read. To prevent this race condition, `MarkAsFree` calls are queued until concurrent reads on the affected segment complete. These problems do not exist for `Stash` operations because they write to newly allocated pages that no other actor references yet. ## API->Ops translation table Those that require I/O are colored in **bold**. | API Sequence | I/O Ops Sequence | Explanation | |---|---|---| | `SET` (overwrite) | `Delete` | We remove the reference to the blob stashed on disk. No overwrite of existing entry. | | `GET` | **`Read`**, `Delete` (optional) | Reads the value. Depending on policy, we might then remove the blob from storage and keep it in RAM ("warm up"). | | `DEL`, `GET` | `Delete` | `DEL` removes the entry. Subsequent `GET` won't find it in TieredStorage. | | `APPEND` | **`Read`**, `Delete` | Modify not done in place. Read to memory, append, then remove old disk entry. | | `GET`, `SET` | **`Read`**, `Delete` | `GET` triggers `Read`. `SET` triggers `Delete`. If `Read` is in-flight, `DiskStorage::MarkAsFree` is delayed until `Read` completes to avoid reusing the page prematurely. | | `SET`, `DEL` | **`TryStash`**, `Delete` | `SET` may be followed by `TryStash` in case we decide to offload an in-memory entry. In case `DEL` is processed when stash is still in flight, `CancelStash()` will be called. Otherwise, `MarkAsFree` will be called to mark the page as available. | ================================================ FILE: docs/cluster-node-health.md ================================================ # Cluster Node Health **Node health is passive metadata provided by the cluster manager (control plane) via the `DFLYCLUSTER CONFIG` command.** Dragonfly nodes do not actively determine their own health status; instead, the cluster orchestrator monitors node states and communicates health information to each node through the cluster configuration. Dragonfly supports node health status reporting for cluster configurations, providing Valkey-compatible behavior for cluster management commands. This feature allows the cluster manager to track the health state of each node and communicate it to clients through various cluster commands. ## Overview The node health feature was introduced in [PR #4758](https://github.com/dragonflydb/dragonfly/pull/4758) and [PR #4767](https://github.com/dragonflydb/dragonfly/pull/4767) to address [issue #4741](https://github.com/dragonflydb/dragonfly/issues/4741). The health status is part of the cluster configuration and can be set for both master and replica nodes. Different cluster commands use this information to filter or display nodes based on their health state. ## Health States Dragonfly supports four health states for cluster nodes: | State | Description | Visible in Commands | |-----------|-------------------------------------------------------------------------------------------|---------------------| | `online` | Node is fully operational and ready to serve requests | All commands | | `loading` | Node is still loading data (e.g., during initial sync or restart) | `CLUSTER SHARDS`, `CLUSTER NODES` | | `fail` | Node has failed or is unreachable | `CLUSTER SHARDS`, `CLUSTER NODES` | | `hidden` | Replica exists but should not be exposed to clients (internal use by cluster manager) | Masters: all commands; Replicas: none | ### Default State When no health status is specified in the configuration, nodes default to the `online` state. ## Configuration Node health is specified in the cluster configuration JSON that is passed via the `DFLYCLUSTER CONFIG` command. The health status is set using the `health` field for each node. ### Configuration Format ```json [ { "slot_ranges": [ { "start": 0, "end": 16383 } ], "master": { "id": "node-master-1", "ip": "10.0.0.1", "port": 7000, "health": "online" }, "replicas": [ { "id": "node-replica-1", "ip": "10.0.0.2", "port": 7001, "health": "online" }, { "id": "node-replica-2", "ip": "10.0.0.3", "port": 7002, "health": "loading" }, { "id": "node-replica-3", "ip": "10.0.0.4", "port": 7003, "health": "fail" }, { "id": "node-replica-4", "ip": "10.0.0.5", "port": 7004, "health": "hidden" } ] } ] ``` ### Setting Configuration Use the `DFLYCLUSTER CONFIG` command to set the cluster configuration with health information: ```bash DFLYCLUSTER CONFIG ``` The health field is optional and case-insensitive. Valid values are: `online`, `loading`, `fail`, and `hidden`. ## Command Behavior Different cluster commands handle node health status in different ways: ### CLUSTER SHARDS The `CLUSTER SHARDS` command returns detailed information about cluster shards, including the health status of all nodes except those marked as `hidden`. **Example:** ```bash 127.0.0.1:6379> CLUSTER SHARDS 1) 1) "slots" 2) 1) (integer) 0 2) (integer) 16383 3) "nodes" 4) 1) 1) "id" 2) "node-master-1" 3) "endpoint" 4) "10.0.0.1" 5) "ip" 6) "10.0.0.1" 7) "port" 8) (integer) 7000 9) "role" 10) "master" 11) "replication-offset" 12) (integer) 0 13) "health" 14) "online" 2) 1) "id" 2) "node-replica-1" 3) "endpoint" 4) "10.0.0.2" 5) "ip" 6) "10.0.0.2" 7) "port" 8) (integer) 7001 9) "role" 10) "replica" 11) "replication-offset" 12) (integer) 0 13) "health" 14) "online" 3) 1) "id" 2) "node-replica-2" 3) "endpoint" 4) "10.0.0.3" 5) "ip" 6) "10.0.0.3" 7) "port" 8) (integer) 7002 9) "role" 10) "replica" 11) "replication-offset" 12) (integer) 0 13) "health" 14) "loading" 4) 1) "id" 2) "node-replica-3" 3) "endpoint" 4) "10.0.0.4" 5) "ip" 6) "10.0.0.4" 7) "port" 8) (integer) 7003 9) "role" 10) "replica" 11) "replication-offset" 12) (integer) 0 13) "health" 14) "fail" ``` **Note:** Nodes with `hidden` health status are filtered out and do not appear in the output. ### CLUSTER SLOTS The `CLUSTER SLOTS` command returns slot distribution information. This command filters out replicas that are not ready to serve requests. **Filtering behavior:** - Includes replicas with `online` health status - Excludes replicas with `loading`, `fail`, or `hidden` health status **Example:** ```bash 127.0.0.1:6379> CLUSTER SLOTS 1) 1) (integer) 0 2) (integer) 16383 3) 1) "10.0.0.1" 2) (integer) 7000 3) "node-master-1" 4) 1) "10.0.0.2" 2) (integer) 7001 3) "node-replica-1" ``` In this example, only the master and the `online` replica (`node-replica-1`) are shown. Replicas with `loading`, `fail`, or `hidden` status are not included. ### CLUSTER NODES The `CLUSTER NODES` command returns a list of all cluster nodes in a space-separated format. This command shows nodes with most health states but excludes `hidden` nodes. **Connection state mapping:** - `online` and `loading` nodes: shown as `connected` - `fail` nodes: shown as `disconnected` - `hidden` nodes: not shown in output **Example:** ```bash 127.0.0.1:6379> CLUSTER NODES node-master-1 10.0.0.1:7000@7000 master - 0 0 0 connected 0-16383 node-replica-1 10.0.0.2:7001@7001 slave node-master-1 0 0 0 connected node-replica-2 10.0.0.3:7002@7002 slave node-master-1 0 0 0 connected node-replica-3 10.0.0.4:7003@7003 slave node-master-1 0 0 0 disconnected ``` **Note:** - `node-replica-1` (online): appears as `connected` - `node-replica-2` (loading): appears as `connected` - `node-replica-3` (fail): appears as `disconnected` - `node-replica-4` (hidden): not shown in output ## Use Cases ### 1. Gradual Node Addition When adding a new replica to a cluster, you can set its health status to `loading` while it's syncing data. This allows the cluster manager to track the node but prevents clients from redirecting read requests to it via `CLUSTER SLOTS`. ### 2. Failed Node Handling When a node fails or becomes unreachable, the cluster manager can mark it as `fail`. This provides visibility in `CLUSTER SHARDS` and `CLUSTER NODES` while excluding it from `CLUSTER SLOTS` responses. ### 3. Internal Replicas The `hidden` health status is useful for replica nodes that are managed internally by the cluster orchestrator but should not be visible to external clients. Hidden replicas are filtered out from all cluster commands (`CLUSTER SHARDS`, `CLUSTER SLOTS`, and `CLUSTER NODES`). Note that masters marked as `hidden` are still visible in all commands; the filtering only applies to replicas. ### 4. Valkey Compatibility This feature provides Valkey-compatible behavior for cluster client APIs: - `CLUSTER SHARDS` returns the health status of replica nodes - `CLUSTER SLOTS` does not return replicas that have not finished loading ## Implementation Details For developers interested in the implementation: 1. **Data Structure**: The `NodeHealth` enum is defined in `src/server/cluster/cluster_defs.h` with four values: `FAIL`, `LOADING`, `ONLINE`, and `HIDDEN`. 2. **Configuration Parsing**: Health status is parsed from JSON in `src/server/cluster/cluster_config.cc` in the `ParseClusterNode` function. 3. **Command Handlers**: The cluster commands in `src/server/cluster/cluster_family.cc` implement filtering logic based on health status: - `ClusterShards`: Filters out replicas with `HIDDEN` health before calling `ClusterShardsImpl` (masters are still included even if marked `HIDDEN`) - `ClusterSlotsImpl`: Filters out `HIDDEN`, `FAIL`, and `LOADING` replicas (masters are always included) - `ClusterNodesImpl`: Filters out replicas with `HIDDEN` health when listing replicas (masters with `HIDDEN` health are still included) and maps health to connection state 4. **Default Value**: When not specified in configuration, nodes default to `ONLINE` state as defined in `ClusterExtendedNodeInfo`. ## See Also - [Dragonfly Cluster Mode Documentation](https://www.dragonflydb.io/docs/cluster) - [CLUSTER SHARDS Command](https://redis.io/commands/cluster-shards/) - [CLUSTER SLOTS Command](https://redis.io/commands/cluster-slots/) - [CLUSTER NODES Command](https://redis.io/commands/cluster-nodes/) ================================================ FILE: docs/coordinator.excalidraw ================================================ { "type": "excalidraw", "version": 2, "source": "https://excalidraw.com", "elements": [ { "type": "rectangle", "version": 498, "versionNonce": 987480120, "isDeleted": false, "id": "jPwIU_a9_nxvuDFAcbzxM", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "dotted", "roughness": 1, "opacity": 100, "angle": 0, "x": 712.375, "y": 510.2500000000001, "strokeColor": "#000000", "backgroundColor": "#15aabf", "width": 307, "height": 30, "seed": 1029717964, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "type": "text", "id": "U2-I9a2X4amHnB7NZFWGv" }, { "id": "MJoeQ6ylkFi5Z7UCzD-r-", "type": "arrow" }, { "id": "KpIRIBeGsT3yzCPp6jbEN", "type": "arrow" }, { "id": "Qnatw_Uix7cMFwAuW1DkJ", "type": "arrow" }, { "id": "TLS6mZEI7BXyUdiiYHdrg", "type": "arrow" }, { "id": "h_hyKP8N7nmD1NiZNa3ez", "type": "arrow" }, { "id": "CrT6zZ8CKm_MSDw-CmcPG", "type": "arrow" } ], "updated": 1660733356396, "link": null, "locked": false }, { "type": "text", "version": 389, "versionNonce": 1321365816, "isDeleted": false, "id": "U2-I9a2X4amHnB7NZFWGv", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 717.375, "y": 515.2500000000001, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 297, "height": 20, "seed": 1592449524, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1660733269433, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "coordinator", "baseline": 14, "textAlign": "center", "verticalAlign": "middle", "containerId": "jPwIU_a9_nxvuDFAcbzxM", "originalText": "coordinator" }, { "type": "rectangle", "version": 469, "versionNonce": 684925752, "isDeleted": false, "id": "BY5OdEEKT0Y_DTy9Zgr9C", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 714.375, "y": 217.41666666666669, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 77, "height": 192, "seed": 1621471436, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "MJoeQ6ylkFi5Z7UCzD-r-", "type": "arrow" }, { "id": "KpIRIBeGsT3yzCPp6jbEN", "type": "arrow" } ], "updated": 1660733316757, "link": null, "locked": false }, { "type": "text", "version": 113, "versionNonce": 2140069448, "isDeleted": false, "id": "45U617mr0L9ob4mc7Xozt", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 724.875, "y": 171.0865384615385, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 56, "height": 40, "seed": 1285924468, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1660733195706, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "shard 1\n", "baseline": 34, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "shard 1\n" }, { "type": "text", "version": 123, "versionNonce": 738921016, "isDeleted": false, "id": "vY-LnNlhD3qWMEtRPoU0t", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 840.4375, "y": 171.0865384615385, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 64, "height": 20, "seed": 817296972, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1660733195706, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "shard 2", "baseline": 14, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "shard 2" }, { "type": "rectangle", "version": 499, "versionNonce": 1256651064, "isDeleted": false, "id": "xvkm28eoejETjF3M78jpN", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 943.125, "y": 221.875, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 77, "height": 187, "seed": 1482008524, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "h_hyKP8N7nmD1NiZNa3ez", "type": "arrow" }, { "id": "CrT6zZ8CKm_MSDw-CmcPG", "type": "arrow" } ], "updated": 1660733356396, "link": null, "locked": false }, { "type": "text", "version": 193, "versionNonce": 731710264, "isDeleted": false, "id": "H72xWL9unzb1mQiLvx7L4", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 950.125, "y": 176.7115384615385, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 63, "height": 20, "seed": 1704611020, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1660733195706, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "shard 3", "baseline": 14, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "shard 3" }, { "type": "rectangle", "version": 547, "versionNonce": 1963108408, "isDeleted": false, "id": "jj-MVcNrzcH0DbFFo9noF", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 833.9375, "y": 221.16666666666669, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 77, "height": 193, "seed": 1374694167, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "Qnatw_Uix7cMFwAuW1DkJ", "type": "arrow" }, { "id": "TLS6mZEI7BXyUdiiYHdrg", "type": "arrow" } ], "updated": 1660733333008, "link": null, "locked": false }, { "id": "MJoeQ6ylkFi5Z7UCzD-r-", "type": "arrow", "x": 717.875, "y": 501.1682692307693, "width": 24, "height": 87, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 6593352, "version": 99, "versionNonce": 1021163848, "isDeleted": false, "boundElements": null, "updated": 1660733308793, "link": null, "locked": false, "points": [ [ 0, 0 ], [ -24, -44 ], [ -3, -87 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "jPwIU_a9_nxvuDFAcbzxM", "focus": -0.8341352911917994, "gap": 9.08173076923083 }, "endBinding": { "elementId": "BY5OdEEKT0Y_DTy9Zgr9C", "focus": -0.13122256675640864, "gap": 4.751602564102598 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "KpIRIBeGsT3yzCPp6jbEN", "type": "arrow", "x": 752.875, "y": 419.1682692307693, "width": 16, "height": 90, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 1407934264, "version": 74, "versionNonce": 1205666632, "isDeleted": false, "boundElements": null, "updated": 1660733316764, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 7, 42 ], [ -9, 90 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "BY5OdEEKT0Y_DTy9Zgr9C", "focus": 0.3233993962204972, "gap": 9.751602564102598 }, "endBinding": { "elementId": "jPwIU_a9_nxvuDFAcbzxM", "focus": -0.8035367629216211, "gap": 1.0817307692308304 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "Qnatw_Uix7cMFwAuW1DkJ", "type": "arrow", "x": 837.875, "y": 506.1682692307693, "width": 7, "height": 83, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 1927132472, "version": 74, "versionNonce": 1840565576, "isDeleted": false, "boundElements": null, "updated": 1660733325799, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 7, -83 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "jPwIU_a9_nxvuDFAcbzxM", "focus": -0.191317746711659, "gap": 4.0817307692308304 }, "endBinding": { "elementId": "jj-MVcNrzcH0DbFFo9noF", "focus": 0.4002005378587657, "gap": 9.001602564102598 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "TLS6mZEI7BXyUdiiYHdrg", "type": "arrow", "x": 872.875, "y": 423.1682692307693, "width": 13, "height": 82, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 247434040, "version": 76, "versionNonce": 1827860040, "isDeleted": false, "boundElements": null, "updated": 1660733333013, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 9, 41 ], [ -4, 82 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "jj-MVcNrzcH0DbFFo9noF", "focus": 0.38070164408537926, "gap": 9.001602564102598 }, "endBinding": { "elementId": "jPwIU_a9_nxvuDFAcbzxM", "focus": -0.02127803036140877, "gap": 5.0817307692308304 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "h_hyKP8N7nmD1NiZNa3ez", "type": "arrow", "x": 995.875, "y": 418.1682692307693, "width": 13, "height": 90, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 2138692424, "version": 57, "versionNonce": 178091592, "isDeleted": false, "boundElements": null, "updated": 1660733348048, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 12, 47 ], [ -1, 90 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "xvkm28eoejETjF3M78jpN", "focus": 0.19231425235177602, "gap": 9.293269230769283 }, "endBinding": { "elementId": "jPwIU_a9_nxvuDFAcbzxM", "focus": 0.7835976013538369, "gap": 2.0817307692308304 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "CrT6zZ8CKm_MSDw-CmcPG", "type": "arrow", "x": 957.875, "y": 502.1682692307693, "width": 18, "height": 91, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 1991558200, "version": 58, "versionNonce": 1980388936, "isDeleted": false, "boundElements": null, "updated": 1660733356402, "link": null, "locked": false, "points": [ [ 0, 0 ], [ -11, -39 ], [ 7, -91 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "jPwIU_a9_nxvuDFAcbzxM", "focus": 0.6245467021802061, "gap": 8.08173076923083 }, "endBinding": { "elementId": "xvkm28eoejETjF3M78jpN", "focus": -0.23155463939046053, "gap": 2.2932692307692832 }, "startArrowhead": null, "endArrowhead": "arrow" } ], "appState": { "gridSize": null, "viewBackgroundColor": "#ffffff" }, "files": {} } ================================================ FILE: docs/dashtable.md ================================================ # Dashtable in Dragonfly Dashtable is a very important data structure in Dragonfly. This document explains how it fits inside the engine. Each selectable database holds a primary dashtable that contains all its entries. Another instance of Dashtable holds an optional expiry information, for keys that have TTL expiry on them. Dashtable is equivalent to Redis dictionary but have some wonderful properties that make Dragonfly memory efficient in various situations. ![Database Overview](./db.svg) ## Redis dictionary *“All problems in computer science can be solved by another level of indirection”* This section is a brief refresher of how redis dictionary (RD) is implemented. We shamelessly "borrowed" a diagram from [this blogpost](https://codeburst.io/a-closer-look-at-redis-dictionary-implementation-internals-3fd815aae535), so if you want a deep-dive, you can read the original article. Each `RD` is in fact two hash-tables (see `ht` field in the diagram below). The second instance is used for incremental resizes of the dictionary. Each hash-table `dictht` is implemented as a [classic hashtable with separate chaining](https://en.wikipedia.org/wiki/Hash_table#Separate_chaining). `dictEntry` is the link-list entry that wraps each key/value pair inside the table. Each dictEntry has three pointers and takes up 24 bytes of space. The bucket array of `dictht` is resized at powers of two, so usually its utilization is in [50, 100] range. ![RD structure](https://miro.medium.com/max/1400/1*gNc8VzCknWRxXTBP9cVEHQ.png)
Let's estimate the overhead of `dictht` table inside RD. *Case 1*: it has `N` items at 100% load factor, in other words, buckets count equals to number of items. Each bucket holds a pointer to dictEntry, i.e. it's 8 bytes. In total we need: $8N + 24N = 32N$ bytes per record.
*Case 2*: `N` items at 75% load factor, in other words, the number of buckets is 1.33 higher than number of items. In total we need: $N\*1.33\*8 + 24N \approx 34N$ bytes per record.
*Case 3*: `N` items at 50% load factor, say right after table growth. Number of buckets is twice the number of items, hence we need $N\*2\*8 + 24N = 40N$ bytes per record. In best possible case we need at least 16 bytes to store key/value pair into the table, therefore the overhead of `dictht` is on average about 16-24 bytes per item. Now lets take incremental growth into account. When `ht[0]` is full (i.e. RD needs to migrate data to a bigger table), it will instantiate a second temporary instance `ht[1]` that will hold additional 2*N buckets. Both instances will live in parallel until all data is migrated to `ht[1]` and then `ht[0]` bucket array will be deleted. All this complexity is hidden from a user by well engineered API of RD. Lets combine case 3 and case 1 to analyze memory spike at this point: `ht[0]` holds `N` items and it is fully utilized. `ht[1]` is allocated with `2N` buckets. Overall, the memory needed during the spike is $32N + 16N=48N$ bytes. To summarize, RD requires between **16-32 bytes overhead**. ## Dash table [Dashtable](https://arxiv.org/abs/2003.07302) is an evolution of an algorithm from 1979 called [extendible hashing](https://en.wikipedia.org/wiki/Extendible_hashing). Similarly to a classic hashtable, dashtable (DT) also holds an array of pointers at front. However, unlike with classic tables, it points to `segments` and not to linked lists of items. Each `segment` is, in fact, a mini-hashtable of constant size. The front array of pointers to segments is called `directory`. Similarly to a classic table, when an item is inserted into a DT, it first determines the destination segment based on item's hashvalue. The segment is implemented as a hashtable with open-addressed hashing scheme and as I said - constant in size. Once segment is determined, the item inserted into one of its buckets. If an item was successfully inserted, we finished, otherwise, the segment is "full" and needs splitting. The DT splits the contents of a full segment in two segments, and the additional segment is added to the directory. Then it tries to reinsert the item again. To summarize, the classic chaining hash-table is built upon a dynamic array of linked-lists while dashtable is more like a dynamic array of flat hash-tables of constant size. ![Dashtable Diagram](./dashtable.svg) In the diagram above you can see how dashtable looks like. Each segment is comprised of `K` buckets. For example, in our implementation a dashtable has 60 buckets per segment (it's a compile-time parameter that can be configured). ### Segment zoom-in Below you can see the diagram of a segment. It comprised of regular buckets and stash buckets. Each bucket has `k` slots and each slot can host a key-value record. ![Segment](./dashsegment.svg) In our implementation, each segment has 56 regular buckets, 4 stash buckets and each bucket contains 14 slots. Overall, each dashtable segment has capacity to host 840 records. When an item is inserted into a segment, DT first determines its home bucket based on item's hash value. The home bucket is one of 56 regular buckets that reside in the table. Each bucket has 14 available slots and the item can reside in any free slot. If the home bucket is full, then DT tries to insert to the regular bucket on the right. And if that bucket is also full, it tries to insert into one of 4 stash buckets. These are kept deliberately aside to gather spillovers from the regular buckets. The segment is "full" when the insertion fails, i.e. the home bucket and the neighbour bucket and all 4 stash buckets are full. Please note that segment is not necessary at full capacity, it can be that other buckets are not yet full, but unfortunately, that item can go only into these 6 buckets, so the segment contents must be split. In case of split event, DT creates a new segment, adds it to the directory and the items from the old segment partly moved to the new one, and partly rebalanced within the old one. Only two segments are touched during the split event. Now we can explain why seemingly similar data-structure has an advantage over a classic hashtable in terms of memory and cpu. 1. Memory: we need `~N/840` entries or `8N/840` bytes in dashtable directory to host N items on average. Basically, the overhead of directory almost disappears in DT. Say for 1M items we will need ~1200 segments or 9600 bytes for the main array. That's in contrast to RD where we will need a solid `8N` bucket array overhead - no matter what. For 1M items, it will obviously be 8MB. In addition, dash segments use open addressing collision scheme with probing, that means that they do not need anything like `dictEntry`. Dashtable uses lots of tricks to make its own metadata small. In our implementation, the average `tax` per entry is short of 20 bits compared to 64 bits in RD (dictEntry.next). In addition, DT incremental resize does not allocate a bigger table - instead it adds a single segment per split event. Assuming that key/pair entry is two 8 byte pointers like in RD, then DT requires $16N + (8N/840) + 2.5N + O(1) \approx 19N$ bytes at 100% utilization. This number is very close to the optimum of 16 bytes. In unlikely case when all segments just doubled in size, i.e. DT is at 50% of utilization we may need $38N$ bytes per item. In practice, each segment grows independently from others, so the table has smooth memory usage of 22-32 bytes per item or **6-16 bytes overhead**. 1. Speed: RD requires an allocation for dictEntry per insertion and deallocation per deletion. In addition, RD uses chaining, which is cache unfriendly on modern hardware. There is a consensus in engineering and research communities that classic chaining schemes are slower than open addressing alternatives. Having said that, DT also needs to go through a single level of indirection when fetching a segment pointer. However, DT's directory size is relatively small: in the example above, all 9K could resize in L1 cache. Once the segment is determined, the rest of the insertion, however, is very fast an mostly operates on 1-3 memory cache lines. Finally, during resizes, RD requires to allocate a bucket array of size `2N`. That could be time consuming - imagine an allocation of 100M buckets for example. DT on the other hand requires an allocation of constant size per new segment. DT is faster and what's more important - it's incremental ability is better. It eliminates latency spikes and reduces tail latency of the operations above. Please note that with all efficiency of Dashtable, it can not decrease drastically the overall memory usage. Its primary goal is to reduce waste around dictionary management. Having said that, by reducing metadata waste we could insert dragonfly-specific attributes into a table's metadata in order to implement other intelligent algorithms like forkless save. This is where some of the Dragonfly's disrupting qualities [can be seen](#forkless-save). ## Benchmarks There are many other improvements in dragonfly that save memory besides DT. I will not be able to cover them all here. The results below show the final result as of May 2022. ### Populate single-threaded To compare RD vs DT I often use an internal debugging command "debug populate" that quickly fills both datastores with data. It just saves time and gives more consistent results compared to memtier_benchmark. It also shows the raw speed at which each dictionary gets filled without intermediary factors like networking, parsing etc. I deliberately fill datasets with a small data to show how overhead of metadata differs between two data structures. I run "debug populate 20000000" (20M) on both engines on my home machine "AMD Ryzen 5 3400G with 8 cores". | | Dragonfly | Redis 6 | |-------------|-----------|---------| | Time | 10.8s | 16.0s | | Memory used | 1GB | 1.73G | When looking at Redis6 "info memory" stats, you can see that `used_memory_overhead` field equals to `1.0GB`. That means that out of 1.73GB bytes allocated, a whooping 1.0GB is used for the metadata. For small data use-cases the cost of metadata in Redis is larger than the data itself. ### Populate multi-threaded Now I run Dragonfly on all 8 cores. Redis has the same results, of course. | | Dragonfly | Redis 6 | |-------------|-----------|---------| | Time | 2.43s | 16.0s | | Memory used | 896MB | 1.73G | Due to shared-nothing architecture, Dragonfly maintains a dashtable per thread with its own slice of data. Each thread fills 1/8th of 20M range it owns - and it much faster, almost 8 times faster. You can see that the total usage is even smaller, because now we maintain smaller tables in each thread (it's not always the case though - we could get slightly worse memory usage than with single-threaded case, depends where we stand compared to hash table utilization). ### Forkless Save This example shows how much memory Dragonfly uses during BGSAVE under load compared to Redis. Btw, BGSAVE and SAVE in Dragonfly is the same procedure because it's implemented using fully asynchronous algorithm that maintains point-in-time snapshot guarantees. This test consists of 3 steps: 1. Execute `debug populate 5000000 key 1024` command on both servers to quickly fill them up with ~5GB of data. 2. Run `memtier_benchmark --ratio 1:0 -n 600000 --threads=2 -c 20 --distinct-client-seed --key-prefix="key:" --hide-histogram --key-maximum=5000000 -d 1024` command in order to send constant update traffic. This traffic should not affect substantially the memory usage of both servers. 3. Finally, run `bgsave` on both servers while measuring their memory. It's very hard, technically to measure exact memory usage of Redis during BGSAVE because it creates a child process that shares its parent memory in-part. We chose `cgroupsv2` as a tool to measure the memory. We put each server into a separate cgroup and we sampled `memory.current` attribute for each cgroup. Since a forked Redis process inherits the cgroup of the parent, we get an accurate estimation of their total memory usage. Although we did not need this for Dragonfly we applied the same approach for consistency. ![BGSAVE](./bgsave_memusage.svg) As you can see on the graph, Redis uses 50% more memory even before BGSAVE starts. Around second 14, BGSAVE kicks off on both servers. Visually you can not see this event on Dragonfly graph, but it's seen very well on Redis graph. It took just few seconds for Dragonfly to finish its snapshot (again, not possible to see) and around second 20 Dragonfly is already behind BGSAVE. You can see a distinguishable cliff at second 39 where Redis finishes its snapshot, reaching almost x3 times more memory usage at peak. ### Expiry of items during writes Efficient Expiry is very important for many scenarios. See, for example, [Pelikan paper'21](https://pelikan.io/2021/segcache.html). Twitter team says that their memory footprint could be reduced by as much as by 60% by employing better expiry methodology. The authors of the post above show pros and cons of expiration methods in the table below: They argue that proactive expiration is very important for timely deletion of expired items. Dragonfly, employs its own intelligent garbage collection procedure. By leveraging DashTable compartmentalized structure it can actually employ a very efficient passive expiry algorithm with low CPU overhead. Our passive procedure is complimented with proactive gradual scanning of the table in background. The procedure is a follows: A dashtable grows when its segment becomes full during the insertion and needs to be split. This is a convenient point to perform garbage collection, but only for that segment. We scan its buckets for the expired items. If we delete some of them, we may avoid growing the table altogether! The cost of scanning the segment before potential split is no more the split itself so can be estimated as `O(1)`. We use `memtier_benchmark` for the experiment to demonstrate Dragonfly vs Redis expiry efficiency. We run locally the following command: ```bash memtier_benchmark --ratio 1:0 -n 600000 --threads=2 -c 20 --distinct-client-seed \ --key-prefix="key:" --hide-histogram --expiry-range=30-30 --key-maximum=100000000 -d 256 ``` We load larger values (256 bytes) to reduce the impact of metadata savings of Dragonfly. | | Dragonfly | Redis 6 | |----------------------|-----------|---------| | Memory peak usage | 1.45GB | 1.95GB | | Avg SET qps | 131K | 100K | Please note that Redis could sustain 30% less qps. That means that the optimal working sets for Dragonfly and Redis are different - the former needed to host at least `20s*131k` items at any point of time and the latter only needed to keep `20s*100K` items. So for `30%` bigger working set Dragonfly needed `25%` less memory at peak. *Please ignore the performance advantage of Dragonfly over Redis in this test - it has no meaning. I run it locally on my machine and it does not represent a real throughput benchmark.
*All diagrams in this doc are created in [drawio app](https://app.diagrams.net/).* ================================================ FILE: docs/dense_set.excalidraw ================================================ { "type": "excalidraw", "version": 2, "source": "https://excalidraw.com", "elements": [ { "id": "LdnS4utc0Co8ZQl0k_99q", "type": "rectangle", "x": 278.57142857142867, "y": 767.857142857143, "width": 157, "height": 42, "angle": 0, "strokeColor": "#364fc7", "backgroundColor": "#4c6ef5", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 34309611, "version": 379, "versionNonce": 490192843, "isDeleted": false, "boundElements": [ { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow" }, { "type": "text", "id": "BV0b6Du7Nu_TpcyHxOq9M" } ], "updated": 1662257477282, "link": null, "locked": false }, { "id": "6iemTDX54UBvWAow6YZUm", "type": "ellipse", "x": 785.5714285714287, "y": 670.857142857143, "width": 151, "height": 65, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 2110505739, "version": 615, "versionNonce": 1697849797, "isDeleted": false, "boundElements": [ { "type": "text", "id": "CsENpV2URO6_T9J1e_EWv" }, { "id": "h4EkHYMe6b4cxIpFk2aJ1", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "id": "CsENpV2URO6_T9J1e_EWv", "type": "text", "x": 790.5714285714287, "y": 689.357142857143, "width": 141, "height": 28, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 1128884549, "version": 556, "versionNonce": 341608715, "isDeleted": false, "boundElements": null, "updated": 1662257477283, "link": null, "locked": false, "text": "\"abcd...\"", "fontSize": 20, "fontFamily": 1, "textAlign": "center", "verticalAlign": "middle", "baseline": 19, "containerId": "6iemTDX54UBvWAow6YZUm", "originalText": "\"abcd...\"" }, { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow", "x": 436.80362915161936, "y": 789.7627222797395, "width": 81.53559883961861, "height": 0.030478424363479917, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 1941199909, "version": 1319, "versionNonce": 319409605, "isDeleted": false, "boundElements": null, "updated": 1662257477403, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 81.53559883961861, 0.030478424363479917 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "LdnS4utc0Co8ZQl0k_99q", "focus": 0.041645385141281355, "gap": 1.2322005801906926 }, "endBinding": { "elementId": "9mWjCy5sUe-mID6u6k7Ll", "focus": -0.08136851610313917, "gap": 1.2322005801906926 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "type": "ellipse", "version": 1317, "versionNonce": 365900933, "isDeleted": false, "id": "tbWakWx-QT3DCK-_FZhx-", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1687.5714285714287, "y": 681.857142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 151, "height": 65, "seed": 429183979, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "RqWMMUkOMQWtnqqIo_0RK", "type": "text" }, { "id": "ZD_EGEh1PSlEhdhmPUGm3", "type": "arrow" }, { "type": "text", "id": "RqWMMUkOMQWtnqqIo_0RK" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 1258, "versionNonce": 1498415691, "isDeleted": false, "id": "RqWMMUkOMQWtnqqIo_0RK", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1692.5714285714287, "y": 700.357142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 141, "height": 28, "seed": 365098053, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "\"abcd...\"", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "tbWakWx-QT3DCK-_FZhx-", "originalText": "\"abcd...\"" }, { "type": "arrow", "version": 3623, "versionNonce": 1786105125, "isDeleted": false, "id": "ZD_EGEh1PSlEhdhmPUGm3", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1593.5714285714287, "y": 767.857142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 98.02649903619977, "height": 41.55714530998347, "seed": 1874017893, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1662257477405, "link": null, "locked": false, "startBinding": { "elementId": "RyzbgdtiyAgDl_Gg-xKD6", "focus": 0.44304364520670675, "gap": 2 }, "endBinding": { "elementId": "tbWakWx-QT3DCK-_FZhx-", "focus": 0.6558676754700489, "gap": 1 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 57.5, -9.5 ], [ 98.02649903619977, -41.55714530998347 ] ] }, { "type": "ellipse", "version": 1594, "versionNonce": 722835269, "isDeleted": false, "id": "hls1kkVvTEbIVUoHV9YjB", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1688.0714285714287, "y": 848.357142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 151, "height": 65, "seed": 464754437, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "7OALlCUSo8C4wRunATj7i", "type": "text" }, { "id": "ZD_EGEh1PSlEhdhmPUGm3", "type": "arrow" }, { "id": "7OALlCUSo8C4wRunATj7i", "type": "text" }, { "type": "text", "id": "7OALlCUSo8C4wRunATj7i" }, { "id": "PtndVbqi061kx-2QVmX9B", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 1533, "versionNonce": 1814440843, "isDeleted": false, "id": "7OALlCUSo8C4wRunATj7i", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1693.0714285714287, "y": 866.857142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 141, "height": 28, "seed": 1547241419, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "\"abcd...\"", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "hls1kkVvTEbIVUoHV9YjB", "originalText": "\"abcd...\"" }, { "id": "PtndVbqi061kx-2QVmX9B", "type": "arrow", "x": 1595.5714285714287, "y": 818.857142857143, "width": 128.1422939788249, "height": 32.825682301479674, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 310414827, "version": 1513, "versionNonce": 652927109, "isDeleted": false, "boundElements": null, "updated": 1662257477406, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 49.5, -8 ], [ 128.1422939788249, 24.825682301479674 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "LUhivcEGaeW_fHoMkT5PY", "focus": 0.2744377811094453, "gap": 5 }, "endBinding": { "elementId": "hls1kkVvTEbIVUoHV9YjB", "focus": 0.453665660258198, "gap": 9.270374749825422 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "type": "rectangle", "version": 539, "versionNonce": 447891973, "isDeleted": false, "id": "BmVwp90EOf01pxoCqayka", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 779.0714285714287, "y": 273.8571428571429, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 157, "height": 42, "seed": 817120651, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "7-os26TSlkxMhDb-ALHK8", "type": "text" }, { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow" }, { "type": "text", "id": "7-os26TSlkxMhDb-ALHK8" }, { "id": "2-4BatkaFqKxOF9ikfE9M", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 496, "versionNonce": 1475544267, "isDeleted": false, "id": "7-os26TSlkxMhDb-ALHK8", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 784.0714285714287, "y": 280.8571428571429, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 147, "height": 28, "seed": 398781605, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DensePtr", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "BmVwp90EOf01pxoCqayka", "originalText": "DensePtr" }, { "id": "BV0b6Du7Nu_TpcyHxOq9M", "type": "text", "x": 283.57142857142867, "y": 774.857142857143, "width": 147, "height": 28, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 84057963, "version": 301, "versionNonce": 1123890789, "isDeleted": false, "boundElements": null, "updated": 1662257477283, "link": null, "locked": false, "text": "DenseLinkKey", "fontSize": 20, "fontFamily": 1, "textAlign": "center", "verticalAlign": "middle", "baseline": 19, "containerId": "LdnS4utc0Co8ZQl0k_99q", "originalText": "DenseLinkKey" }, { "type": "ellipse", "version": 1392, "versionNonce": 208548715, "isDeleted": false, "id": "oUfTPCoNOMOVUScypl9ov", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1038.0714285714287, "y": 262.3571428571429, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 151, "height": 65, "seed": 1274116613, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "_ZWQLsSXL62nm9Vxybs_T", "type": "text" }, { "id": "ZD_EGEh1PSlEhdhmPUGm3", "type": "arrow" }, { "id": "_ZWQLsSXL62nm9Vxybs_T", "type": "text" }, { "id": "_ZWQLsSXL62nm9Vxybs_T", "type": "text" }, { "id": "PtndVbqi061kx-2QVmX9B", "type": "arrow" }, { "type": "text", "id": "_ZWQLsSXL62nm9Vxybs_T" }, { "id": "2-4BatkaFqKxOF9ikfE9M", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 1329, "versionNonce": 1599163589, "isDeleted": false, "id": "_ZWQLsSXL62nm9Vxybs_T", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1043.0714285714287, "y": 280.8571428571429, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 141, "height": 28, "seed": 1098831051, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "\"abcd...\"", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "oUfTPCoNOMOVUScypl9ov", "originalText": "\"abcd...\"" }, { "id": "2-4BatkaFqKxOF9ikfE9M", "type": "arrow", "x": 937.5714285714287, "y": 296.8571428571429, "width": 97, "height": 2, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 2010023531, "version": 543, "versionNonce": 1848018917, "isDeleted": false, "boundElements": null, "updated": 1662257477407, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 41, 1 ], [ 97, -1 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "BmVwp90EOf01pxoCqayka", "focus": 0.0021287919105907396, "gap": 1.5 }, "endBinding": { "elementId": "oUfTPCoNOMOVUScypl9ov", "focus": 0.05585205610314286, "gap": 3.5285491921035828 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "type": "rectangle", "version": 751, "versionNonce": 235050155, "isDeleted": false, "id": "Suj1TA3n75lniv8ZthhOy", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 780.5714285714287, "y": 481.857142857143, "strokeColor": "#2b8a3e", "backgroundColor": "#12b886", "width": 157, "height": 42, "seed": 1337311947, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "5l8sQoeycml7y43c3H6j4", "type": "text" }, { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow" }, { "id": "5l8sQoeycml7y43c3H6j4", "type": "text" }, { "id": "XNzXS4nhlngVv4LqrpGWH", "type": "arrow" }, { "type": "text", "id": "5l8sQoeycml7y43c3H6j4" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 707, "versionNonce": 949919621, "isDeleted": false, "id": "5l8sQoeycml7y43c3H6j4", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 785.5714285714287, "y": 488.8571428571431, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 147, "height": 28, "seed": 26534757, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DensePtr", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "Suj1TA3n75lniv8ZthhOy", "originalText": "DensePtr" }, { "type": "ellipse", "version": 1601, "versionNonce": 822765285, "isDeleted": false, "id": "e0Z3-_Eg_DtzWKAJ00uZx", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1039.5714285714287, "y": 470.357142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 151, "height": 65, "seed": 959000939, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "_Bita6RwDhub4HiG4-vAe", "type": "text" }, { "id": "ZD_EGEh1PSlEhdhmPUGm3", "type": "arrow" }, { "id": "_Bita6RwDhub4HiG4-vAe", "type": "text" }, { "id": "_Bita6RwDhub4HiG4-vAe", "type": "text" }, { "id": "PtndVbqi061kx-2QVmX9B", "type": "arrow" }, { "id": "_Bita6RwDhub4HiG4-vAe", "type": "text" }, { "id": "XNzXS4nhlngVv4LqrpGWH", "type": "arrow" }, { "type": "text", "id": "_Bita6RwDhub4HiG4-vAe" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 1537, "versionNonce": 140486123, "isDeleted": false, "id": "_Bita6RwDhub4HiG4-vAe", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1044.5714285714287, "y": 488.8571428571431, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 141, "height": 28, "seed": 776810181, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "\"abcd...\"", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "e0Z3-_Eg_DtzWKAJ00uZx", "originalText": "\"abcd...\"" }, { "type": "arrow", "version": 1219, "versionNonce": 1379287877, "isDeleted": false, "id": "XNzXS4nhlngVv4LqrpGWH", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 939.0714285714287, "y": 504.8571428571431, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 97, "height": 2, "seed": 192269323, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1662257477408, "link": null, "locked": false, "startBinding": { "elementId": "Suj1TA3n75lniv8ZthhOy", "focus": 0.002128791910595701, "gap": 1.5 }, "endBinding": { "elementId": "e0Z3-_Eg_DtzWKAJ00uZx", "focus": 0.055852056103139376, "gap": 3.5285491921035685 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 41, 1 ], [ 97, -1 ] ] }, { "id": "RGj3Y6CtyijvehUeHVywF", "type": "text", "x": 749.5714285714287, "y": 560.857142857143, "width": 474, "height": 46, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 130163083, "version": 363, "versionNonce": 1546629541, "isDeleted": false, "boundElements": null, "updated": 1662257477283, "link": null, "locked": false, "text": "Chain With Multiple Entries", "fontSize": 36, "fontFamily": 1, "textAlign": "center", "verticalAlign": "top", "baseline": 32, "containerId": null, "originalText": "Chain With Multiple Entries" }, { "type": "rectangle", "version": 696, "versionNonce": 85550891, "isDeleted": false, "id": "1S3pVzBUuYFr-RbDX1FXv", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 536.0714285714287, "y": 747.857142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 157, "height": 42, "seed": 715872619, "groupIds": [ "kaOwSxozJCF4g6QcHxA1q" ], "strokeSharpness": "sharp", "boundElements": [ { "type": "text", "id": "P0w2r72h8lTF4KC0S8iK0" }, { "id": "h4EkHYMe6b4cxIpFk2aJ1", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "rectangle", "version": 695, "versionNonce": 1599687115, "isDeleted": false, "id": "CdJtzp6w0n0rveWC1BWuQ", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 535.0714285714287, "y": 804.857142857143, "strokeColor": "#364fc7", "backgroundColor": "#4c6ef5", "width": 157, "height": 42, "seed": 582081413, "groupIds": [ "kaOwSxozJCF4g6QcHxA1q" ], "strokeSharpness": "sharp", "boundElements": [ { "type": "text", "id": "LFc4k25ArlLXtoygEGUU6" }, { "id": "iKBu85WHY4IL_69QEvdyT", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "id": "9mWjCy5sUe-mID6u6k7Ll", "type": "rectangle", "x": 519.5714285714287, "y": 703.857142857143, "width": 188, "height": 159, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [ "kaOwSxozJCF4g6QcHxA1q" ], "strokeSharpness": "sharp", "seed": 1146093355, "version": 582, "versionNonce": 952359019, "isDeleted": false, "boundElements": [ { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow" }, { "id": "5d6SPvHw2keIDl-5kNmEb", "type": "arrow" }, { "type": "text", "id": "2HOa22It8IfsktBdjpTwo" } ], "updated": 1662257477283, "link": null, "locked": false }, { "id": "LFc4k25ArlLXtoygEGUU6", "type": "text", "x": 540.0714285714287, "y": 811.857142857143, "width": 147, "height": 28, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [ "kaOwSxozJCF4g6QcHxA1q" ], "strokeSharpness": "sharp", "seed": 1252566219, "version": 555, "versionNonce": 790860555, "isDeleted": false, "boundElements": null, "updated": 1662257477283, "link": null, "locked": false, "text": "DenseLinkKey", "fontSize": 20, "fontFamily": 1, "textAlign": "center", "verticalAlign": "middle", "baseline": 19, "containerId": "CdJtzp6w0n0rveWC1BWuQ", "originalText": "DenseLinkKey" }, { "id": "P0w2r72h8lTF4KC0S8iK0", "type": "text", "x": 541.0714285714287, "y": 754.857142857143, "width": 147, "height": 28, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [ "kaOwSxozJCF4g6QcHxA1q" ], "strokeSharpness": "sharp", "seed": 1221832197, "version": 556, "versionNonce": 1960523557, "isDeleted": false, "boundElements": null, "updated": 1662257477283, "link": null, "locked": false, "text": "DensePtr", "fontSize": 20, "fontFamily": 1, "textAlign": "center", "verticalAlign": "middle", "baseline": 19, "containerId": "1S3pVzBUuYFr-RbDX1FXv", "originalText": "DensePtr" }, { "id": "2HOa22It8IfsktBdjpTwo", "type": "text", "x": 524.5714285714287, "y": 708.857142857143, "width": 178, "height": 28, "angle": 0, "strokeColor": "#a61e4d", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [ "kaOwSxozJCF4g6QcHxA1q" ], "strokeSharpness": "sharp", "seed": 971803051, "version": 325, "versionNonce": 1123219883, "isDeleted": false, "boundElements": null, "updated": 1662257477283, "link": null, "locked": false, "text": "DenseLinkKey", "fontSize": 20, "fontFamily": 1, "textAlign": "center", "verticalAlign": "top", "baseline": 19, "containerId": "9mWjCy5sUe-mID6u6k7Ll", "originalText": "DenseLinkKey" }, { "type": "rectangle", "version": 877, "versionNonce": 1427311237, "isDeleted": false, "id": "RyzbgdtiyAgDl_Gg-xKD6", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1434.5714285714287, "y": 745.107142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 157, "height": 42, "seed": 1200981765, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "K8fVzXRPoMTnOm4BIQdpC", "type": "text" }, { "type": "text", "id": "K8fVzXRPoMTnOm4BIQdpC" }, { "id": "ZD_EGEh1PSlEhdhmPUGm3", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "rectangle", "version": 876, "versionNonce": 1849346795, "isDeleted": false, "id": "LUhivcEGaeW_fHoMkT5PY", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1433.5714285714287, "y": 802.107142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 157, "height": 42, "seed": 1269700555, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "d68fKPrkXvutq5CIgpKu0", "type": "text" }, { "type": "text", "id": "d68fKPrkXvutq5CIgpKu0" }, { "id": "PtndVbqi061kx-2QVmX9B", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "rectangle", "version": 767, "versionNonce": 114818213, "isDeleted": false, "id": "nqXx_jG0SMox2AHT2L5F2", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1418.0714285714287, "y": 701.107142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 188, "height": 159, "seed": 532176485, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow" }, { "id": "5d6SPvHw2keIDl-5kNmEb", "type": "arrow" }, { "id": "P46vozsH8hY3lX5pMtPk8", "type": "text" }, { "type": "text", "id": "P46vozsH8hY3lX5pMtPk8" }, { "id": "CVx1AqnNI76hVX9-ObrtA", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 745, "versionNonce": 1071377733, "isDeleted": false, "id": "d68fKPrkXvutq5CIgpKu0", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1438.5714285714287, "y": 809.107142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 147, "height": 28, "seed": 491710059, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DensePtr", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "LUhivcEGaeW_fHoMkT5PY", "originalText": "DensePtr" }, { "type": "text", "version": 736, "versionNonce": 22842443, "isDeleted": false, "id": "K8fVzXRPoMTnOm4BIQdpC", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1439.5714285714287, "y": 752.107142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 147, "height": 28, "seed": 1346980293, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DensePtr", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "RyzbgdtiyAgDl_Gg-xKD6", "originalText": "DensePtr" }, { "type": "text", "version": 505, "versionNonce": 1882147883, "isDeleted": false, "id": "P46vozsH8hY3lX5pMtPk8", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1423.0714285714287, "y": 706.107142857143, "strokeColor": "#a61e4d", "backgroundColor": "#12b886", "width": 178, "height": 28, "seed": 723312907, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DenseLinkKey", "baseline": 19, "textAlign": "center", "verticalAlign": "top", "containerId": "nqXx_jG0SMox2AHT2L5F2", "originalText": "DenseLinkKey" }, { "id": "h4EkHYMe6b4cxIpFk2aJ1", "type": "arrow", "x": 698.5714285714287, "y": 769.857142857143, "width": 108, "height": 43, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 313455205, "version": 995, "versionNonce": 1199018661, "isDeleted": false, "boundElements": null, "updated": 1662257477410, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 67, -2.5 ], [ 108, -43 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "1S3pVzBUuYFr-RbDX1FXv", "focus": 0.17277405270544205, "gap": 5.5 }, "endBinding": { "elementId": "6iemTDX54UBvWAow6YZUm", "focus": 0.37288545736724105, "gap": 1 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "type": "text", "version": 735, "versionNonce": 1684199269, "isDeleted": false, "id": "C55jJitM19fp12H5lRwCI", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 798.5714285714287, "y": 170.8571428571429, "strokeColor": "#000000", "backgroundColor": "#12b886", "width": 374, "height": 46, "seed": 22775717, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 36, "fontFamily": 1, "text": "Chain With One Entry", "baseline": 32, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "Chain With One Entry" }, { "type": "text", "version": 957, "versionNonce": 694451563, "isDeleted": false, "id": "hGFgRua4wpyTtp4D694Ud", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 726.5714285714287, "y": 392.8571428571429, "strokeColor": "#000000", "backgroundColor": "#12b886", "width": 518, "height": 46, "seed": 840351749, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 36, "fontFamily": 1, "text": "Chain With a Displaced Entry", "baseline": 32, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "Chain With a Displaced Entry" }, { "type": "ellipse", "version": 766, "versionNonce": 1799135941, "isDeleted": false, "id": "LGXZp6X5oRRKg9gzSJIcd", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1229.8214285714287, "y": 669.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 151, "height": 65, "seed": 1902833605, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "uxrkX0MrXRysMixOCLf86", "type": "text" }, { "id": "sGH1mRBDDfdaZOORzbU1h", "type": "arrow" }, { "type": "text", "id": "uxrkX0MrXRysMixOCLf86" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 706, "versionNonce": 1852282891, "isDeleted": false, "id": "uxrkX0MrXRysMixOCLf86", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1234.8214285714287, "y": 688.482142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 141, "height": 28, "seed": 1657838347, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "\"abcd...\"", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "LGXZp6X5oRRKg9gzSJIcd", "originalText": "\"abcd...\"" }, { "type": "rectangle", "version": 848, "versionNonce": 1701732011, "isDeleted": false, "id": "6SF7SEj50JLrJxpeJopdp", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 980.3214285714287, "y": 746.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 157, "height": 42, "seed": 368070501, "groupIds": [ "7648kMiz63bJLV7GO8sve" ], "strokeSharpness": "sharp", "boundElements": [ { "id": "rmScJhxvevICKMmx6PYQF", "type": "text" }, { "id": "sGH1mRBDDfdaZOORzbU1h", "type": "arrow" }, { "type": "text", "id": "rmScJhxvevICKMmx6PYQF" }, { "id": "iKBu85WHY4IL_69QEvdyT", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "rectangle", "version": 844, "versionNonce": 1763922251, "isDeleted": false, "id": "Nz45mnUTSGpaOgsVNIEr-", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 979.3214285714287, "y": 803.982142857143, "strokeColor": "#364fc7", "backgroundColor": "#4c6ef5", "width": 157, "height": 42, "seed": 998933867, "groupIds": [ "7648kMiz63bJLV7GO8sve" ], "strokeSharpness": "sharp", "boundElements": [ { "id": "dkP-6jzTOX9bdt8GCvOJw", "type": "text" }, { "id": "CVx1AqnNI76hVX9-ObrtA", "type": "arrow" }, { "type": "text", "id": "dkP-6jzTOX9bdt8GCvOJw" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "rectangle", "version": 734, "versionNonce": 450949099, "isDeleted": false, "id": "QvUMauaFoUm7amxqdJy2z", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 963.8214285714287, "y": 702.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 188, "height": 159, "seed": 2086345413, "groupIds": [ "7648kMiz63bJLV7GO8sve" ], "strokeSharpness": "sharp", "boundElements": [ { "id": "wIo5IjqjKx5agDWM2U6y9", "type": "arrow" }, { "id": "5d6SPvHw2keIDl-5kNmEb", "type": "arrow" }, { "id": "sEfRctJpRk7foZK9c0IAH", "type": "text" }, { "type": "text", "id": "sEfRctJpRk7foZK9c0IAH" }, { "id": "iKBu85WHY4IL_69QEvdyT", "type": "arrow" } ], "updated": 1662257477283, "link": null, "locked": false }, { "type": "text", "version": 705, "versionNonce": 212974219, "isDeleted": false, "id": "dkP-6jzTOX9bdt8GCvOJw", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 984.3214285714287, "y": 810.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 147, "height": 28, "seed": 1586274315, "groupIds": [ "7648kMiz63bJLV7GO8sve" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DenseLinkKey", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "Nz45mnUTSGpaOgsVNIEr-", "originalText": "DenseLinkKey" }, { "type": "text", "version": 706, "versionNonce": 474302373, "isDeleted": false, "id": "rmScJhxvevICKMmx6PYQF", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 985.3214285714287, "y": 753.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 147, "height": 28, "seed": 370506277, "groupIds": [ "7648kMiz63bJLV7GO8sve" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DensePtr", "baseline": 19, "textAlign": "center", "verticalAlign": "middle", "containerId": "6SF7SEj50JLrJxpeJopdp", "originalText": "DensePtr" }, { "type": "text", "version": 475, "versionNonce": 726675755, "isDeleted": false, "id": "sEfRctJpRk7foZK9c0IAH", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 968.8214285714287, "y": 707.982142857143, "strokeColor": "#a61e4d", "backgroundColor": "#12b886", "width": 178, "height": 28, "seed": 464676523, "groupIds": [ "7648kMiz63bJLV7GO8sve" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1662257477283, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "DenseLinkKey", "baseline": 19, "textAlign": "center", "verticalAlign": "top", "containerId": "QvUMauaFoUm7amxqdJy2z", "originalText": "DenseLinkKey" }, { "type": "arrow", "version": 1465, "versionNonce": 1519908357, "isDeleted": false, "id": "sGH1mRBDDfdaZOORzbU1h", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1142.8214285714287, "y": 768.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 108, "height": 43, "seed": 1144014277, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1662257477411, "link": null, "locked": false, "startBinding": { "elementId": "6SF7SEj50JLrJxpeJopdp", "focus": 0.17277405270544205, "gap": 5.5 }, "endBinding": { "elementId": "LGXZp6X5oRRKg9gzSJIcd", "focus": 0.37288545736724105, "gap": 1 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 67, -2.5 ], [ 108, -43 ] ] }, { "type": "arrow", "version": 1414, "versionNonce": 647294309, "isDeleted": false, "id": "CVx1AqnNI76hVX9-ObrtA", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1141.8214285714287, "y": 826.982142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 272, "height": 42, "seed": 171156747, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1662257477411, "link": null, "locked": false, "startBinding": { "elementId": "Nz45mnUTSGpaOgsVNIEr-", "focus": 0.5583475858439679, "gap": 5.5 }, "endBinding": { "elementId": "nqXx_jG0SMox2AHT2L5F2", "focus": 0.06888696200536025, "gap": 4.25 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 95, -23 ], [ 272, -42 ] ] }, { "type": "arrow", "version": 1469, "versionNonce": 180091077, "isDeleted": false, "id": "iKBu85WHY4IL_69QEvdyT", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 693.5714285714287, "y": 826.857142857143, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 269, "height": 43, "seed": 1191246795, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1662257477412, "link": null, "locked": false, "startBinding": { "elementId": "CdJtzp6w0n0rveWC1BWuQ", "focus": 0.5091435337455598, "gap": 1.5 }, "endBinding": { "elementId": "QvUMauaFoUm7amxqdJy2z", "focus": 0.10601094635015593, "gap": 1.25 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 95, -23 ], [ 269, -43 ] ] } ], "appState": { "gridSize": null, "viewBackgroundColor": "#ffffff" }, "files": {} } ================================================ FILE: docs/dense_set.md ================================================ # DenseSet in Dragonfly `DenseSet` uses [classic hashtable with separate chaining](https://en.wikipedia.org/wiki/Hash_table#Separate_chaining) similar to the Redis dictionary for lookup of items within the set. The main optimization present in `DenseSet` is the ability for a pointer to **point to either an object or a link key**, removing the need to allocate a set entry for every entry. This is accomplished by using [pointer tagging](https://en.wikipedia.org/wiki/Tagged_pointer) exploiting the fact that the top 12 bits of any userspace address are not used and can be set to indicate if the current pointer points to nothing, a link key, or an object. The following is what each bit in a pointer is used for | Bit Index (from LSB) | Meaning | | -------------------- |-------- | | 0 - 52 | Memory address of data in the userspace | | 53 | Indicates if this `DensePtr` points to data stored in the `DenseSet` or the next link in a chain | | 54 | Displacement bit. Indicates if the current entry is in the correct list defined by the data's hash | | 55 | Direction displaced, this only has meaning if the Displacement bit is set. 0 indicates the entry is to the left of its correct list, 1 indicates it is to the right of the correct list. | | 56 - 63 | Unused | Further, to reduce collisions items may be inserted into neighbors of the home chain (the chain determined by the hash) that are empty to reduce the number of unused spaces. These entries are then marked as displaced using pointer tagging. An example of possible bucket configurations can be seen below. ![Dense Set Visualization](./dense_set.svg) *Created using [excalidraw](https://excalidraw.com)* ### Insertion To insert an entry a `DenseSet` will take the following steps: 1. Check if the entry already exists in the set, if so return false 2. If the entry does not exist look for an empty chain at the hash index ± 1, prioritizing the home chain. If an empty entry is found the item will be inserted and return true 3. If step 2 fails and the growth prerequisites are met, increase the number of buckets in the table and repeat step 2 4. If step 3 fails, attempt to insert the entry in the home chain. - If the home chain is not occupied by a displaced entry insert the new entry in the front of the list - If the home chain is occupied by a displaced entry move the displaced entry to its home chain. This may cause a domino effect if the home chain of the displaced entry is occupied by a second displaced entry, resulting in up to `O(N)` "fixes" ### Searching To find an entry in a `DenseSet`: 1. Check the first entry in the home and neighbour cells for matching entries 2. If step 1 fails iterate the home chain of the searched entry and check for equality ### Pending Improvements Some further improvements to `DenseSet` include allowing entries to be inserted in their home chain without having to perform the current `O(N)` steps to fix displaced entries. By inserting an entry in their home chain after the displaced entry instead of fixing up displaced entries, searching incurs minimal added overhead and there is no domino effect in inserting a new entry. To move a displaced entry to its home chain eventually multiple heuristics may be implemented including: - When an entry is erased if the chain becomes empty and there is a displaced entry in the neighbor chains move it to the now empty home chain - If a displaced entry is found as a result of a search and is the root of a chain with multiple entries, the displaced node should be moved to its home bucket ## Benchmarks At 100% utilization the Redis dictionary implementation uses approximately 32 bytes per record ([read the breakdown for more information](./dashtable.md#redis-dictionary)) In comparison using the neighbour cell optimization, `DenseSet` has ~21% of spaces unused at full utilization resulting in $N\*8 + 0.2\*16N \approx 11.2N$ or ~12 bytes per record, yielding ~20 byte savings. The number of bytes per record saved grows as utilization decreases. Command `memtier_benchmark -p 6379 --command "sadd __key__ __data__" -n 10000000 --threads=1 -c 1 --command-key-pattern=R --data-size=10 --key-prefix="key:" --hide-histogram --random-data --key-maximum=1 --randomize --pipeline 20` produces two sets entries with lots of small records in them. This is how memory usage looks like with DenseSet: | Server | Memory (RSS) | |:---------------------:|:------: | | Dragonfly/DenseSet | 323MB 🟩 | | Redis | 586MB | | Dragonfly/RedisDict | 663MB | ================================================ FILE: docs/df-share-nothing.md ================================================ # Dragonfly Architecture Dragonfly is a modern replacement for memory stores like Redis and Memcached. It scales vertically on a single instance to support millions of requests per second. It is more memory efficient, has been designed with reliability in mind, and includes a better caching design. ## Threading model Dragonfly uses a single process with a multiple-thread architecture. Each Dragonfly thread is indirectly assigned several responsibilities via fibers. One such responsibility is handling incoming connections. Once a socket listener accepts a client connection, the connection spends its entire lifetime bound to a single thread inside a fiber. Dragonfly is written to be 100% non-blocking; it uses fibers to provide asynchronicity in each thread. One of the essential properties of asynchronicity is that a thread cannot be blocked as long as it has pending CPU tasks. Dragonfly preserves this property by wrapping each unit of execution context in a fiber; we wrap units of execution that can potentially be blocked on I/O. For example, a connection loop runs within a fiber; a function that writes a snapshot runs inside a fiber, and so on. As a side comment - asynchronicity and parallelism are different terms. Nodejs, for example, provides asynchronous execution but is single-threaded. Similarly, each Dragonfly thread is asynchronous on its own; therefore, Dragonfly is responsive to incoming events even when it handles long-running commands like saving to disk or running Lua scripts. ### Thread actors in DF The DF in-memory database is sharded into `N` parts, where `N` is less or equal to the number of threads in the system. Each database shard is owned and accessed by a single thread. The same thread can handle TCP connections and simultaneously host a database shard. See the diagram below.
Here, our DF process spawns 4 threads, where threads 1 through 3 handle I/O (i.e., manage client connections) and threads 2 through 4 manage DB shards. Thread 2, for example, divides its CPU time between handling incoming requests and processing DB operations on the shard it owns. So when we say that thread 1 is an I/O thread, we mean that Dragonfly can pin fibers that manage client connections to thread 1. In general, any thread can have many responsibilities that require CPU time; database management and connection handling are only two of those responsibilities. ## Fibers I suggest reading my [intro post](https://www.romange.com/2018/12/15/introduction-to-fibers-in-c-/) about `Boost.Fibers` to learn more about fibers. By the way, I want to compliment `Boost.Fibers` library–it has been exceptionally well designed: it's unintrusive, lightweight, and efficient. Moreover, its default scheduler can be overridden. In the case of `helio`, the I/O library that powers Dragonfly, we overrode the `Boost.Fibers` scheduler to support shared-nothing architecture and integrate it with the I/O polling loop. Importantly, fibers require bottom-up support in the application layer to preserve their asynchronicity. For example, in the snippet below, a blocking write into `fd` won't magically allow a fiber to preempt and switch to another fiber. No, the whole thread will be blocked. ```cpp ... write(fd, buf, 1000000); ... pthread_mutex_lock(...); ``` Similarly, with a `pthread_mutex_lock` call, the whole thread might be blocked, wasting precious CPU time.. Therefore, the Dragonfly code uses *fiber-friendly* primitives for I/O, communication, and coordination. These primitives are supplied by the `helio` and `Boost.Fibers` libraries. ## Life of a command request This section explains how Dragonfly handles a command in the context of shared-nothing architecture. In most architectures used today, multi-threaded servers use mutex locks to protect their data structures, but Dragonfly does not. Why is this? Inter-thread interactions in Dragonfly occur only via passing messages from thread to thread. For example, consider the following sequence diagram of handling a SET request: ```uml @startuml actor User as A1 boundary connection as B1 entity "Shard K" as E1 A1 -> B1 : SET KEY VAL B1 -> E1 : SET KEY VAL / k = HASH(KEY) % N E1 -> B1 : OK B1 -> A1 : Response @enduml ``` Here, a connection fiber resides in a thread different from one that handles the `KEY` entity. We use hashing to decide which shard owns which key. Another way to think of this flow is that a connection fiber serves as a coordinator for issuing transactional commands to other threads. In this simple example, the external "SET" command requires a single message passed from the coordinator to the destination shard thread. When we think of the Dragonfly model in the context of a single command request, I prefer to use the following diagram instead of the [one above](#thread-actors-in-df).
Here, a coordinator (or connection fiber) might even reside on one of the threads that coincidently owns one of the shards. However, it is easier to think of it as a separate entity that never directly accesses any shard data. The coordinator serves as a virtualization layer that hides all the complexity of talking to multiple shards. It employs start-of-the-art algorithms to provide atomicity (and strict serializability) semantics for multi-key commands like "mset, mget, and blpop." It also offers strict serializability for Lua scripts and multi-command transactions. Hiding such complexity is valuable to the end customer, but it comes with some CPU and latency costs. We believe the trade-off is worthwhile given the value that Dragonfly provides. If you want to deep dive into Dragonfly architecture without the complexities of transactional code, it's worth checking [Midi Redis](https://github.com/romange/midi-redis/), which implements a toy backend supporting `PING`, `SET`, and `GET` [commands](https://github.com/romange/midi-redis/blob/main/server/main_service.cc#L239). In fact, Dragonfly grew from that project; they share a common commit history. By the way, to learn how to build even simpler TCP backends than `midi-redis`, `helio` library provides sample backends like these: [echo_server](https://github.com/romange/helio/blob/master/examples/echo_server.cc) and [ping_iouring_server.cc](https://github.com/romange/helio/blob/master/examples/pingserver/ping_iouring_server.cc). These backends reach millions of QPS on multi-core servers much like Dragonfly and midi-redis do. ================================================ FILE: docs/differences.md ================================================ # Differences with Redis ## String lengths, indices. String sizes are limited to 256MB. Indices (say in GETRANGE and SETRANGE commands) should be signed 32 bit integers in range [-2147483647, 2147483648]. ### String handling. SORT does not take any locale into account. ## Expiry ranges. Expirations are limited to 8 years. For commands with millisecond precision like PEXPIRE or PSETEX, expirations greater than 2^28ms are quietly rounded to the nearest second losing precision of less than 0.001%. ## Lua We use lua 5.4.4 that has been released in 2022. That means we also support [lua integers](https://github.com/redis/redis/issues/5261). ================================================ FILE: docs/faq.md ================================================ # Dragonfly Frequently Asked Questions - [Dragonfly Frequently Asked Questions](#dragonfly-frequently-asked-questions) - [What is the license model of Dragonfly? Is it an open source?](#what-is-the-license-model-of-dragonfly-is-it-an-open-source) - [Can I use dragonfly in production?](#can-i-use-dragonfly-in-production) - [We benchmarked Dragonfly and we have not reached 4M qps throughput as you advertised.](#we-benchmarked-dragonfly-and-we-have-not-reached-4m-qps-throughput-as-you-advertised) - [Dragonfly provides vertical scale, but we can achieve similar throughput with X nodes in a Redis cluster.](#dragonfly-provides-vertical-scale-but-we-can-achieve-similar-throughput-with-x-nodes-in-a-redis-cluster) - [If only Dragonfly had this command I would use it for sure](#if-only-dragonfly-had-this-command-i-would-use-it-for-sure) ## What is the license model of Dragonfly? Is it an open source? Dragonfly is released under [BSL 1.1](../LICENSE.md) (Business Source License). BSL 1.1 is considered to be "source available" license and it's not strictly open-source license. We believe that a [BSL 1.1](https://spdx.org/licenses/BUSL-1.1.html) license is more permissive than licenses like AGPL, and it will allow us to provide a competitive commercial service using our technology. In general terms, it means that Dragonfly's code is free to use and free to change as long as you do not sell services directly related to Dragonfly or in-memory datastores. We followed the trend of other technological companies like Elastic, Redis, MongoDB, Cockroach labs, Redpanda Data to protect our rights to provide service and support for the software we are building. ## Can I use dragonfly in production? License wise you are free to use dragonfly in your production as long as you do not provide Dragonfly as a managed service. From a code maturity point of view, Dragonfly's code is covered with unit testing and the regression tests. However as with any new software there are use cases that are hard to test and predict. We advise you to run your own particular use case on dragonfly for a few days before considering production usage. ## We benchmarked Dragonfly and we have not reached 4M qps throughput as you advertised. We conducted our experiments using a load-test generator called `memtier_benchmark`, and we run benchmarks on AWS network-enhanced instance `c6gn.16xlarge` on recent Linux kernel versions. Dragonfly might reach smaller throughput on other instances, but we would still expect to reach around 1M+ qps on instances with 16-32 vCPUs. ## Dragonfly provides vertical scale, but we can achieve similar throughput with X nodes in a Redis cluster. Dragonfly optimizes the use of underlying hardware, allowing it to run efficiently on instances as small as 8GB, and scale vertically to large 2TB machines with 128 cores. This versatility significantly reduces the complexity of running cluster workloads on a single node, saving hardware resources and costs. More importantly, it diminishes the total cost of ownership associated with managing multi-node clusters. In contrast, Redis in cluster mode imposes limitations on multi-key and transactional operations, whereas Dragonfly maintains the same semantics as a single-node Redis system. Furthermore, scaling out horizontally with small instances can lead to instability in production environments. We believe that large-scale deployments of in-memory stores require both vertical and horizontal scaling, which is not efficiently achievable with an in-memory store like Redis. ## If only Dragonfly had this command I would use it for sure Dragonfly implements ~190 Redis commands which we think represent a good coverage of the market. However this is not based empirical data. Having said that, if you have commands that are not covered, please feel free to open an issue for that or vote for an existing issue. We will do our best to prioritise those commands according to their popularity. ================================================ FILE: docs/memcached_benchmark.md ================================================ Contention in memcached under the high write throughput. Overall CPU usage of memcached when performing SETS benchmark: ================================================ FILE: docs/memory_bgsave.tsv ================================================ Time Dragonfly Redis 4 4738531328 6819917824 5 4738637824 6819917824 6 4738658304 6819913728 7 4738777088 6820589568 8 4738781184 6820638720 9 4738768896 6820769792 10 4738494464 6820777984 11 4738756608 6820683776 12 4740325376 6820687872 13 4740243456 6820691968 14 4740194304 6820687872 15 4740194304 7429746688 16 4740734976 7942115328 17 4740370432 8400957440 18 4740366336 8863305728 19 4740390912 9302515712 20 4740399104 9697935360 21 4740423680 10074103808 22 4748312576 10362601472 23 4750438400 10649939968 24 4750315520 10926985216 25 4750426112 11195555840 26 4750180352 11444666368 27 4750417920 11665764352 28 4750131200 11872944128 29 4750233600 12060946432 30 4750475264 12232212480 31 12379299840 32 12521598976 33 12647915520 34 12756508672 35 12848570368 36 12944240640 37 13025046528 38 13105799168 39 13181427712 40 8000053248 41 7048486912 42 7048507392 ================================================ FILE: docs/namespaces.md ================================================ # Namespaces in Dragonfly Dragonfly added an _experimental_ feature, allowing complete separation of data by different users. We call this feature _namespaces_, and it allows using a single Dragonfly server with multiple tenants, each using their own data, without being able to mix them together. Note that this feature can alternatively be achieved by having each user `SELECT` a different (numeric) database, or by asking that each user uses a unique prefix for their keys. This approach has several disadvantages, like users forgetting to `SELECT` / use their prefix, accessing data logically belonging to other users. The advantage of using Namespaces is that data is completely isolated, and users cannot accidentally use data they do not own. A user must authenticate in order to access the namespace it was assigned. And as a bonus, each namespace can have multiple databases, switched via `SELECT` like any regular data store. However, before using this feature, please note that it is experimental. This means that: * Some features are not supported for non-default namespaces, such as replication and save to RDB * Some tools are missing, like breakdown of memory / load per namespace * We do not yet consider this production ready, and it might still have some uncovered bugs So kindly use it at your own risk. ## Usage This section describes how, as a Dragonfly user / administrator, you could use namespaces. A namespace is identified by a unique string id, defined by the user / admin. Each Dragonfly user is associated with a single namespace. If not set explicitly, then the default namespace is used, which is the empty string id. Multiple users can use the same namespace if they are all assigned the same namespace id. This can allow, for example, creating a read-only user as well as a mutating user over the same data. To associate user `user1` with the namespace `namespace1`, use the `ACL` command with the `NAMESPACE:namespace1` flag: ``` ACL SETUSER user1 NAMESPACE:namespace1 ON >user_pass +@all ~* ``` This sets / creates user `user`, using password `user_pass`, using namespace `namespace1`. For more examples check out `tests/dragonfly/acl_family_test.py` - specifically the `test_namespaces` function. ## Technical Details This section describes how we _implemented_ namespaces in Dragonfly. It is meant to be used by those who wish to contribute pull requests to Dragonfly. Prior to adding namespaces to Dragonfly, each _shard_ had a single `DbSlice` that it owned. They were thread-local, global-scope instances. To support namespaces, we created a `Namespace` class (see `src/server/namespaces.h`) which contains a `vector`, with a `DbSlice` per shard. When first used, a `Namespace` calls the engine shard set to initialize the array of `DbSlice`s. To access all `Namespace`s, we also added a registry with the original name `Namespaces`. It is a global, thread safe class that allows accessing all registered namespaces, and registering new ones on the fly. Note that, while it is thread safe, it shouldn't be a bottle neck because it is supposed to only be used during the authentication of a connection (or when adding new namespaces). When a new connection is authenticated with Dragonfly, we look up (and create, if needed) the namespace it is associated with. We then save a `Namespace* ns` inside the `dfly::ConnectionContext` class to associate the user with the namespaces. Because we removed the global `DbSlice` objects, this is now the only way to access namespaces, which protects users from accessing unowned data. Currently, we do not have any support for removing namespaces, so they hang in memory until the server exits. ================================================ FILE: docs/quick-start/README.md ================================================

Dragonfly

# Quick Start Starting with `docker run` is the simplest way to get up and running with DragonflyDB. If you do not have docker on your machine, [Install Docker](https://docs.docker.com/get-docker/) before continuing. ## Step 1 ### On linux ```bash docker run --network=host --ulimit memlock=-1 docker.dragonflydb.io/dragonflydb/dragonfly ``` ### On macOS _`network=host` doesn't work well on macOS, see [this issue](https://github.com/docker/for-mac/issues/1031)_ ```bash docker run -p 6379:6379 --ulimit memlock=-1 docker.dragonflydb.io/dragonflydb/dragonfly ``` Dragonfly DB will answer to both `http` and `redis` requests out of the box! You can use `redis-cli` to connect to `localhost:6379` or open a browser and visit `http://localhost:6379` **NOTE**: On some configurations, running with the `docker run --privileged ...` flag can fix some initialization errors. ## Step 2 Connect with a redis client ```bash redis-cli 127.0.0.1:6379> set hello world OK 127.0.0.1:6379> keys * 1) "hello" 127.0.0.1:6379> get hello "world" 127.0.0.1:6379> ``` ## Step 3 Continue being great and build your app with the power of DragonflyDB! ## Known issues ## More Build Options - [Docker Compose Deployment](/contrib/docker/) - [Kubernetes Deployment with Helm Chart](/contrib/charts/dragonfly/) - [Build From Source](/docs/build-from-source.md) ================================================ FILE: docs/rdbsave.excalidraw ================================================ { "type": "excalidraw", "version": 2, "source": "https://excalidraw.com", "elements": [ { "type": "rectangle", "version": 586, "versionNonce": 345912761, "isDeleted": false, "id": "BY5OdEEKT0Y_DTy9Zgr9C", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 654.7020016982203, "y": 187.24519230769243, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 165, "height": 199, "seed": 1621471436, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "sIrssFTnnb9f1o26g1j88", "type": "text" }, { "type": "text", "id": "sIrssFTnnb9f1o26g1j88" }, { "id": "1cq4mAkO92nzlk-wjAy0a", "type": "arrow" } ], "updated": 1661620421120, "link": null, "locked": false }, { "type": "text", "version": 514, "versionNonce": 869523031, "isDeleted": false, "id": "sIrssFTnnb9f1o26g1j88", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 659.7020016982203, "y": 261.74519230769243, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 155, "height": 50, "seed": 711168500, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1661620421121, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "Thread-local\nSnapshot 1", "baseline": 43, "textAlign": "center", "verticalAlign": "middle", "containerId": "BY5OdEEKT0Y_DTy9Zgr9C", "originalText": "Thread-local\nSnapshot 1" }, { "type": "rectangle", "version": 622, "versionNonce": 1016232663, "isDeleted": false, "id": "OiDY20ES-4wBxFVAzHkHt", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 866.0673076923077, "y": 187.24519230769243, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 165, "height": 199, "seed": 1937655639, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "nTSFevnRPYnvrSc57ZrgV", "type": "text" }, { "id": "nTSFevnRPYnvrSc57ZrgV", "type": "text" }, { "type": "text", "id": "nTSFevnRPYnvrSc57ZrgV" }, { "id": "NGMUGV32wJmpMyvB3YQTx", "type": "arrow" } ], "updated": 1661620421121, "link": null, "locked": false }, { "type": "text", "version": 539, "versionNonce": 941214039, "isDeleted": false, "id": "nTSFevnRPYnvrSc57ZrgV", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 871.0673076923077, "y": 256.74519230769243, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 155, "height": 60, "seed": 1072545177, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1661620424002, "link": null, "locked": false, "fontSize": 23.932285237126536, "fontFamily": 1, "text": "Thread-local\nSnapshot 2", "baseline": 51, "textAlign": "center", "verticalAlign": "middle", "containerId": "OiDY20ES-4wBxFVAzHkHt", "originalText": "Thread-local\nSnapshot 2" }, { "type": "rectangle", "version": 608, "versionNonce": 1548421111, "isDeleted": false, "id": "0DuGwtSiWQDXGbVDx_Yq4", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1085.2980769230767, "y": 187.24519230769243, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 165, "height": 199, "seed": 1695403735, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "dcrIif4WgKLztfzWXXskR", "type": "text" }, { "id": "dcrIif4WgKLztfzWXXskR", "type": "text" }, { "type": "text", "id": "dcrIif4WgKLztfzWXXskR" }, { "id": "hgq3HgiDoEU1A13Sax2A5", "type": "arrow" } ], "updated": 1661620421121, "link": null, "locked": false }, { "type": "text", "version": 530, "versionNonce": 667080441, "isDeleted": false, "id": "dcrIif4WgKLztfzWXXskR", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1090.2980769230767, "y": 256.74519230769243, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 155, "height": 60, "seed": 379350553, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1661620421122, "link": null, "locked": false, "fontSize": 23.932285237126536, "fontFamily": 1, "text": "Thread-local\nSnapshot 3", "baseline": 51, "textAlign": "center", "verticalAlign": "middle", "containerId": "0DuGwtSiWQDXGbVDx_Yq4", "originalText": "Thread-local\nSnapshot 3" }, { "id": "577abnzpQuxk_hrNgIMkV", "type": "diamond", "x": 689.3365384615385, "y": 437.86057692307713, "width": 92, "height": 157, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 520181305, "version": 125, "versionNonce": 1270149399, "isDeleted": false, "boundElements": [ { "id": "1cq4mAkO92nzlk-wjAy0a", "type": "arrow" }, { "type": "text", "id": "YWzMoutOj3POKIhzoAb6q" }, { "id": "HjlV2QEoKO1Najg9D1xnm", "type": "arrow" } ], "updated": 1661620421122, "link": null, "locked": false }, { "id": "1cq4mAkO92nzlk-wjAy0a", "type": "arrow", "x": 728.5673076923077, "y": 395.9759615384616, "width": 32.307692307692264, "height": 36.04730445962048, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 2032795417, "version": 139, "versionNonce": 1145353783, "isDeleted": false, "boundElements": null, "updated": 1661620421122, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 32.307692307692264, 11.538461538461547 ], [ 9.869210911479854, 36.04730445962048 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "BY5OdEEKT0Y_DTy9Zgr9C", "focus": 0.8708968370314767, "gap": 9.73076923076917 }, "endBinding": { "elementId": "577abnzpQuxk_hrNgIMkV", "focus": -1.6111525113388454, "gap": 5.625821498015291 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "YWzMoutOj3POKIhzoAb6q", "type": "text", "x": 694.3365384615385, "y": 498.36057692307713, "width": 82, "height": 36, "angle": 0, "strokeColor": "#000000", "backgroundColor": "transparent", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 381921847, "version": 39, "versionNonce": 405941433, "isDeleted": false, "boundElements": null, "updated": 1661620421122, "link": null, "locked": false, "text": "Rdb\nSerializer", "fontSize": 16, "fontFamily": 2, "textAlign": "center", "verticalAlign": "middle", "baseline": 32, "containerId": "577abnzpQuxk_hrNgIMkV", "originalText": "Rdb\nSerializer" }, { "id": "Ig1qNk-AOw_VTS_xlELs5", "type": "rectangle", "x": 717.798076923077, "y": 641.3605769230771, "width": 477, "height": 67, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#fa5252", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 1664303159, "version": 124, "versionNonce": 111029657, "isDeleted": false, "boundElements": [ { "type": "text", "id": "jE5wNvo8TFk1wC4v8bQ6s" }, { "id": "HjlV2QEoKO1Najg9D1xnm", "type": "arrow" }, { "id": "hLcR_BUncIusv-IFL2ucM", "type": "arrow" }, { "id": "WHRznFJAFjpXbmv35tCsY", "type": "arrow" }, { "id": "yVBhfXkyFmu2rg16oRlxu", "type": "arrow" } ], "updated": 1661620421122, "link": null, "locked": false }, { "type": "diamond", "version": 140, "versionNonce": 1301746297, "isDeleted": false, "id": "MclWY93u6fXaKcMyYF-Jy", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 896.4134615384614, "y": 437.8605769230771, "strokeColor": "#000000", "backgroundColor": "#12b886", "width": 92, "height": 157, "seed": 755813689, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "NGMUGV32wJmpMyvB3YQTx", "type": "arrow" }, { "id": "_xhHeDkg3dVxrIbXlln8Z", "type": "text" }, { "type": "text", "id": "_xhHeDkg3dVxrIbXlln8Z" }, { "id": "hLcR_BUncIusv-IFL2ucM", "type": "arrow" } ], "updated": 1661620421122, "link": null, "locked": false }, { "type": "arrow", "version": 167, "versionNonce": 1223962007, "isDeleted": false, "id": "NGMUGV32wJmpMyvB3YQTx", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 944.8750000000002, "y": 387.86057692307696, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 15.10726263633046, "height": 47.58370911007313, "seed": 282885847, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1661620421122, "link": null, "locked": false, "startBinding": { "elementId": "OiDY20ES-4wBxFVAzHkHt", "focus": 0.48198474540576314, "gap": 1.615384615384528 }, "endBinding": { "elementId": "MclWY93u6fXaKcMyYF-Jy", "focus": -0.9774990043807243, "gap": 2.921009509018951 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 14.615384615384528, 21.538461538461547 ], [ -0.4918780209459328, 47.58370911007313 ] ] }, { "type": "text", "version": 51, "versionNonce": 299916121, "isDeleted": false, "id": "_xhHeDkg3dVxrIbXlln8Z", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 901.4134615384614, "y": 498.3605769230771, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 82, "height": 36, "seed": 1481686553, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1661620421122, "link": null, "locked": false, "fontSize": 16, "fontFamily": 2, "text": "Rdb\nSerializer", "baseline": 32, "textAlign": "center", "verticalAlign": "middle", "containerId": "MclWY93u6fXaKcMyYF-Jy", "originalText": "Rdb\nSerializer" }, { "type": "diamond", "version": 225, "versionNonce": 1063805623, "isDeleted": false, "id": "jGf5xxZ5eve-AtPae7Yly", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1119.4903846153848, "y": 437.8605769230772, "strokeColor": "#000000", "backgroundColor": "#12b886", "width": 92, "height": 157, "seed": 538175673, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "hgq3HgiDoEU1A13Sax2A5", "type": "arrow" }, { "id": "WQcx4-r2uMVAquWROfq1l", "type": "text" }, { "type": "text", "id": "WQcx4-r2uMVAquWROfq1l" }, { "id": "WHRznFJAFjpXbmv35tCsY", "type": "arrow" } ], "updated": 1661620421122, "link": null, "locked": false }, { "type": "arrow", "version": 390, "versionNonce": 332236857, "isDeleted": false, "id": "hgq3HgiDoEU1A13Sax2A5", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1141.6872098880729, "y": 392.47596153846166, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 70.88009780423477, "height": 61.500951281640766, "seed": 168221527, "groupIds": [], "strokeSharpness": "round", "boundElements": [], "updated": 1661620421122, "link": null, "locked": false, "startBinding": { "elementId": "0DuGwtSiWQDXGbVDx_Yq4", "focus": 0.9791425008071145, "gap": 6.230769230769226 }, "endBinding": { "elementId": "jGf5xxZ5eve-AtPae7Yly", "focus": -0.5445868784908863, "gap": 4.55886494843503 }, "lastCommittedPoint": null, "startArrowhead": null, "endArrowhead": "arrow", "points": [ [ 0, 0 ], [ 70.88009780423477, 10.76923076923083 ], [ 38.5310635413573, 61.500951281640766 ] ] }, { "type": "text", "version": 138, "versionNonce": 2144924631, "isDeleted": false, "id": "WQcx4-r2uMVAquWROfq1l", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1124.4903846153848, "y": 498.3605769230772, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 82, "height": 36, "seed": 585656729, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1661620421122, "link": null, "locked": false, "fontSize": 16, "fontFamily": 2, "text": "Rdb\nSerializer", "baseline": 32, "textAlign": "center", "verticalAlign": "middle", "containerId": "jGf5xxZ5eve-AtPae7Yly", "originalText": "Rdb\nSerializer" }, { "id": "jE5wNvo8TFk1wC4v8bQ6s", "type": "text", "x": 722.798076923077, "y": 656.8605769230771, "width": 467, "height": 36, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 320154873, "version": 98, "versionNonce": 1177598807, "isDeleted": false, "boundElements": null, "updated": 1661620459622, "link": null, "locked": false, "text": "Blob Channel (SliceSnapshot::RecordChannel)\nBucket-level granularity", "fontSize": 16, "fontFamily": 2, "textAlign": "center", "verticalAlign": "middle", "baseline": 32, "containerId": "Ig1qNk-AOw_VTS_xlELs5", "originalText": "Blob Channel (SliceSnapshot::RecordChannel)\nBucket-level granularity" }, { "id": "HjlV2QEoKO1Najg9D1xnm", "type": "arrow", "x": 741.2581209970564, "y": 588.5811776062717, "width": 31.351415988958138, "height": 44.98870164238667, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 1489149785, "version": 105, "versionNonce": 1873907193, "isDeleted": false, "boundElements": null, "updated": 1661620421122, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 31.351415988958138, 44.98870164238667 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "577abnzpQuxk_hrNgIMkV", "gap": 1.9342976914014673, "focus": 0.8117909371106269 }, "endBinding": { "elementId": "Ig1qNk-AOw_VTS_xlELs5", "gap": 7.790697674418787, "focus": -0.593178549414425 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "hLcR_BUncIusv-IFL2ucM", "type": "arrow", "x": 919.3365384615385, "y": 574.4375, "width": 31.736196893864076, "height": 60.69051878354196, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 328800759, "version": 85, "versionNonce": 304047833, "isDeleted": false, "boundElements": null, "updated": 1661620421122, "link": null, "locked": false, "points": [ [ 0, 0 ], [ -14.615384615384642, 25.384615384615472 ], [ 17.120812278479434, 60.69051878354196 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "MclWY93u6fXaKcMyYF-Jy", "focus": -0.22524576872402804, "gap": 9.584854518692971 }, "endBinding": { "elementId": "Ig1qNk-AOw_VTS_xlELs5", "gap": 6.232558139535168, "focus": 0.05517004727771827 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "WHRznFJAFjpXbmv35tCsY", "type": "arrow", "x": 1123.951923076923, "y": 553.6682692307693, "width": 32.30769230769238, "height": 81.53846153846143, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#12b886", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 971531865, "version": 66, "versionNonce": 789696311, "isDeleted": false, "boundElements": null, "updated": 1661620421122, "link": null, "locked": false, "points": [ [ 0, 0 ], [ -32.30769230769238, 38.46153846153834 ], [ -23.84615384615404, 81.53846153846143 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "jGf5xxZ5eve-AtPae7Yly", "focus": 0.2217391304347844, "gap": 15.012636648887266 }, "endBinding": { "elementId": "Ig1qNk-AOw_VTS_xlELs5", "focus": 0.6185597345566728, "gap": 6.153846153846416 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "yVBhfXkyFmu2rg16oRlxu", "type": "arrow", "x": 864.7211538461538, "y": 717.5144230769231, "width": 67.97279116285586, "height": 64.8374913674163, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#228be6", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 651147575, "version": 635, "versionNonce": 116567415, "isDeleted": false, "boundElements": null, "updated": 1661620421122, "link": null, "locked": false, "points": [ [ 0, 0 ], [ -42.30769230769215, 16.923076923076792 ], [ -67.97279116285586, 64.8374913674163 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "Ig1qNk-AOw_VTS_xlELs5", "focus": -0.04672674106343535, "gap": 9.153846153845961 }, "endBinding": { "elementId": "HK8F6p6Adyxvgasi9uzJo", "focus": -0.17323237259147364, "gap": 5.1625086325837515 }, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "HK8F6p6Adyxvgasi9uzJo", "type": "rectangle", "x": 707.7980769230769, "y": 784.4375, "width": 155.84615384615387, "height": 98.27507912481072, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#4c6ef5", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 2031988567, "version": 164, "versionNonce": 418531705, "isDeleted": false, "boundElements": [ { "id": "yVBhfXkyFmu2rg16oRlxu", "type": "arrow" }, { "type": "text", "id": "fB6sqnJqDlolUIDrydMk5" }, { "id": "YVK4Nv0Onos-JNSI9I5YI", "type": "arrow" } ], "updated": 1661620421122, "link": null, "locked": false }, { "id": "fB6sqnJqDlolUIDrydMk5", "type": "text", "x": 712.7980769230769, "y": 825.5750395624053, "width": 145.84615384615387, "height": 16, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#4c6ef5", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 1340401175, "version": 194, "versionNonce": 1565255319, "isDeleted": false, "boundElements": null, "updated": 1661620421123, "link": null, "locked": false, "text": "SaveBody", "fontSize": 14.404558404558403, "fontFamily": 2, "textAlign": "center", "verticalAlign": "middle", "baseline": 13, "containerId": "HK8F6p6Adyxvgasi9uzJo", "originalText": "SaveBody" }, { "type": "rectangle", "version": 216, "versionNonce": 1292304185, "isDeleted": false, "id": "w6yJKrh_ucB0qKWLRrPA1", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 921.4134615384612, "y": 785.2230373606715, "strokeColor": "#000000", "backgroundColor": "#15aabf", "width": 156, "height": 98.27507912481072, "seed": 1894727609, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "id": "yVBhfXkyFmu2rg16oRlxu", "type": "arrow" }, { "id": "JClqLh6OUtndfrUc-BbHt", "type": "text" }, { "type": "text", "id": "JClqLh6OUtndfrUc-BbHt" }, { "id": "XiGmqFegyOE2IKWoIo40s", "type": "arrow" } ], "updated": 1661620421123, "link": null, "locked": false }, { "type": "text", "version": 259, "versionNonce": 710307031, "isDeleted": false, "id": "JClqLh6OUtndfrUc-BbHt", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 926.4134615384612, "y": 826.3605769230768, "strokeColor": "#000000", "backgroundColor": "#4c6ef5", "width": 146, "height": 16, "seed": 1215329367, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1661620421123, "link": null, "locked": false, "fontSize": 14.404558404558403, "fontFamily": 2, "text": "AlignedBuffer", "baseline": 13, "textAlign": "center", "verticalAlign": "middle", "containerId": "w6yJKrh_ucB0qKWLRrPA1", "originalText": "AlignedBuffer" }, { "id": "YVK4Nv0Onos-JNSI9I5YI", "type": "arrow", "x": 867.7980769230768, "y": 836.7451923076923, "width": 55.38461538461536, "height": 0, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#15aabf", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 2028321497, "version": 86, "versionNonce": 506769433, "isDeleted": false, "boundElements": null, "updated": 1661620421123, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 55.38461538461536, 0 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "HK8F6p6Adyxvgasi9uzJo", "focus": 0.0018973206471872748, "gap": 4.153846153846075 }, "endBinding": null, "startArrowhead": null, "endArrowhead": "arrow" }, { "id": "cqCQRIsxqHSsV_j5V6fMA", "type": "ellipse", "x": 1165.490384615384, "y": 781.3605769230769, "width": 128, "height": 106, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#e64980", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 1621295255, "version": 67, "versionNonce": 281065975, "isDeleted": false, "boundElements": [ { "type": "text", "id": "6N8Vr1qw1YKDs9h0ze2LI" }, { "id": "XiGmqFegyOE2IKWoIo40s", "type": "arrow" } ], "updated": 1661620421123, "link": null, "locked": false }, { "id": "6N8Vr1qw1YKDs9h0ze2LI", "type": "text", "x": 1170.490384615384, "y": 816.3605769230769, "width": 118, "height": 36, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#e64980", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "sharp", "seed": 1910738841, "version": 45, "versionNonce": 1681474809, "isDeleted": false, "boundElements": null, "updated": 1661620421123, "link": null, "locked": false, "text": "Direct I/O\nFile", "fontSize": 16, "fontFamily": 2, "textAlign": "center", "verticalAlign": "middle", "baseline": 32, "containerId": "cqCQRIsxqHSsV_j5V6fMA", "originalText": "Direct I/O\nFile" }, { "id": "XiGmqFegyOE2IKWoIo40s", "type": "arrow", "x": 1082.4134615384614, "y": 834.4375, "width": 69.23076923076928, "height": 0.7692307692308304, "angle": 0, "strokeColor": "#000000", "backgroundColor": "#e64980", "fillStyle": "solid", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "groupIds": [], "strokeSharpness": "round", "seed": 1724070359, "version": 21, "versionNonce": 178545431, "isDeleted": false, "boundElements": null, "updated": 1661620421123, "link": null, "locked": false, "points": [ [ 0, 0 ], [ 69.23076923076928, -0.7692307692308304 ] ], "lastCommittedPoint": null, "startBinding": { "elementId": "w6yJKrh_ucB0qKWLRrPA1", "focus": 0.01998122507071207, "gap": 5.000000000000227 }, "endBinding": { "elementId": "cqCQRIsxqHSsV_j5V6fMA", "focus": 0.029379713320443476, "gap": 13.85030430804018 }, "startArrowhead": null, "endArrowhead": "arrow" } ], "appState": { "gridSize": null, "viewBackgroundColor": "#ffffff" }, "files": {} } ================================================ FILE: docs/rdbsave.md ================================================ # RDB Snapshot design The following document describes Dragonfly's point in time, forkless snapshotting procedure, including all its configurations. ## Redis-compatible RDB snapshot This snapshot is serialized into a single file or into a network socket. This configuration is used to create redis-compatible backup snapshots. The algorithm utilizes the shared-nothing architecture of Dragonfly and makes sure that each shard-thread serializes only its own data. Below is the high description of the flow. 1. The `RdbSave` class instantiates a single blocking channel (in red). Its purpose is to gather all the blobs from all the shards. 2. In addition it creates thread-local snapshot instances in each DF shard. TODO: to rename them in the codebase to another name (SnapshotShard?) since `snapshot` word creates ambiguity here. 3. Each SnapshotShard instantiates its own RdbSerializer that is used to serialize each K/V entry into a binary representation according to the Redis format spec. SnapshotShards combine multiple blobs from the same Dash bucket into a single blob. They always send blob data at bucket granularity, i.e. they never send blob into the channel that only partially covers the bucket. This is needed in order to guarantee snapshot isolation. 4. The RdbSerializer uses `io::Sink` to emit binary data. The SnapshotShard instance passes into it a `StringFile` which is just a memory-only based sink that wraps `std::string` object. Once `StringFile` instance becomes large, it's flushed into the channel (as long as it follows the rules above). 4. RdbSave also creates a fiber (SaveBody) that pull all the blobs from the channel. Blobs migh come in unspecified order though it's guaranteed that each blob is self sufficient but itself. 5. DF uses direct I/O, to improve i/o throughput, which, in turn requires properly aligned memory buffers to work. Unfortunately, blobs that come from the rdb channel come in different sizes and they are not aligned by OS page granularity. Therefore, DF passes all the data from rdb channel through AlignedBuffer transformation. The purpose of this class is to copy the incoming data into a properly aligned buffer. Once it accumulates enough data, it flushes it into the output file. To summarize, this configuration employs a single sink to create one file or one stream of data that represents the whole database. ## Dragonfly Snapshot (TBD) Required for replication. Creates several multiple files, one file per SnapshotShard. Does not require a central sink. Each SnapshotShard still uses RdbSerializer together with StringFile to guarantee bucket level granularity. We still need AlignedBuffer if we want to use direct I/O. For a DF process with N shard, it will create N files. Will probably require additional metadata file to provide file-level consistency, but for now we can assume that only N files are created, since our use-case will be network based replication. How it's gonna be used? Replica (slave) will hand-shake with the master and find out how many shard it has. Then it will open `N` sockets and each one of them will pull shard data. First, they will pull snapshot data, and replay it by distributing entries among `K` replica shards. After all the snapshot data is replayed, they will continue with replaying the change log (stable state replication), which is out of context of this document. ## Relaxed point-in-time (TBD) When DF saves its snapshot file on disk, it maintains snapshot isolation by applying a virtual cut through all the process shards. Snapshotting may take time, during which, DF may process many write requests. These mutations won't be part of the snapshot, because the cut captures data up to the point **it has started**. This is perfect for backups. I call this variation - conservative snapshotting. However, when we perform snapshotting for replication, we would like to produce a snapshot that includes all the data upto point in time when the snapshotting **finishes**. I called this *relaxed snapshotting*. The reason for relaxed snapshotting is to avoid keeping the changelog of all mutations during the snapshot creation. As a side comment - we could, in theory, support the same (relaxed) semantics for file snapshots, but it's not necessary since it might increase the snapshot sizes. The snapshotting phase (full-sync) can take up lots of time which add lots of memory pressure on the system. Keeping the change-log aside during the full-sync phase will only add more pressure. We achieve relaxed snapshotting by pushing the changes into the replication sockets without saving them aside. Of course, we would still need a point-in-time consistency, in order to know when the snapshotting finished and the stable state replication started. ## Conservative and relaxed snapshotting variations Both algorithms maintain a scanning process (fiber) that iteratively goes over the main dictionary and serializes its data. Before starting the process, the SnapshotShard captures the change epoch of its shard (this epoch is increased with each write request). ```cpp SnapshotShard.epoch = shard.epoch++; ``` For sake of simplicity, we can assume that each entry in the shard maintains its own version counter. By capturing the epoch number we establish a cut: all entries with `version <= SnapshotShard.epoch` have not been serialized yet and were not modified by the concurrent writes. The DashTable iteration algorithm guarantees convergence and coverage ("at most once"), but it does not guarantee that each entry is visited *exactly once*. Therefore, we use entry versions for two things: 1) to avoid serialization of the same entry multiple times, and 2) to correctly serialize entries that need to change due to concurrent writes. Serialization Fiber: ```cpp for (entry : table) { if (entry.version <= cut.epoch) { entry.version = cut.epoch + 1; SendToSerializationSink(entry); } } ``` To allow concurrent writes during the snapshotting phase, we setup a hook that is triggered on each entry mutation in the table: OnWriteHook: ```cpp .... if (entry.version <= cut.version) { SendToSerializationSink(entry); } ... entry = new_entry; entry.version = shard.epoch++; // guaranteed to become > cut.version ``` Please note that this hook maintains point-in-time semantics for the conservative variation by pushing the previous value of the entry into the sink before changing it. However, for the relaxed point-in-time, we do not have to store the old value. Therefore, we can do the following: OnWriteHook: ```cpp if (entry.version <= cut.version) { SendToSerializationSink(new_entry); // do not have to send the old value } else { // Keep sending the changes. SendToSerializationSink(IncrementalDiff(entry, new_entry)); } entry = new_entry; entry.version = shard.epoch++; ``` The change data is sent along with the rest of the contents, and it requires to extend the existing rdb format to support differential operations like (hset, append, etc). The Serialization Fiber loop is the same for this variation. ================================================ FILE: docs/shard-serialization.md ================================================ # Shard Serialization This document describes how Dragonfly serializes a single shard's data via `SliceSnapshot`. It covers both point-in-time (PIT) and non-PIT serialization modes, their correctness guarantees, and the mechanisms used to coordinate concurrent mutations with the serialization process. ## Overview Shard serialization is used for two purposes: 1. **Backups (RDB save)** — Must produce a consistent point-in-time snapshot. Always uses PIT mode. 2. **Replication (full sync)** — Serializes baseline data and then streams journal changes. Can use either PIT or non-PIT mode, controlled by the `--point_in_time_snapshot` flag (default: true). Both modes share the same traversal infrastructure (`IterateBucketsFb` → `BucketSaveCb` → `SerializeBucket` → `SerializeEntry`) and the same flushing/backpressure machinery (`HandleFlushData` → `consumer_->ConsumeData`). They differ in **how they handle concurrent mutations** during the traversal. | | PIT mode | Non-PIT mode | |---|----------|-------------| | Flag | `use_snapshot_version_ == true` | `use_snapshot_version_ == false` | | Used for | Backups and replication | Replication only | | Consistency | Exact point-in-time snapshot | Eventual consistency (baseline + journal) | | `OnDbChange` | Serializes bucket before mutation | Barrier only (no serialization) | | `OnMoved` | Not registered | Handles DashTable item reshuffling | | Bucket versioning | Yes — skip already-serialized buckets | No — serialize every bucket visited | | Throughput | Lower (mutation path does serialization work) | Higher (mutation path only acquires mutex) | ## Core Types | Type | Location | Role | |------|----------|------| | `SliceSnapshot` | `src/server/snapshot.h` | Orchestrates shard serialization | | `RdbSerializer` | `src/server/rdb_save.h` | Serializes entries into RDB-format buffers | | `SnapshotDataConsumerInterface` | `src/server/snapshot.h` | Downstream sink interface | | `RdbSaver::Impl` | `src/server/rdb_save.cc` | Consumer impl: writes to socket or channel | | `ThreadLocalMutex` | `src/server/synchronization.h` | Fiber-aware mutex for atomicity barrier | | `ChangeReq` | `src/server/table.h` | Describes a table mutation (update or insert) | ## Data Flow Overview ```mermaid flowchart TD subgraph ShardThread[Shard thread / fibers] MUT[DB mutation] -->|change callback| ODC[OnDbChange] ODC -->|lock big_value_mu_| SB1["SerializeBucket
(PIT only)"] SB1 --> SE1[SerializeEntry] SE1 --> SAVE1[RdbSerializer::SaveEntry] TRAV[Snapshot fiber: IterateBucketsFb] --> BSCB[BucketSaveCb] BSCB -->|lock big_value_mu_ + GetLatch| SB2[SerializeBucket] SB2 --> SE2[SerializeEntry] SE2 --> SAVE2[RdbSerializer::SaveEntry] MOV[DashTable move] -->|non-PIT only| OMV[OnMoved] OMV -->|lock big_value_mu_| SB3["SerializeBucket
(if moved across cursor)"] EXP["Expiry / Eviction
(heartbeat, inline, lazy)"] -->|"RecordDelete
(no OnDbChange)"| JRN_DIRECT["journal::RecordEntry
(DEL)"] JRN_DIRECT --> CJC JRN[Journal change] --> CJC[ConsumeJournalChange] CJC -->|lock big_value_mu_| WJE[serializer_->WriteJournalEntry] end SAVE1 -->|consume_fun_ if buffer > threshold| HFD[HandleFlushData] SAVE2 -->|consume_fun_ if buffer > threshold| HFD TRAV -->|between buckets| PS[PushSerialized] PS --> FS[FlushSerialized] FS --> HFD HFD --> SEQ[seq_cond_.wait - ordering gate] SEQ --> CD[consumer_->ConsumeData] CD --> SINK[(Replica socket / sink)] ``` ## PIT Mode (Point-in-Time Snapshot) PIT mode captures an exact snapshot of the shard at the logical moment `snapshot_version_` was assigned. It is the default for both backups and replication. ### Bucket Versioning Dragonfly's `DashTable` ([dashtable.md](dashtable.md)) maintains a version counter per physical bucket. The snapshot must serialize all buckets with version `< snapshot_version_`. - `SerializeBucket` sets the bucket version to `snapshot_version_`, ensuring each bucket is serialized exactly once. - Mutations bump bucket versions, so buckets mutated after the snapshot started will have version `>= snapshot_version_` and are skipped by the traversal. - Buckets not yet traversed but about to be mutated require **serialize-before-mutate**, enforced by `OnDbChange()`. ### Ordering Invariant > For any key, the replica must receive the baseline value **strictly before** any journal entry > that mutates that key. We will use two terms for journal changes: - **Self-contained**: the journal entry fully determines the resulting logical state and can be replayed without the prior value (for example `SET`, `DEL`). - **Baseline-dependent**: the journal entry describes a mutation of an existing value and requires the baseline state to be reconstructed first (for example `HSET`, `LPUSH`). For **transaction-driven mutations** this is guaranteed because: 1. `OnDbChange` runs before the mutation commits and serializes the bucket if needed. 2. `OnDbChange` unconditionally acquires `big_value_mu_` first, so the mutation and its subsequent journal emission cannot overtake an in-progress bucket serialization. **Important caveat:** not all journal entries follow the `OnDbChange` → mutation → `RecordJournal` → `ConsumeJournalChange` sequence. Several code paths emit journal entries via `journal::RecordEntry` directly, bypassing `PreUpdateBlocking` and `OnDbChange` entirely. See [Journal Entries Without `OnDbChange`](#journal-entries-without-ondbchange) below. ### Journal Entries Without `OnDbChange` Not all journal entries follow the transaction-driven `PreUpdateBlocking` → `OnDbChange` → mutation → `RecordJournal` → `ConsumeJournalChange` sequence. Several code paths call `journal::RecordEntry` directly (→ `JournalSlice::AddLogRecord` → `ConsumeJournalChange`), bypassing `OnDbChange` entirely: | Source | Journal command | Trigger | |--------|----------------|---------| | `ExpireIfNeeded` (`db_slice.cc`) | `DEL` | Lazy expiry during key lookup, active expiry sweep (`DeleteExpiredStep`), heartbeat-driven eviction (`FreeMemWithEvictionStepAtomic`) | | `PrimeEvictionPolicy::Evict` (`db_slice.cc`) | `DEL` | Inline eviction when a DashTable bucket overflows during insert | | `generic_family.cc` (SCAN-based deletion) | `DEL` | `RecordDelete` after `DbSlice::Del` in the RM command | | `dflycmd.cc`, `replica.cc`, `cluster_family.cc` | `PING` / `DFLYCLUSTER` | Control signals: takeover sync, PING propagation, cluster config | All data-mutating entries above are self-contained `DEL` commands. The non-mutating entries (`PING`, `DFLYCLUSTER`) carry no key-level semantics. **Why this matters for `ConsumeJournalChange` and `big_value_mu_`:** these journal entries still flow through `ConsumeJournalChange`, which acquires `big_value_mu_`. Today the mutex serves two purposes on these paths: 1. **Serializer buffer exclusivity** — preventing a journal write from interleaving with an in-progress `SerializeBucket` call that shares the same `serializer_` instance. 2. **Baseline-before-journal ordering** — a `DEL K` must not reach the output stream (or a separate journal stream) while K's baseline is still being serialized. Even with separate serializer buffers and tagged-chunk interleaving, the consumer could process `DEL K` before receiving the full baseline, violating the ordering invariant. The mutex prevents this today by blocking the journal write until `SerializeBucket` completes. The lock is *not* needed for transaction-style ordering against `OnDbChange` (these paths bypass it entirely), but it is needed for both concerns above. Removing it requires (a) separate serializer buffers (Phase 2, item 7) **and** (b) a mechanism to defer the `DEL` until the bucket's baseline is fully emitted (Phase 1, item 6 — deferred deletion queue). **Could these paths call `OnDbChange` before deleting?** Not safely: - **`ExpireIfNeeded`:** `SerializeBucket` (called from `OnDbChange`) can preempt, but `ExpireIfNeeded` must not — `ExpireAllIfNeeded` calls `serialization_latch_.Wait()` and lazy expiry in `FindInternal` relies on cooperative scheduling. - **`PrimeEvictionPolicy::Evict`:** `Evict` runs inside DashTable's insert path while the table is mid-structural-mutation. `OnDbChange` calls `SerializeBucket` (iterates the bucket) and `CVCUponInsert` (probes the table) — both unsafe here. Re-entrancy risk. - **`FreeMemWithEvictionStepAtomic`:** runs from heartbeat with `serialization_latch_` held; `OnDbChange` per evicted key would add overhead and preemption points inside the loop. The ordering issue is twofold: byte-stream integrity ([§1](#1-shard-wide-stall-under-big_value_mu_)) and baseline-before-journal correctness — a `DEL` must not be emitted (even to a separate stream) while the same key's baseline is still being serialized. Roadmap item 6 proposes a **deferred deletion queue** to address this without blocking or re-entrancy. ### Mutation Path: `OnDbChange` (PIT) ``` OnDbChange(db_index, req) lock(big_value_mu_) if req is update (existing bucket): bit = *req.update() if !bit.is_done() && bit.GetVersion() < snapshot_version_: -> SerializeBucket(db_index, *bit) else (insert, new key): key = get(req.change) -> table->CVCUponInsert(snapshot_version_, key, callback) callback(bucket_iterator): -> SerializeBucket(db_index, it) unlock(big_value_mu_) ``` For updates, `ChangeReq::update()` returns a `PrimeTable::bucket_iterator`. If the bucket has not been serialized yet (version `< snapshot_version_`), it is serialized now. For inserts, `CVCUponInsert` (`src/core/dash.h`) simulates the insert to identify which buckets' versions would change, and serializes each one with version `< snapshot_version_` via the callback. ### Traversal Path: `BucketSaveCb` (PIT) ``` BucketSaveCb(db_index, bucket_iterator) lock(big_value_mu_) if bucket version >= snapshot_version_: skip (already serialized by OnDbChange or a previous visit) FlushChangeToEarlierCallbacks(...) lock(*db_slice_->GetLatch()) -> SerializeBucket(db_index, bucket_iterator) set bucket version = snapshot_version_ for each occupied slot: -> SerializeEntry -> SaveEntry -> PushToConsumerIfNeeded ``` The version check is the key optimization: buckets already serialized by `OnDbChange` are skipped. ## Non-PIT Mode (Eventual Consistency) Non-PIT mode is available **only for replication** (`stream_journal == true`) and is enabled by setting `--point_in_time_snapshot=false`. It improves server throughput during full sync by eliminating serialization work from the mutation path. ### Design Rationale A replica does not need an exact point-in-time snapshot. It needs to reach eventual consistency: after the full sync baseline is delivered and the journal stream catches up, the replica's state must match the master's current state. This weaker guarantee allows the snapshot to be "fuzzy" — it may include some mutations that happened after the snapshot started and miss others, as long as the journal stream fills in the gaps. ### How It Differs from PIT **`OnDbChange` does no serialization.** In non-PIT mode, the `if (use_snapshot_version_)` block is skipped entirely. `OnDbChange` only acquires `big_value_mu_` and returns immediately. This serves as a **barrier** — it prevents mutations from modifying a bucket while it is being serialized by the traversal fiber — but it does not serialize anything itself. **No bucket version tracking.** `SerializeBucket` does not set the bucket version. `BucketSaveCb` does not check or skip based on version. Every bucket visited by the traversal is serialized unconditionally. **`OnMoved` handles DashTable reshuffling.** When items are inserted into DashTable, existing items may be moved between buckets (due to hash table splitting/merging). In PIT mode this is handled by `OnDbChange` + bucket versioning. In non-PIT mode, since `OnDbChange` does no serialization, a separate `OnMoved` callback is needed to catch items that "jump" across the traversal cursor: ``` OnMoved(db_index, items) lock(big_value_mu_) for each (source_cursor, dest_cursor) in items: if IsPositionSerialized(dest_cursor) && !IsPositionSerialized(source_cursor): -> SerializeBucket(db_index, CursorToBucketIt(dest)) ``` An item needs re-serialization when it moves **from** a not-yet-visited bucket **to** an already-visited bucket. Without this, the item would be missed entirely: the traversal already passed the destination, and the source bucket still has the item removed. **`CVCUponInsert` is not used.** In PIT mode, `OnDbChange` calls `CVCUponInsert` for inserts to proactively serialize *all* buckets the insert would touch (home, neighbor, stash — or the entire segment on a split) **before** the insert commits. This is necessary because PIT must capture the pre-mutation state of every affected bucket. Non-PIT has no such requirement. Instead, the insert proceeds, and `OnMoved` reactively handles any items that were displaced across the traversal cursor. For truly new keys (not displaced existing items), non-PIT relies on the cursor visiting the key's bucket later, or on the journal stream capturing the insert. ### `IsPositionSerialized` — Cursor-Based Position Tracking ```cpp bool IsPositionSerialized(DbIndex id, PrimeTable::Cursor cursor) { uint8_t depth = db_slice_->GetTables(id).first->depth(); return id < snapshot_db_index_ || (id == snapshot_db_index_ && (cursor.bucket_id() < snapshot_cursor_.bucket_id() || (cursor.bucket_id() == snapshot_cursor_.bucket_id() && cursor.segment_id(depth) < snapshot_cursor_.segment_id(depth)))); } ``` Compares a cursor position against the current traversal position (`snapshot_cursor_`, `snapshot_db_index_`). A position is "serialized" if it is behind the cursor — i.e., the traversal has already visited it. ### Traversal Path: `BucketSaveCb` (Non-PIT) ``` BucketSaveCb(db_index, bucket_iterator) lock(big_value_mu_) // no version check — serialize every bucket unconditionally lock(*db_slice_->GetLatch()) -> SerializeBucket(db_index, bucket_iterator) // no version update for each occupied slot: -> SerializeEntry -> SaveEntry -> PushToConsumerIfNeeded ``` ### Correctness in Non-PIT Mode Non-PIT mode guarantees: - Every key that existed when the traversal started and was not deleted before being visited will be serialized at least once (by the traversal or by `OnMoved`). - Keys inserted after the traversal started will appear in the journal stream. - Keys may be serialized in a state newer than the snapshot start (since mutations are not blocked by `OnDbChange` serialization, only by the mutex barrier). - The journal stream, combined with the baseline, produces an eventually consistent replica. What it does **not** guarantee: - Point-in-time consistency. The serialized baseline is a "fuzzy" view spanning the traversal duration. ## Shared Infrastructure The following sections apply to both PIT and non-PIT modes. ### Traversal: `IterateBucketsFb` ``` IterateBucketsFb(send_full_sync_cut) for each database: for each logical bucket via PrimeTable::TraverseBuckets(): -> BucketSaveCb(db_index, bucket_iterator) PushSerialized(false) // explicit flush between buckets yield if CPU time > ~15us PushSerialized(true) // force-flush after each database if send_full_sync_cut: serializer_->SendFullSyncCut() PushSerialized(true) ``` ### Serialization: `SerializeBucket` and `SerializeEntry` `SerializeBucket` iterates all occupied slots in a physical bucket and calls `SerializeEntry` for each. `SerializeEntry` looks up expiry and memcache flags, then calls `serializer_->SaveEntry(pk, pv, expire_time, mc_flags, db_index)`. ### Journal Path: `ConsumeJournalChange` ``` ConsumeJournalChange(item) lock(big_value_mu_) serializer_->WriteJournalEntry(item.journal_item.data) unlock(big_value_mu_) ``` Active in both modes when `stream_journal == true`. Acquires `big_value_mu_` to ensure journal entries are not interleaved with bucket serialization. Does **not** flush data — only appends to the serializer buffer. Flushing happens later via `ThrottleIfNeeded` → `PushSerialized(false)`, called from `JournalSlice` after the journal callback returns. ### Flushing and Backpressure #### `HandleFlushData(std::string data)` — Common Blocking Sink All serialized data ultimately flows through `HandleFlushData`: 1. Assigns monotonically increasing record ID (`rec_id_++`). 2. Optionally yields (background mode). 3. **Blocks** on `seq_cond_.wait` until `id == last_pushed_id_ + 1` (sequential ordering). 4. **Blocks** on `consumer_->ConsumeData(data, cntx_)` (downstream write). 5. Updates `last_pushed_id_`, notifies waiters via `seq_cond_.notify_all()`. 6. Optionally sleeps to throttle CPU (non-background mode, up to 2ms proportional to CPU spent). #### `FlushSerialized(RdbSerializer* serializer)` Calls `serializer->Flush(kFlushEndEntry)` to extract and optionally compress the buffer, then passes the result to `HandleFlushData`. Uses the main `serializer_` if no argument is given. #### `PushSerialized(bool force)` Skips if `!force` and `serializer_->SerializedLen() < kMinBlobSize` (8KB). Otherwise calls `FlushSerialized()` to drain the main serializer buffer. #### `RdbSerializer::PushToConsumerIfNeeded(FlushState flush_state)` ```cpp void RdbSerializer::PushToConsumerIfNeeded(SerializerBase::FlushState flush_state) { if (consume_fun_ && SerializedLen() > flush_threshold_) { string blob = Flush(flush_state); consume_fun_(std::move(blob)); // synchronous! } } ``` Only fires when `consume_fun_` is set **and** the buffer exceeds `flush_threshold_`. When it fires, it **synchronously** invokes the callback, which for `SliceSnapshot` is `HandleFlushData`. ## All Code Paths That Acquire `big_value_mu_` Currently there are **five** call sites in `snapshot.cc` that lock `big_value_mu_`. The diagrams below show the complete call chain from lock acquisition to potential blocking points. ### Path 1: `BucketSaveCb` (traversal fiber, both modes) ```mermaid flowchart LR A[IterateBucketsFb] --> B["BucketSaveCb
lock big_value_mu_
lock GetLatch()"] B --> C[SerializeBucket] C --> D[SerializeEntry] D --> E[SaveEntry] E -->|"if buffer > threshold"| F["consume_fun_()
= HandleFlushData"] F --> G["seq_cond_.wait
consumer_->ConsumeData
BLOCKS"] classDef lock fill:#FFF3E0,stroke:#EF6C00; classDef block fill:#FFEBEE,stroke:#C62828; class B lock; class G block; ``` ### Path 2: `OnDbChange` (mutation fiber, PIT only) ```mermaid flowchart LR A[DB mutation] --> B["OnDbChange
lock big_value_mu_"] B -->|PIT| C[SerializeBucket] C --> D[SerializeEntry] D --> E[SaveEntry] E -->|"if buffer > threshold"| F["consume_fun_()
= HandleFlushData"] F --> G["seq_cond_.wait
consumer_->ConsumeData
BLOCKS"] B -->|non-PIT| H["return
(barrier only)"] classDef lock fill:#FFF3E0,stroke:#EF6C00; classDef block fill:#FFEBEE,stroke:#C62828; classDef safe fill:#E8F5E9,stroke:#2E7D32; class B lock; class G block; class H safe; ``` ### Path 3: `OnMoved` (non-PIT only) ```mermaid flowchart LR A[DashTable move] --> B["OnMoved
lock big_value_mu_"] B -->|"moved across cursor"| C[SerializeBucket] C --> D[SerializeEntry] D --> E[SaveEntry] E -->|"if buffer > threshold"| F["consume_fun_()
= HandleFlushData"] F --> G["seq_cond_.wait
consumer_->ConsumeData
BLOCKS"] B -->|"same side of cursor"| H[skip] classDef lock fill:#FFF3E0,stroke:#EF6C00; classDef block fill:#FFEBEE,stroke:#C62828; class B lock; class G block; ``` ### Path 4: `ConsumeJournalChange` (journal callback, both modes) ```mermaid flowchart LR A[Journal change] --> B["ConsumeJournalChange
lock big_value_mu_"] B --> C["serializer_->WriteJournalEntry
(buffer append only)"] C --> D[returns] classDef lock fill:#FFF3E0,stroke:#EF6C00; class B lock; ``` This path does **not** reach `HandleFlushData`. It only appends to the serializer buffer. ## All Code Paths That Reach `HandleFlushData` ```mermaid flowchart TD subgraph HAZARD["Under big_value_mu_ (HAZARD)"] A1["OnDbChange — PIT only
lock big_value_mu_"] --> SB1["SerializeBucket → SerializeEntry → SaveEntry"] A2["BucketSaveCb — both modes
lock big_value_mu_ + GetLatch()"] --> SB2["SerializeBucket → SerializeEntry → SaveEntry"] A3["OnMoved — non-PIT only
lock big_value_mu_"] --> SB3["SerializeBucket → SerializeEntry → SaveEntry"] SB1 --> CF["PushToConsumerIfNeeded
consume_fun_()"] SB2 --> CF SB3 --> CF CF --> HFD1[HandleFlushData] end subgraph SAFE["Outside big_value_mu_ (SAFE)"] B1["IterateBucketsFb loop
(between buckets)"] --> PS1["PushSerialized(false)"] B2["IterateBucketsFb
(end of database)"] --> PS2["PushSerialized(true)"] B3["IterateBucketsFb
(full sync cut)"] --> PS3["PushSerialized(true)"] B4[FinalizeJournalStream] --> PS4["PushSerialized(true)"] B5["ThrottleIfNeeded
(from JournalSlice)"] --> PS5["PushSerialized(false)"] PS1 --> FS[FlushSerialized] PS2 --> FS PS3 --> FS PS4 --> FS PS5 --> FS FS --> HFD2[HandleFlushData] end HFD1 --> BLOCK["seq_cond_.wait
consumer_->ConsumeData
(BLOCKING)"] HFD2 --> BLOCK classDef hazard fill:#FFEBEE,stroke:#C62828,stroke-width:2px,color:#B71C1C; classDef safe fill:#E8F5E9,stroke:#2E7D32,color:#1B5E20; classDef block fill:#FFF3E0,stroke:#EF6C00; class A1,A2,A3,CF,HFD1 hazard; class B1,B2,B3,B4,B5,PS1,PS2,PS3,PS4,PS5,FS,HFD2 safe; class BLOCK block; ``` ## Delayed Serialization of tiered entities Tiered string values are not read synchronously under `big_value_mu_`. Instead, `SerializeExternal` pushes a `TieredDelayedEntry` into `delayed_entries_`; the actual read and serialization happen later in `PushSerialized()`, outside the bucket-serialization critical section. The current implementation is fragile — delayed entries live in a global side queue rather than being associated with their originating bucket, and this can corrupt the output stream — a delayed tiered value may be emitted after a journal entry for the same key, violating baseline-before-journal ordering (see PR #6824). Note: `RestoreStreamer` (used for slot migration) has its own delayed-entry mechanism via `CmdSerializer`, which uses a keyed `flat_hash_map` rather than a plain deque. The analysis below focuses on `SliceSnapshot`; the `RestoreStreamer` path has analogous concerns but a different data structure. This creates two distinct notions of "bucket finished": 1. **Traversal finished** — `SerializeBucket` has iterated every entry and returned. 2. **Baseline fully emitted** — all delayed tiered entries from that bucket have also been read, serialized, and flushed. For in-memory values these coincide; for tiered values they do not. The ordering invariant (`baseline(K)` before `journal(K)`) still applies. Because the baseline for a tiered key `K` may only materialize when `PushSerialized()` drains `delayed_entries_`, a bucket's completion point extends from "finished iterating" to "all delayed values serialized and flushed". ## Locking and Synchronization ### `big_value_mu_` (ThreadLocalMutex) A `ThreadLocalMutex` (`src/server/synchronization.cc`) serving as the primary synchronization barrier. **Important:** `ThreadLocalMutex::lock()` and `unlock()` are **no-ops** when `serialization_max_chunk_size == 0`. This means `big_value_mu_` only provides actual synchronization when big-value streaming is enabled. When it is disabled, all `lock_guard` calls on this mutex are effectively free, and the system relies on cooperative scheduling (no preemption during serialization) for correctness. Its role differs by mode: **PIT mode:** Prevents mutations from modifying a bucket while it is being serialized, and prevents journal entries from being written during bucket serialization. This enforces both serialize-before-mutate and the ordering invariant. **Non-PIT mode:** Prevents mutations from modifying a bucket while `BucketSaveCb` is serializing it (data consistency within a single bucket). Also serves as a barrier for `ConsumeJournalChange` and `OnMoved`. | Path | Mode | Lock held | Additional locks | |------|------|-----------|-----------------| | `BucketSaveCb` | Both | `big_value_mu_` | `GetLatch()` | | `OnDbChange` | Both | `big_value_mu_` | none | | `OnMoved` | Non-PIT | `big_value_mu_` | none | | `ConsumeJournalChange` | Both | `big_value_mu_` | none | ### `GetLatch()` (LocalLatch) Acquired by `BucketSaveCb` in addition to `big_value_mu_`. This is a non-preempting latch (`src/server/synchronization.h`) that increments a blocking counter, preventing `Heartbeat()` from running if `SerializeBucket` preempts (e.g., during large value serialization). ### `seq_cond_` (CondVarAny) Condition variable used in `HandleFlushData` to ensure records are pushed to the consumer in sequential order of their `rec_id_`. If fiber A has `id=5` and fiber B has `id=6`, B waits until A finishes pushing and updates `last_pushed_id_` to 5. This is needed because fibers are awakened in arbitrary order and reordering flushed chunks breaks the wire protocol. ## Inefficiencies and Improvement Goals This section identifies concrete problems in the current serialization design and the improvements that address them. The [Technical Roadmap](#technical-roadmap) maps these into an ordered execution plan. **Hard constraints** (apply to all improvements): - **Backpressure must be maintained.** A slow consumer must slow down the producer; we cannot buffer unboundedly. - **Bounded serialization memory.** Intermediate buffers must not grow proportionally to the dataset size. ### 1. Shard-wide stall under `big_value_mu_` **Problem.** `big_value_mu_` is a single shard-wide mutex that guards three distinct concerns simultaneously: 1. **Bucket atomicity** — the bucket must not be mutated while `SerializeBucket` iterates it. 2. **Serializer buffer exclusivity** — `serializer_` must not be written to by two fibers. 3. **Journal ordering** — journal entries must not interleave with bucket serialization. When `consume_fun_` fires under the lock (large value → `PushToConsumerIfNeeded` → `HandleFlushData`), the mutex is held across blocking I/O (`seq_cond_.wait`, `consumer_->ConsumeData`). This stalls the entire shard: traversal, mutations, journal writes, and `OnMoved` all contend on the same lock. **Why the mutex is needed in `ConsumeJournalChange`.** Transaction paths are already ordered by `OnDbChange` (it acquires `big_value_mu_` first, so `ConsumeJournalChange` on the same fiber cannot start while traversal holds the lock). The mutex matters for [paths that bypass `OnDbChange`](#journal-entries-without-ondbchange) — inline eviction and heartbeat-driven deletions. Without it, inline eviction could produce: **Counter-example without the `ConsumeJournalChange` mutex — inline eviction via `PrimeEvictionPolicy::Evict`:** 1. Traversal calls `SerializeBucket(B)` and begins iterating it; the bucket contains key `K` (a large hash, serialized element-by-element). The traversal preempts mid-entry via `consume_fun_`. 2. While the traversal is preempted, a client command triggers a DashTable insert on a different bucket. The insert finds no free slot in its home bucket and calls `PrimeEvictionPolicy::Evict`, which selects `K` as the victim. 3. `Evict` removes `K` from the table and — still on the same fiber, inside the DashTable insert — calls `journal::RecordEntry(DEL K)` directly, bypassing `OnDbChange`. 4. `ConsumeJournalChange` appends `DEL K` to the shared serializer buffer immediately, even though traversal has already emitted only a prefix of `K`'s baseline. 5. Traversal resumes and appends the remaining bytes of `K`'s baseline. Result: the replica's byte stream contains `[partial baseline of K] [DEL K] [rest of baseline of K]`. The RDB decoder sees a truncated entry followed by an unexpected journal opcode, or parses garbage if the lengths happen to align. Even if the `DEL` is parsed out-of-band, the subsequent baseline bytes reconstruct `K` on the replica, reversing the deletion. **Goal.** Separate the three concerns so that: - bucket atomicity uses bucket-level mechanisms (versioning + bucket completion state); - buffer exclusivity uses per-serializer isolation (each producer owns its buffer); - journal ordering uses bucket completion state and deferred deletion queues; - no code path blocks on downstream I/O while holding a shard-wide lock. **Approach.** See [§5 summary table](#5-summary-mutex-roles-and-their-replacements) for the full mapping. Key mechanisms: bucket completion state ([§2](#2-imprecise-bucket-completion-tracking)), separate serializer instances ([§3](#3-shared-serializer-buffer-and-wire-format-coupling)), and non-preempting chunk production. See Roadmap items 6, 7, 8, 9. ### 2. Imprecise bucket completion tracking **Problem.** The system has no explicit notion of when a bucket's baseline is *fully emitted* (see [Delayed Serialization of tiered entities](#delayed-serialization-of-tiered-entities) for details on how tiered values extend bucket completion beyond `SerializeBucket`'s return). This creates two issues: - A journal entry for key K can reach the output buffer (via `ConsumeJournalChange`) before K's delayed tiered baseline is drained — violating the [ordering invariant](#ordering-invariant) (see PR #6824). - [Non-transaction journal entries](#journal-entries-without-ondbchange) (expiry, eviction) bypass `OnDbChange` entirely. Since there is no bucket completion state to consult, `DEL` entries can interleave mid-serialization of the deleted key's baseline. **Goal.** Make "baseline fully emitted" precise for every bucket — including tiered values — so that ordering decisions can be expressed through per-bucket state rather than shard-wide mutex exclusion. **Approach.** - Introduce a per snapshot instance/bucket state machine: `NotVisited` → `Serializing` → `DelayedPending` → `Covered`. Each bucket is identified by a stable `BucketIdentity`. A bucket must remain in the tracking map (`currently_serialized_: map`) until all work completes; otherwise `version >= snapshot_version_` + absent-from-map would falsely read as `Covered`. State encoding: | State | Encoding | Meaning | |-------|----------|---------| | **NotVisited** | `version < snapshot_version_`, not in map | Traversal has not reached this bucket | | **Serializing** | `version >= snapshot_version_`, in map as `Serializing` | Traversal is iterating this bucket | | **DelayedPending** | `version >= snapshot_version_`, in map as `DelayedPending` | Iteration done, tiered entries still pending | | **Covered** | `version >= snapshot_version_`, not in map | Baseline fully emitted | - Associate delayed tiered entries with their originating bucket instead of the global queue. Transition to `Covered` only after all delayed entries are flushed. - **Transaction-driven mutations:** `OnDbChange` blocks (fiber-aware wait) on `Serializing`/`DelayedPending` buckets; proceeds immediately on `NotVisited` (serialize now) or `Covered` (baseline already emitted). Since `OnDbChange` → mutation → `RecordJournal` → `ConsumeJournalChange` is sequential on the mutation fiber, blocking `OnDbChange` guarantees baseline-before-journal. - **Non-transaction deletions (expiry, eviction):** `OnDbChange` is [infeasible on these paths](#journal-entries-without-ondbchange). Instead, use a **deferred deletion queue**: enqueue the key when the bucket is `Serializing`/`DelayedPending`; drain (emit `DEL`) when the bucket transitions to `Covered`. See roadmap item 6 for details. - **Latency tradeoff:** blocking `OnDbChange` on `DelayedPending` means a mutation fiber can stall for the duration of a tiered disk read (see roadmap item 6 for mitigation). See Roadmap items 3, 5, 6. ### 3. Shared serializer buffer and wire-format coupling **Problem.** `ConsumeJournalChange` and `SerializeBucket` write to the same `serializer_` buffer (the "buffer exclusivity" role from [§1](#1-shard-wide-stall-under-big_value_mu_)). Even with separate buffers, interleaved output from two serializers cannot be demuxed by the consumer without a framing protocol — a journal entry injected mid-RDB-entry produces an unparseable byte stream (see the [eviction counter-example](#1-shard-wide-stall-under-big_value_mu_) for a concrete scenario). **Goal.** Decouple journal and bucket serialization so they can produce data independently, without sharing a buffer or requiring a shard-wide lock for output integrity. **Approach.** - **Tagged-chunk wire format.** Extend the serialization format with tagged chunks: each mid-entry flush produces a chunk tagged with a stream ID. The consumer reassembles same-ID chunks before decoding. Small values (single chunk) use the existing format unchanged — no overhead. Controlled by a master-side flag (`--serialization_tagged_chunks`). - **Separate `RdbSerializer` per producer.** Give journal entries and bucket serialization their own serializer instances. Each produces tagged chunks independently. With separate buffers, `ConsumeJournalChange` no longer needs `big_value_mu_` for buffer exclusivity. - **Flushing strategy:** small values serialize the entire bucket without preemption; large values release the lock between chunks and apply backpressure outside the critical section. Bucket contents remain stable across the gap because PIT versioning prevents re-serialization and `OnDbChange` blocking (§1) prevents mutation. See Roadmap items 4, 7. ### 4. Non-PIT redundant journal traffic **Problem.** Non-PIT mode (eventual consistency for replication) emits every journal entry regardless of whether the snapshot traversal will cover the mutation. For self-contained entries (`SET`, `DEL`) this is redundant but harmless. For baseline-dependent entries (`HSET`, `LPUSH`, etc.) the system emits both the baseline value and the journal entry for every mutation, even when the traversal has not yet reached the bucket and will serialize the post-mutation value. **Goal.** In non-PIT mode, reduce journal traffic by skipping entries that are guaranteed to be covered by the traversal, without compromising eventual consistency. **Approach.** Use the bucket completion state machine (§1) to classify mutations: - **Self-contained entries** (`SET`, `DEL`, `EXPIRE`): skip for `NotVisited` buckets (traversal will see post-mutation value); emit for `Covered` buckets; emit conservatively for `Serializing`/`DelayedPending`. Classification is by **emitted journal command form**, not the user-facing command — commands like `JSON.SET` may be self-contained or not depending on arguments and must be validated individually. - **Baseline-dependent entries** (`HSET`, `LPUSH`, `SADD`, `ZADD`, `XADD`, `APPEND`, etc.): **SkipBoth** — suppress both baseline serialization and journal entry — when the bucket is `NotVisited`/`Serializing`, the mutation is a single-key in-memory update (no delete, no rehash, no insert), and no delayed tiered entry is in flight. Otherwise fall back to emit journal only or keep both. Each `SliceSnapshot` instance marks suppressed mutations locally; `ConsumeJournalChange` skips them without cross-instance coordination. See Roadmap items 10–15. ### 5. Summary: mutex roles and their replacements The previous subsections identify `big_value_mu_`'s three roles and the mechanisms that replace each: | Mutex role | Replacement | Source | |-----------|-------------|--------| | Journal ordering | Bucket completion state + deferred deletion queue | §1 | | Buffer exclusivity | Separate `RdbSerializer` per producer + tagged chunks | §3 | | Bucket atomicity (PIT) | Bucket versioning + `OnDbChange` blocking | §1, §2 | | Bucket atomicity (non-PIT) | Non-preempting chunk production | §2, §3 | Once all replacements are in place and validated, the mutex can be narrowed per mode and path, and eventually removed entirely. The roadmap structures this as a sequence of incremental steps (Phases 0–4), each validated before the next begins. ## Technical Roadmap The improvements identified above are interdependent. The safest path is to split them into small, verifiable steps that first improve observability and correctness scaffolding, then improve PIT and PIT+tiered correctness/robustness, and only after that tackle non-PIT optimizations and deeper serializer / lock-removal changes. Some of the groundwork — especially bucket-level completion state — is shared and should be laid early even if the first consumers are PIT-oriented. Because non-PIT is currently experimental and unused, the roadmap below does **not** treat current non-PIT behavior as a compatibility constraint. Later non-PIT phases may simplify, replace, or remove experimental behavior rather than preserving it. ### Phase 0 — Baseline and guardrails 1. **Document current invariants in code comments and tests.** - Make the key ordering rules explicit near `SliceSnapshot::OnDbChange`, `SliceSnapshot::ConsumeJournalChange`, `RestoreStreamer::OnDbChange`, and `DbSlice::FlushChangeToEarlierCallbacks`. - Prefer focused replication tests over purely end-to-end hash comparisons. The current broad replication suite is useful, but Phase 0 needs tests that fail specifically when an ordering invariant is broken. - Add focused tests for: - PIT: baseline-before-journal for baseline-dependent mutations. - tiered values: delayed serialization still preserves baseline-before-journal. - Suggested test strategy: - **PIT ordering guardrail:** add a test in `tests/dragonfly/replication_test.py` that starts full sync with `point_in_time_snapshot=true`, performs a small controlled set of baseline-dependent updates during full sync (`HSET`, `LPUSH`, `APPEND`, `XADD`), waits for stable sync, and then asserts exact key/value equality for only those keys. The intent is to make a baseline-before-journal violation fail on a tiny, debuggable workload. - **tiered delayed-entry guardrail:** rehabilitate the currently skipped tiered replication test in `tests/dragonfly/tiering_test.py` and make it assert not just final equivalence, but that concurrent writes to tiered keys during full sync do not lose updates. - Suggested assertions: - assert exact values for a small curated key set, not just whole-dataset hashes; - assert replica reaches stable sync and catches up via `check_all_replicas_finished`; - assert path-activation counters from logs where available (`side_saved`, `moved_saved`); - for tricky cases, prefer deterministic key-level checks over probabilistic stress-only validation. - Suggested scope split: - keep the existing large/stress replication tests as coarse regression coverage; - add a handful of small, deterministic Phase 0 tests whose only purpose is to guard the invariants this roadmap depends on. - Goal: freeze the current correctness contract before changing behavior. 2. **Add lightweight observability for snapshot/journal interleavings.** - Count how often `ConsumeJournalChange` runs while a bucket is being serialized. - Count flushes triggered under `big_value_mu_` versus outside it. - Suggested locations for counters / debug stats: - increment a counter when `ConsumeJournalChange` acquires the barrier while `serialize_bucket_running_` is true; - increment separate counters for `HandleFlushData` reached from under `big_value_mu_` versus from `PushSerialized` outside the critical section; - Suggested exposure: - start with log lines in the existing `Exit SnapshotSerializer` / replication progress logs; - if the signals become broadly useful, promote them to INFO/stats fields later. - Suggested rollout rule: - add observability before optimization, and require each new fast path to demonstrate that the expected path was actually exercised in tests. - Goal: validate which paths are actually hot and which optimizations are worth the risk. ### Phase 1 — PIT and PIT+tiered foundation 3. **Introduce explicit bucket-level completion state.** - **Prerequisites:** Phase 0.1–0.2. - Implement the per-snapshot-instance state machine described in [§1](#1-imprecise-bucket-completion-tracking): `NotVisited` → `Serializing` → `DelayedPending` → `Covered`, keyed by `BucketIdentity`. - Keep this state entirely instance-local to `SliceSnapshot` / `RestoreStreamer`. - Goal: replace vague "bucket iteration finished" reasoning with an explicit state machine that will later serve both PIT+tiered correctness and non-PIT decisions. 4. **Extend the wire format with tagged chunks.** - **Prerequisites:** none. - Implements the tagged-chunk format described in [§3](#3-shared-serializer-buffer-and-wire-format-coupling). Entries that may be split across preemption points are wrapped in a per-stream-tag envelope; single-chunk entries use the existing format unchanged (no overhead). - **Wire format:** `RDB_OPCODE_DF_MASK`-style flag bit (`DF_MASK_FLAG_CHUNKED`). When set, payload is `stream_tag: uint32, payload_length: uint32, payload: bytes`. Entries without the flag are unchanged. - **Enablement:** master-side flag (`--serialization_tagged_chunks`), not `DflyVersion` (which doesn't apply to DFS backups). The loader detects tagged chunks by the flag bit and reassembles transparently. - Pure format + loader-side work — no changes to serialization logic or locking. Can be developed independently of Phases 0–1. - **Scope:** replication and DFS backups. Only legacy `.rdb` format does not need tagged chunks (`SnapshotFlush::kDisallow`, no concurrent bucket serialization). - Why early: Phase 2 (item 7) needs separate serializers whose interleaved output requires tagged chunks for demuxing. - Goal: have the wire-format infrastructure ready before Phase 2 needs it. 5. **Associate delayed tiered serialization with bucket state.** - **Prerequisites:** 1.3. - Address the [tiered completion gap](#delayed-serialization-of-tiered-entities): associate `delayed_entries_` with their originating bucket instead of the global queue. - Only transition a bucket to `Covered` once its delayed tiered entries are emitted. - Goal: make "baseline fully emitted" precise, not just "bucket iteration finished". 6. **Use bucket completion state to harden PIT ordering guarantees.** - **Prerequisites:** 1.3 and 1.5. - Re-express the PIT ordering rule in terms of bucket completion state, not just mutex exclusion and `bucket.version`. - For in-memory values, PIT ordering is already sound by construction (sequential `OnDbChange` → mutation → `ConsumeJournalChange` on the same fiber). The real gap is **tiered delayed entries** (see [Delayed Serialization](#delayed-serialization-of-tiered-entities)): a journal entry can reach the buffer before the delayed baseline is drained. - **`OnDbChange` blocking:** block (fiber-aware wait) when the bucket is `Serializing` or `DelayedPending`; proceed on `NotVisited` (serialize now → `Covered`) or `Covered` (baseline already emitted). Because `OnDbChange` → mutation → `RecordJournal` → `ConsumeJournalChange` is sequential on the mutation fiber, blocking `OnDbChange` guarantees baseline-before-journal for all transaction-driven mutations. - **Deferred deletion queue** for [non-transaction journal paths](#journal-entries-without-ondbchange) (expiry, eviction — where `OnDbChange` is infeasible). When a deletion encounters a bucket in `Serializing`/`DelayedPending`, enqueue the key into a per-bucket `pending_deletions: vector` (bounded by bucket capacity, typically 12–14 slots). The traversal fiber drains the queue — emitting deferred `DEL` entries — when transitioning the bucket to `Covered`. For `NotVisited`/`Covered` buckets, `DEL` is emitted immediately as today. Properties: - no blocking, re-entrancy, or preemption on the deletion fiber; - baseline-before-journal ordering preserved by construction. - After this item, `big_value_mu_` is no longer needed for journal ordering, but is still needed for [buffer exclusivity](#3-shared-serializer-buffer-and-wire-format-coupling) (items 7–8). - **Latency tradeoff:** blocking `OnDbChange` on `DelayedPending` can stall a mutation fiber for the duration of a tiered disk read (`Future>`). Acceptable for correctness; monitor and consider `KeepBoth` fallback if latency is excessive. - Use Phase 0 tests to validate PIT+tiered behavior under preemption and backpressure. - Goal: make the existing production path easier to reason about before adding new behavior. ### Phase 2 — Reduce PIT blocking and serializer fragility 7. **Give journal and bucket serialization separate `RdbSerializer` instances.** - **Prerequisites:** 1.4 and 1.6. - NOTE: maybe unnecessary if rely on 1.4. - Addresses the [shared buffer problem](#3-shared-serializer-buffer-and-wire-format-coupling) and the primary [shard-wide stall hazard](#blocking-under-big_value_mu_). - The fix: give journal entries their own `RdbSerializer` instance. Bucket serialization and journal serialization never share a buffer. Each produces tagged chunks (item 4) that the consumer (replica or DFS loader) reassembles by stream tag. - The same separation is needed for **DFS backups** (no journal, but still PIT): once per-bucket locks (item 6) replace the shard-wide `big_value_mu_`, two concurrent `SerializeBucket` calls can run on different buckets (traversal fiber on bucket A preempts mid-entry via `consume_fun_`, `OnDbChange` serializes bucket B). Each call needs its own buffer; tagged chunks allow their interleaved output to be reassembled. - With separate serializers, `big_value_mu_` is no longer needed for buffer exclusivity. `ConsumeJournalChange` writes to its own serializer without acquiring `big_value_mu_` at all (journal ordering is already guaranteed by bucket completion state from item 6). - The flushing strategy depends on value size: - **Small values (typical case):** `consume_fun_` is disabled (or made a no-op) while the lock is held. `SerializeBucket` serializes the entire bucket into the bucket serializer's buffer without preempting — the buffer grows but stays bounded because most buckets contain only small entries. After `SerializeBucket` returns and the lock is released, the accumulated buffer is flushed as a tagged chunk outside the lock. - **Large values (e.g., a 1 GB set):** the existing `kFlushMidEntry` boundaries become lock-release points. After serializing a bounded batch of elements, the lock is released, the accumulated chunk is flushed (with backpressure) outside the lock, and the lock is re-acquired for the next batch. Bucket contents remain stable across the gap because (a) PIT versioning prevents re-serialization and (b) `OnDbChange` blocking (item 6) prevents the mutation from committing. Both are required: (a) alone prevents double-serialization but not mid-value mutation; (b) alone prevents mutation but not concurrent `SerializeBucket` entry. - Goal: eliminate blocking under `big_value_mu_` by removing the shared-buffer reason for holding it, rather than by restructuring the lock/unlock pattern around the same buffer. 8. **Simplify `rec_id_` / `seq_cond_` ordering once tagged-chunk delivery is proven.** - **Prerequisites:** 2.7, 1.4. - With tagged chunks support, we may not need a consistent global order between different fibers. In that case `rec_id_` / `seq_cond_.wait` become redundant. - Remove `rec_id_` / `seq_cond_` only after demonstrating (via tests and observability) that we do not corrupt the replication stream. - Goal: avoid removing an ordering mechanism before its replacement is demonstrably sound. 9. **Narrow `big_value_mu_` for PIT only after the above is proven.** - **Prerequisites:** 2.7–2.8. - Keep serialize-before-mutate semantics intact. - Remove or narrow mutex roles only where bucket state, serializer isolation, and tagged-chunk delivery already provide an equivalent correctness guarantee. - Goal: simplify the active production path incrementally, not speculatively. ### Phase 3 — Bring non-PIT onto the new foundation 10. **Add non-PIT-specific guardrails before changing non-PIT behavior.** - **Prerequisites:** 1.3 and 1.5. - Add focused tests for: - self-contained journal entries produce correct final state when baseline is fully emitted before or after the journal entry (no mid-entry interleaving); - moved items that cross the cursor are not lost; - any first non-PIT bucket-state redesign still converges under concurrent full-sync writes. - Suggested test strategy: - add a dedicated test with `point_in_time_snapshot=false` that mutates only with self-contained emitted commands (`SET`, `DEL`, `BITOP` rewritten to `SET`/`DEL`); - rehabilitate the currently skipped `test_replication_onmove_flow` instead of replacing it; if it is too flaky for CI, first reduce it to a smaller deterministic reproducer that still asserts both replica equality and `moved_saved > 0` from snapshot logs; - add non-PIT-specific observability such as counting how often `OnMoved` actually serializes a bucket and optionally classifying self-contained vs baseline-dependent journal entries by emitted command. - Goal: avoid touching experimental non-PIT behavior without dedicated guardrails. 11. **Stamp bucket version in non-PIT mode behind a feature flag.** - **Prerequisites:** 1.3 and 1.5 and 3.10. - Teach non-PIT `SerializeBucket` to call `SetVersion(snapshot_version_)`. - Since non-PIT is experimental, prefer the simplest implementation that matches the new bucket-state model rather than preserving legacy bookkeeping. - Validate that traversal, `OnMoved`, and any remaining bucket-version assumptions remain correct under the new design. - Goal: align non-PIT with the new foundation, not preserve its old implementation details. 12. **Implement self-contained journal classification in `ConsumeJournalChange`.** - **Prerequisites:** 3.11. - Classify emitted journal commands as self-contained vs baseline-dependent. - Initially use a conservative allowlist (`SET`, `DEL`, rewritten `BITOP`). - Skip `big_value_mu_` only for self-contained entries in non-PIT mode. - Goal: harvest the simplest safe non-PIT redesign win first, on top of the PIT-hardened foundation. 13. **Add instance-local suppression state for `SkipBoth`.** - **Prerequisites:** 1.3 and 1.5 and 3.11. - Let `OnDbChange` record a local suppression decision for mutations whose effects will be covered by future traversal. - Let the same snapshot instance's `ConsumeJournalChange` consult and clear that state. - Do not introduce shard-wide or cross-instance aggregation. - Goal: keep the redesign entirely within the existing per-instance callback pair, without carrying forward unnecessary experimental structure. 14. **Implement `SkipBoth` for the narrowest safe mutation subset.** - **Prerequisites:** 3.13. - Start with single-key, single-bucket, in-memory updates only. - Exclude inserts, deletes, rehash-triggering operations, and tiered cases. - Require bucket state to be `NotVisited` or `Serializing`. - Goal: prove the mechanism on a subset where correctness is easy to reason about. 15. **Expand `SkipBoth` eligibility only after targeted validation.** - **Prerequisites:** 3.14. - Re-evaluate `DelayedPending` once delayed-entry ownership is explicit. - Re-evaluate inserts only if bucket-touch coverage can be proven cheaply. - Re-evaluate tiered keys only if suppression can be tied to delayed-entry completion. - Goal: expand cautiously instead of generalizing the hard cases upfront. ### Phase 4 — Reassess `big_value_mu_` globally 16. **Narrow the lock's role by mode and path.** - **Prerequisites:** 2.9 for PIT changes; 3.12–3.15 for non-PIT changes. - PIT: keep only what is still required for serialize-before-mutate correctness. - non-PIT: remove it from self-contained journal entries first; then reconsider `OnMoved` and traversal interactions once serialization becomes non-preempting. - Goal: shrink the lock surface incrementally instead of attempting full removal at once; for non-PIT there is no obligation to preserve locking structure that exists only because of the experimental implementation. 17. **Attempt full `big_value_mu_` removal only after all prerequisites are in place.** - **Prerequisites:** 4.16. - Preconditions: - non-preempting bounded serialization chunks, - precise bucket coverage state, - delayed tiered ownership tracked to completion, - journal ordering independent of the mutex, - tests covering PIT, non-PIT, `OnMoved`, and tiered cases. - Goal: ensure lock removal is the final simplification step, not the first risky rewrite. ================================================ FILE: docs/thread-per-core.excalidraw ================================================ { "type": "excalidraw", "version": 2, "source": "https://excalidraw.com", "elements": [ { "type": "text", "version": 158, "versionNonce": 1897755639, "isDeleted": false, "id": "N2nJ6OaFNRqcFW23SO0u2", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 714.625, "y": 507.5390625000001, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 90, "height": 20, "seed": 1339600844, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676475959, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "I/O thread", "baseline": 14, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "I/O thread" }, { "type": "text", "version": 212, "versionNonce": 1838113753, "isDeleted": false, "id": "pZs66qxoJlWQcWuBsvAxk", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 829.125, "y": 509.4140625000001, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 90, "height": 20, "seed": 1172993740, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676475959, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "I/O thread", "baseline": 14, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "I/O thread" }, { "type": "text", "version": 223, "versionNonce": 1421110391, "isDeleted": false, "id": "qhrDskacRkr-tNl2Q3atR", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 948.6875, "y": 508.02455357142867, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 90, "height": 20, "seed": 1936794996, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676504307, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "I/O thread", "baseline": 14, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "I/O thread" }, { "type": "rectangle", "version": 344, "versionNonce": 1641244985, "isDeleted": false, "id": "jPwIU_a9_nxvuDFAcbzxM", "fillStyle": "cross-hatch", "strokeWidth": 1, "strokeStyle": "dotted", "roughness": 1, "opacity": 100, "angle": 0, "x": 712.375, "y": 537.2500000000001, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 431, "height": 30, "seed": 1029717964, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [ { "type": "text", "id": "U2-I9a2X4amHnB7NZFWGv" } ], "updated": 1658676541606, "link": null, "locked": false }, { "type": "text", "version": 239, "versionNonce": 1717412567, "isDeleted": false, "id": "U2-I9a2X4amHnB7NZFWGv", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 717.375, "y": 542.2500000000001, "strokeColor": "#000000", "backgroundColor": "transparent", "width": 421, "height": 20, "seed": 1592449524, "groupIds": [], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676541606, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "message bus", "baseline": 14, "textAlign": "center", "verticalAlign": "middle", "containerId": "jPwIU_a9_nxvuDFAcbzxM", "originalText": "message bus" }, { "type": "rectangle", "version": 315, "versionNonce": 208875257, "isDeleted": false, "id": "mBFE2wiT175ZxMSdmWcvQ", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 712.375, "y": 305.7916666666667, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 77, "height": 192, "seed": 352036980, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [ { "type": "text", "id": "tK1EcrkpG35slJ07z1dTT" } ], "updated": 1658676546251, "link": null, "locked": false }, { "type": "text", "version": 194, "versionNonce": 181803287, "isDeleted": false, "id": "tK1EcrkpG35slJ07z1dTT", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 717.375, "y": 376.7916666666667, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 67, "height": 50, "seed": 1251432308, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "thread\n1", "baseline": 43, "textAlign": "center", "verticalAlign": "middle", "containerId": "mBFE2wiT175ZxMSdmWcvQ", "originalText": "thread\n1" }, { "type": "rectangle", "version": 430, "versionNonce": 1426120247, "isDeleted": false, "id": "BY5OdEEKT0Y_DTy9Zgr9C", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 833.375, "y": 306.4166666666667, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 77, "height": 192, "seed": 1621471436, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [ { "id": "sIrssFTnnb9f1o26g1j88", "type": "text" }, { "type": "text", "id": "sIrssFTnnb9f1o26g1j88" } ], "updated": 1658676546251, "link": null, "locked": false }, { "type": "text", "version": 310, "versionNonce": 514622649, "isDeleted": false, "id": "sIrssFTnnb9f1o26g1j88", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 838.375, "y": 377.4166666666667, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 67, "height": 50, "seed": 711168500, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "thread\n2", "baseline": 43, "textAlign": "center", "verticalAlign": "middle", "containerId": "BY5OdEEKT0Y_DTy9Zgr9C", "originalText": "thread\n2" }, { "type": "text", "version": 76, "versionNonce": 1406533463, "isDeleted": false, "id": "45U617mr0L9ob4mc7Xozt", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 845.375, "y": 260.0865384615385, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 53, "height": 40, "seed": 1285924468, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "shard\nthread", "baseline": 34, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "shard\nthread" }, { "type": "text", "version": 85, "versionNonce": 2081260953, "isDeleted": false, "id": "vY-LnNlhD3qWMEtRPoU0t", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 964.9375, "y": 260.0865384615385, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 53, "height": 40, "seed": 817296972, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "shard\nthread", "baseline": 34, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "shard\nthread" }, { "type": "rectangle", "version": 458, "versionNonce": 190540409, "isDeleted": false, "id": "xvkm28eoejETjF3M78jpN", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1062.125, "y": 310.875, "strokeColor": "#000000", "backgroundColor": "#fa5252", "width": 77, "height": 187, "seed": 1482008524, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [ { "id": "nSQOBHdmN0bLo5OeoOD0P", "type": "text" }, { "type": "text", "id": "nSQOBHdmN0bLo5OeoOD0P" } ], "updated": 1658676546251, "link": null, "locked": false }, { "type": "text", "version": 337, "versionNonce": 2051102103, "isDeleted": false, "id": "nSQOBHdmN0bLo5OeoOD0P", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1067.125, "y": 379.375, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 67, "height": 50, "seed": 1058179828, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "thread\n4", "baseline": 43, "textAlign": "center", "verticalAlign": "middle", "containerId": "xvkm28eoejETjF3M78jpN", "originalText": "thread\n4" }, { "type": "text", "version": 156, "versionNonce": 1163506521, "isDeleted": false, "id": "H72xWL9unzb1mQiLvx7L4", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 1074.125, "y": 265.7115384615385, "strokeColor": "#000000", "backgroundColor": "#fab005", "width": 53, "height": 40, "seed": 1704611020, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 16, "fontFamily": 1, "text": "shard\nthread", "baseline": 34, "textAlign": "center", "verticalAlign": "top", "containerId": null, "originalText": "shard\nthread" }, { "type": "rectangle", "version": 510, "versionNonce": 1046208569, "isDeleted": false, "id": "jj-MVcNrzcH0DbFFo9noF", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 952.9375, "y": 310.1666666666667, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 77, "height": 193, "seed": 1374694167, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [ { "id": "NxhycN5eOsL0I52k0H-lh", "type": "text" }, { "id": "NxhycN5eOsL0I52k0H-lh", "type": "text" }, { "type": "text", "id": "NxhycN5eOsL0I52k0H-lh" } ], "updated": 1658676546251, "link": null, "locked": false }, { "type": "text", "version": 391, "versionNonce": 1308367831, "isDeleted": false, "id": "NxhycN5eOsL0I52k0H-lh", "fillStyle": "hachure", "strokeWidth": 1, "strokeStyle": "solid", "roughness": 1, "opacity": 100, "angle": 0, "x": 957.9375, "y": 381.6666666666667, "strokeColor": "#000000", "backgroundColor": "#fd7e14", "width": 67, "height": 50, "seed": 617412057, "groupIds": [ "DYa5vdmfX68EvWPAq2Beo" ], "strokeSharpness": "sharp", "boundElements": [], "updated": 1658676546251, "link": null, "locked": false, "fontSize": 20, "fontFamily": 1, "text": "thread\n3", "baseline": 43, "textAlign": "center", "verticalAlign": "middle", "containerId": "jj-MVcNrzcH0DbFFo9noF", "originalText": "thread\n3" } ], "appState": { "gridSize": null, "viewBackgroundColor": "#ffffff" }, "files": {} } ================================================ FILE: docs/transaction.md ================================================ # Life of a transaction This document describes how Dragonfly transactions provide atomicity and serializability for its multi-key and multi-command operations. ## Definitions ### Serializability Serializability is an isolation level for database transactions. Serializability describes multiple transactions, where a transaction is usually composed of multiple operations on multiple objects. Database can executed transactions in parallel (and the operations in parallel). Serializability guarantees the result is the same with, as if the transactions were executed one by one. i.e. to behave like executed in a serial order. Serializability doesn’t guarantee the resulting serial order respects recency. I.e. the serial order can be different from the order in which transactions were actually executed. E.g. Tx1 begins earlier than Tx2, but the result behaves as if Tx2 executed before Tx1. That is also to say, to satisfy the same Serializability, there can be more than one possible execution schedulings. ### Strict Serializability Strict serializability means that operations appear to have occurred in some order, consistent with the real-time ordering of those operations; e.g. if operation A completes before operation B begins, then A should appear to precede B in the serialization order. Strict serializability implies atomicity meaning, a transaction’s sub-operations do not appear to interleave with sub-operations from other transactions. It also implies serializability by definition (appear in some order...). Note that simple, single-key operations in Dragonfly are already strictly serializable because in a shared-nothing architecture each shard-thread performs operations on its keys sequentially. The complexity rises when we need to provide strict-serializability (aka serializability and linearizability) for operations spawning multiple keys. ## Transactions high level overview Transactions in Dragonfly are orchestrated by an abstract entity, called coordination layer. In reality, a client connection instance takes on itself the role of a coordinator: it coordinates a transaction every time it drives a redis or memcached command to completion. The algorithm behind Dragonfly transactions is based on the [VLL paper](https://www.cs.umd.edu/~abadi/papers/vldbj-vll.pdf). Every step within a coordinator is done sequentially. Therefore, it's easier to describe the flow using a sequence diagram. Below is a sequence diagram of a generic transaction consisting of multiple execution steps. In this diagram, the operation it executes touches keys in two different shards: `Shard1` and `Shard2`. ```mermaid %%{init: {'theme':'base'}}%% sequenceDiagram participant C as Coordinator participant S1 as Data Shard 1 participant S2 as Data Shard 2 par hop1 C->>+S1: Schedule and C->>+S2: Schedule S1--)C: Ack S2--)C: Ack end par hop2 C->>S1: Exec1 and C->>S2: Exec1 S1--)C: Ack S2--)C: Ack end par hop N+1 C->>S1: Exec N+Fin and C->>S2: Exec N+Fin S1--)-C: Ack S2--)-C: Ack end ``` The shared-nothing architecture of Dragonfly does not allow accessing each shard data directly from a coordinator fiber. Instead, the coordinator sends messages to the shards and instructs them what to do at each step. Every time, the coordinator sends a message, it blocks until it gets an answer. We call such interaction a *message hop* or a *hop* in short. The flow consists of two different phases: *scheduling* a transaction, and *executing* it. The execution phase may consist of one or more hops, depending on the complexity of the operation we model. *Note, that only the coordinator fiber is blocked. Its thread can still execute other fibers - like processing requests on other connections or handling operations for the shard it owns. This is the advantage of adopting fibers - they allow us to separate the execution context from OS threads.* ## Scheduling a transaction The transaction initiates with a scheduling hop, during which the coordinator sends to each shard the keys that shards handle. The coordinator sends messages to multiple shards asynchronously but it waits until all shards ack and confirm that the scheduling succeeded before it proceeds to the next steps. When the scheduling message is processed by a data shard, it adds the transaction to its local transaction queue (tx-queue). In order to provide serializability, i.e. to make sure that all shards order their scheduled transactions in the same order, Dragonfly maintains a global sequence counter that is used to induce a total order for all its transactions. This global counter is shared by all coordinator entities and is represented by an atomic integer. *This counter may be a source of contention - it breaks the shared nothing model, after all. However, in practice, we have not observed a significant impact on Dragonfly performance due to other optimizations we added. These will be detailed in the [Optimization](#optimizations) section below. Transactions in tx-queue in each shard are arranged by their sequence counter. As shown in the snippet below, a shard thread may receive transactions in a different sequence, so a transaction with a smaller id can be added to the tx-queue after a transaction with a larger id. If the scheduling algorithm running on the data shard, can not reorder the last added transaction, it fails the scheduling request. In that case, the coordinator reverts the scheduling operation by removing the tx from the shards, and retries the whole hop again by allocating a new sequence number. In reality the fail-rate of a scheduling attempt is low and the retries are rare (subject to contention on the keys). Note, inconsistent reordering happens when two coordinators try to schedule multi-shard transactions concurrently: ``` C1: enqueue msg to Shard1 to schedule T1 C2: enqueue msg to Shard1 to schedule T2 # enqueued earlier than C1 C1: enqueue msg to Shard2 to schedule T1 C2: enqueue msg to Shard2 to schedule T2 # enqueued later than C1 shard1: pull T2, add it to TxQueue, pull T1, add it to TxQueue shard2: pull T1, add it to TxQueue, pull T2, add it to TxQueue TxQueue1: T2, T1 # wrong order TxQueue2: T1, T2 ``` Once the transaction is added to the tx-queue, the shard also marks the tx-keys using the *intent* locks. Those locks do not block the flow of the underlying operation but merely express the intent to touch or modify the key. In reality, they are represented by a map: `lock:str->counter`. If `lock[key] == 2` it means the tx-queue has 2 pending transactions that plan to modify `key`. These intent locks are used for optimizations detailed below and are not required to implement the naive version of VLL algorithm. Once the scheduling hops converges, it means that the transaction entered the execution phase, in which it never rollbacks, or retries. Once it's been scheduled, VLL guarantees the progress of subsequent execution operations while providing strict-serializability guarantees. It's important to note that a scheduled transaction does not hold exclusivity on its keys. There could be other transactions that still mutate the keys it touches - these transactions were scheduled earlier and have not finished running yet, or even have not even started running. ## Executing a transaction Once the transaction is scheduled, the coordinator starts sending the execution messages. We break each command to one or more micro-ops and each operation corresponds to a single message hop. For example, "MSET" corresponds to a single micro-op "mset" that has the same semantics, but runs in parallel on all the involved shards. However, "RENAME" requires two micro-ops: fetching the data from two keys, and then the second hop - deleting/writing a key (depending whether the key is a source or a destination). Once a coordinator sends the micro-op request to all the shards, it waits for an answer. Only when all shards executed the micro-op and return the result, the coordinator is unblocked and it can proceed to the next hop. The coordinator is allowed to process the intermediary responses from the previous hops in order to define the next execution request. When a coordinator sends an execution request to data shards, it also specifies whether this execution is the last hop for that command. This is necessary, so that shards could do clean-up operations when running the last execution request: unlocking the keys and removing the transaction from the tx-queue. The shards always execute transactions at the head of the tx-queue. When the last execution hop for that transaction is executed the transaction is removed from the queue and the next one can be executed. This way we maintain the ordering guarantees specified by the scheduling order of the transactions and we maintain the serializability of operations across multiple shards. ## Multi-op transactions (Redis transactions) Redis transactions (MULTI/EXEC sequences) and commands produced by Lua scripts are modelled as consecutive commands within a Dragonfly transaction. In order to avoid ambiguity with terms, we call a Redis transaction - a multi-transaction in Dragonfly. The multi feature of the transactional framework allows running consecutive commands without rescheduling the transaction for each command as if they are part of one single transaction. This feature is transparent to the commands itself, so no changes are required for them to be used in a multi-transaction. There are three modes called "multi modes" in which a multi transaction can be executed, each with its own benefits and drawbacks. __1. Global mode__ The transaction is equivalent to a global transaction with multiple hops. It is scheduled globally and the commands are executed as a series of consequitive hops. This mode is required for global commands (like MOVE) and for accessing undeclared keys in Lua scripts. Otherwise, it should be avoided, because it prevents Dragonfly from running concurrently and thus greatly decreases throughput. __2. Lock ahead mode__ The transaction is equivalent to a regular transaction with multiple hops. It is scheduled on all keys used by the commands in the transaction block, or Lua script, and the commands are executed as a series of consecutive hops. __3. Non atomic mode__ All commands are executed as separate transactions making the multi-transaction not atomic. It vastly improves the throughput with contended keys, as locks are acquired only for single commands. This mode is useful for Lua scripts without atomicity requirements. ## Multi-op command squashing There are two fundamental problems to executing a series of consecutive commands on Dragonfly: * each command invocation requires an expensive hop * executing commands sequentially makes no use of our multi-threaded architecture Luckily we can make one important observation about command sequences. Given a sequence of commands _where each command needs to access only a single shard_, we can conclude that as long as they are part of one atomic transaction: * each command needs to preserve its order only relative to other commands accessing the same shard * commands accessing different shards can run in parallel The basic idea behind command squashing is identifying consecutive series of single-shard commands and separating them by shards, while maintaing their relative order withing each shard. Once the commands are separated, we can execute a single hop on all relevant shards. Within each shard the hop callback will execute one by one only those commands, that assigned to its respective shard. Because all commands are already placed on their relevant threads, no further hops are required and all command callbacks are executed inline. Reviewing our initial problems, command squashing: * Allows executing many commands with only one hop * Allows executing commands in pararllel ## Optimizations Out of order transactions - TBD ## Blocking commands (BLPOP) Redis has a rich api with around 200 commands. Few of those commands provide blocking semantics, which allow using Redis as publisher/subscriber broker. Redis (when running as a single node) is famously single threaded, and all its operations are strictly serializable. In order to build a multi-threaded memory store with the equivalent semantics as Redis, we had to design an algorithm that can parallelize potentially blocking operations and still provide strict serializability guarantees. This section focuses mainly on how to solve this challenge for BLPOP (BRPOP) command since it involves coordinating multiple keys and is considered the more complicated case. Other blocking commands can benefit from the same principles. ### BLPOP spec BLPOP key1 key2 key3 0 *BLPOP is a blocking list pop primitive. It is the blocking version of LPOP because it blocks the client connection when there are no elements to pop from any of the given lists. An element is popped from the head of the first list that is non-empty, with the given keys being checked in the order that they are given.* ### Non-blocking behavior of BLPOP When BLPOP is called, if at least one of the specified keys contains a non-empty list, an element is popped from the head of the list and returned to the caller together with the key it was popped from. Keys are checked in the order that they are given. Let's say that the key1 doesn't exist and key2 and key3 hold non-empty lists. Therefore, in the example above, BLPOP returns the element from list2. ### Blocking behavior If none of the specified keys exist, BLPOP blocks the connection until another client performs a LPUSH or RPUSH operation against one of the keys. Once new data is present on one of the lists, the client returns with the name of the key unblocking it and the popped value. ### Ordering semantics If a client tries to wait on multiple keys, but at least one key contains elements, the returned key / element pair is the first key from left to right that has one or more elements. In this case the client will not be blocked. So for instance, BLPOP key1 key2 key3 key4 0, assuming that both key2 and key4 are non-empty, will always return an element from key2. If multiple clients are blocked for the same key, the first client to be served is the one that was waiting longer (the first that was blocked for the key). Once a client is unblocked it does not retain any priority, when it blocks again with the next call to BLPOP, it will be served according to the queue order of clients already waiting for the same key. When a client is blocking on multiple keys at the same time, and elements are becoming available at the same time in multiple keys (because of a transaction), the client will be unblocked with the first key on the left that received data via push operation (assuming it has enough elements to serve our client, as there could be earlier clients waiting for this key as well). ### BLPOP and transactions If multiple elements are pushed either via a transaction or via variadic arguments of LPUSH command then BLPOP is waked after that transaction or command completely finished. Specifically, when a client performs `LPUSH listkey a b c`, then `BLPOP listkey 0` will pop `c`, because `lpush` pushes first `a`, then `b` and then `c` which will be the first one on the left. If a client executes a transaction that first pushes into a list and then pops from it atomically, then another client blocked on `BLPOP` won’t pop anything, because it waits for the transaction to finish. When BLPOP itself is run in a transaction its blocking behavior is disabled and it returns the “timed-out” response if there is no element to pop. ### Complexity of implementing BLPOP in Dragonfly The ordering semantics of BLPOP assume total order of the underlying operations. BLPOP must “observe” multiple keys simultaneously in order to determine which one is non-empty in left-to-right order. If there are no keys with items, BLPOP blocks, waits, and “observes” which key is being filled first. For the single-threaded Redis the order is determined by following the natural execution of operations inside the main execution thread. However, for a multi-threaded, shared-nothing execution, there is no concept of total order or a global synchronized timeline. For non-blockign scenario, "observing" keys is atomic because we lock the keys when executing a command in Dragonfly. However with blocking scenario for BLPOP, we do not have a built-in mechanism to determine which key was filled earlier - since, as stated, the concept of total order does not exist for multiple shards. ### Interesing examples to consider: **Ex1:** ``` client1: blpop X, Y // blocks client2: lpush X A client3: exist X Y ``` Client3 should always return 0. **Ex2:** ``` client1: BLPOP X Y Z client2: RPUSH X A client3: RPUSH X B; RPUSH Y B ``` **Ex3:** ``` client1: BLPOP X Y Z client2: RPUSH Z C client3: RPUSH X A client4: RPUSH X B; RPUSH Y B ``` ### BLPOP Ramblings There are two cases of how a key can appear and wake a blocking `BLPOP`: a. with lpush/rpush/rename commands. b. via multi-transaction. `(a)` is actually easy to reason about, because those commands operate on a single key and single key operations are strictly serializable in shared-nothing architecture. With `(b)` we need to consider the case where we have "BLPOP X Y 0" and then a multi-transaction fills both `y` and `x` using multiple "lpush" commands. Luckily, a multi-transaction in Dragonfly introduces a global barrier across all its shards, and it does not allow any other transactions to run as long as it does not finish. So the blocking "blpop" won't be awaken until the multi-transaction finishes its run. By that time the state of the keys will be well defined and "blpop" will be able to choose the first non empty key to pop from. ## Background reading: ### Strict Serializability Here is a [very nice diagram](https://jepsen.io/consistency) showing how various consistency models relate. Single node Redis is strictly serializable because all its operation are executed sequentially and atomically in a single thread. More formally: following the definition from https://jepsen.io/consistency/models/strict-serializable - due to the single threaded design of Redis, its transactions are executed in a global order, which is consistent with the main thread clock, hence it’s strictly serializable. Serializability is a global property that given a transaction log, there is an order with which transactions are consistent (the log order is not relevant). Example of serializable but not linearizable transaction: https://gist.github.com/pbailis/8279494 More material to read: * [Fauna Serializability vs Linearizability](https://fauna.com/blog/serializability-vs-strict-serializability-the-dirty-secret-of-database-isolation-levels) * [Jepsen consistency diagrams](https://jepsen.io/consistency) * [Strict Serializability definition](https://jepsen.io/consistency/models/strict-serializable) * [Example of serializable but not linearizable schedule](https://gist.github.com/pbailis/8279494) * [Atomic clocks and distributed databases](https://www.cockroachlabs.com/blog/living-without-atomic-clocks/) * [Another cockroach article about consistency](https://www.cockroachlabs.com/blog/consistency-model/) * [Abadi blog](http://dbmsmusings.blogspot.com/) * [Peter Beilis blog](http://www.bailis.org/blog) (both wrote lots of material on the subject) ================================================ FILE: fuzz/FUZZING.md ================================================ # AFL++ Fuzzing for Dragonfly ## Install AFL++ AFL++ must be built from source with `AFL_PERSISTENT_RECORD` enabled for crash replay. ```bash sudo apt update sudo apt install llvm-18-dev clang-18 lld-18 gcc-13-plugin-dev git clone --depth=1 --branch v4.34c https://github.com/AFLplusplus/AFLplusplus.git cd AFLplusplus # Enable AFL_PERSISTENT_RECORD (required for stateful crash replay) sed -i 's|// #define AFL_PERSISTENT_RECORD|#define AFL_PERSISTENT_RECORD|' include/config.h make distrib sudo make install ``` ## Prepare System ```bash sudo afl-system-config ``` `run_fuzzer.sh` also runs these checks automatically (core_pattern, CPU governor). ## Build Dragonfly ```bash cmake -B build-dbg -DUSE_AFL=ON -DCMAKE_BUILD_TYPE=Debug -GNinja ninja -C build-dbg dragonfly ``` ## Run Fuzzer ```bash cd fuzz ./run_fuzzer.sh # RESP protocol (default) ./run_fuzzer.sh memcache # Memcache text protocol ``` Configuration via environment variables: | Variable | Default | Description | |----------|---------|-------------| | `AFL_PROACTOR_THREADS` | `1` | Server threads (1 = most stable coverage) | | `AFL_LOOP_LIMIT` | `10000` | Iterations before server restart (= `AFL_PERSISTENT_RECORD`) | | `BUILD_DIR` | `build-dbg` | Path to build directory | ## Custom Mutators Each target has a custom AFL++ mutator that operates at the protocol level. Instead of flipping random bytes (which mostly breaks protocol framing and gets rejected by the parser), they: - Parse input into a list of commands - Mutate at the command/argument level (replace command, change argument, insert/remove commands, swap order) - Serialize back to valid protocol format | Target | Mutator | Details | |--------|---------|---------| | `resp` | `resp_mutator.py` | 150+ Redis commands, wraps in MULTI/EXEC | | `memcache` | `memcache_mutator.py` | Store/get/meta commands, noreply toggle | Mutators are loaded automatically by `run_fuzzer.sh`. AFL++'s built-in byte-level mutations also run alongside them (useful for parser edge cases). To use only the custom mutator: `export AFL_CUSTOM_MUTATOR_ONLY=1`. ## Crash Replay Dragonfly uses AFL++ persistent mode — the server accumulates state across iterations. A crash at iteration N depends on state built by inputs 1..N-1. `run_fuzzer.sh` syncs `AFL_PERSISTENT_RECORD` with `afl_loop_limit` so the full state history is always available on crash. When a crash occurs, AFL++ saves: ``` crashes/id:000000,sig:06,... # the crashing input crashes/RECORD:000000,cnt:000000 # first input after server start crashes/RECORD:000000,cnt:000001 # second input ... crashes/RECORD:000000,cnt:NNNNNN # input before the crash ``` ### Replay (RESP) ```bash ./build/dragonfly --port 6379 --logtostderr --proactor_threads 1 --dbfilename="" python3 fuzz/replay_crash.py fuzz/artifacts/resp/default/crashes 000000 ``` ### Replay (memcache) ```bash ./build/dragonfly --port 6379 --memcached_port=11211 --logtostderr --proactor_threads 1 --dbfilename="" python3 fuzz/replay_crash.py fuzz/artifacts/memcache/default/crashes 000000 127.0.0.1 11211 ``` ### Package crash for sharing ```bash cd fuzz # RESP ./package_crash.sh 000000 # Memcache ./package_crash.sh 000000 fuzz/artifacts/memcache/default/crashes ``` Creates `crash-000000.tar.gz` containing crash data and `replay_crash.py`. The recipient runs: ```bash # RESP ./build/dragonfly --port 6379 --logtostderr --proactor_threads 1 --dbfilename="" python3 replay_crash.py crashes 000000 # Memcache ./build/dragonfly --port 6379 --memcached_port=11211 --logtostderr --proactor_threads 1 --dbfilename="" python3 replay_crash.py crashes 000000 127.0.0.1 11211 ``` ## Seed Corpus | Target | Directory | Seeds | Coverage | |--------|-----------|-------|----------| | `resp` | `seeds/resp/` | 79 | string, list, hash, set, zset, stream, JSON, search, bloom, geo, HLL, bitops, scripting, ACL, pub/sub, transactions, server ops | | `memcache` | `seeds/memcache/` | 15 | set/get, add/replace, append/prepend, cas, incr/decr, delete, multiget, gat, noreply, meta commands, flush, stats | To add a new RESP seed: ``` *3 $3 SET $3 key $5 value ``` To add a new memcache seed: ``` set mykey 0 0 5 hello get mykey ``` ================================================ FILE: fuzz/dict/memcache.dict ================================================ # Memcache text protocol dictionary for AFL++ # Store commands "set" "add" "replace" "append" "prepend" "cas" # Retrieval commands "get" "gets" "gat" "gats" # Utility commands "delete" "incr" "decr" "flush_all" "stats" "version" "quit" # Meta commands "ms" "mg" "md" "ma" "mn" "me" # Flags/options "noreply" # Common keys "key" "mykey" "k1" "k2" "k3" "counter" # Numbers "0" "1" "5" "10" "100" "1000" "65535" "4294967295" "99999999999" # Expiry values "0" "30" "3600" "9999999" # Line endings "\x0d\x0a" # Partial commands for edge cases "set " "get " "delete " "incr " "decr " "cas " "gat " # Malformed patterns "\x0d" "\x0a" "\x00" "\xff" " " " " "" ================================================ FILE: fuzz/dict/resp.dict ================================================ # AFL++ dictionary for RESP protocol # Dragonfly command keywords and common patterns # RESP protocol markers "*" "$" "+" "-" ":" "\x0d\x0a" # Common commands - String operations "GET" "SET" "MGET" "MSET" "INCR" "DECR" "APPEND" "STRLEN" "SETEX" "SETNX" "GETSET" "GETRANGE" "SETRANGE" # List operations "LPUSH" "RPUSH" "LPOP" "RPOP" "LLEN" "LRANGE" "LINDEX" "LSET" "LTRIM" # Hash operations "HSET" "HGET" "HMSET" "HMGET" "HGETALL" "HDEL" "HEXISTS" "HLEN" "HKEYS" "HVALS" "HINCRBY" # Set operations "SADD" "SREM" "SMEMBERS" "SISMEMBER" "SCARD" "SINTER" "SUNION" "SDIFF" "SPOP" # Sorted set operations "ZADD" "ZREM" "ZRANGE" "ZRANGEBYSCORE" "ZRANK" "ZSCORE" "ZCARD" "ZCOUNT" "ZINCRBY" # Key operations "DEL" "EXISTS" "EXPIRE" "TTL" "PERSIST" "KEYS" "SCAN" "TYPE" "RENAME" "RENAMENX" # Transaction commands "MULTI" "EXEC" "DISCARD" "WATCH" "UNWATCH" # Pub/Sub commands "PUBLISH" "SUBSCRIBE" "UNSUBSCRIBE" "PSUBSCRIBE" "PUNSUBSCRIBE" # Stream commands "XADD" "XREAD" "XRANGE" "XLEN" "XDEL" "XTRIM" "XGROUP" "XREADGROUP" # JSON commands "JSON.SET" "JSON.GET" "JSON.DEL" "JSON.TYPE" "JSON.NUMINCRBY" "JSON.ARRAPPEND" "JSON.ARRLEN" # Bloom filter commands "BF.ADD" "BF.EXISTS" "BF.RESERVE" "BF.MADD" "BF.MEXISTS" # HyperLogLog commands "PFADD" "PFCOUNT" "PFMERGE" # Geo commands "GEOADD" "GEODIST" "GEORADIUS" "GEOSEARCH" # Server commands "PING" "ECHO" "INFO" "DBSIZE" "SELECT" # Cluster commands "CLUSTER" "READONLY" "READWRITE" # Common keys for testing "key" "mykey" "key1" "key2" "test" "foo" "bar" "user:1" "session:123" # Common values "value" "hello" "world" "123" "0" "1" "-1" # Number patterns (0, 1, -1 already above) "100" "1000" "-100" # Special arguments "NX" "XX" "EX" "PX" "GT" "LT" "WITHSCORES" "LIMIT" "COUNT" "MATCH" # Small RESP framing patterns (larger patterns removed — AFL++ warned about >33B tokens) "*1\x0d\x0a$" "*2\x0d\x0a$" "*3\x0d\x0a$" # Scripting commands "EVAL" "EVALSHA" "EVAL_RO" "EVALSHA_RO" "SCRIPT" # Bitfield commands "BITFIELD" "BITFIELD_RO" "BITOP" "BITCOUNT" "BITPOS" "GETBIT" "SETBIT" # More sorted set operations "ZINTER" "ZUNION" "ZINTERSTORE" "ZUNIONSTORE" "ZPOPMIN" "ZPOPMAX" "ZMPOP" # Edge case numbers "9223372036854775807" "-9223372036854775808" "2147483647" "-2147483648" "0.0" "-0.0" "inf" "-inf" "+inf" "nan" # Stream IDs and patterns "0-0" "0-*" "$" ">" "*" "MAXLEN" "MINID" # JSON paths "$.." "$[*]" "$[-1]" "$.name" "$..name" # RESP protocol edge cases "*-1\x0d\x0a" "$-1\x0d\x0a" "*0\x0d\x0a" "$0\x0d\x0a\x0d\x0a" # Lua scripting patterns "return redis.call" "redis.pcall" "KEYS[1]" "ARGV[1]" # Bitfield subcommands "OVERFLOW" "WRAP" "SAT" "FAIL" # Aggregate options "AGGREGATE" "SUM" "MIN" "MAX" "WEIGHTS" # Binary edge cases "\x00" "\xff" "\x00\x00\x00\x00" # --- Additional commands for broader coverage --- # Missing key operations "COPY" "SORT" "SORT_RO" "UNLINK" "TOUCH" "OBJECT" "RANDOMKEY" "DUMP" "RESTORE" "WAIT" "EXPIREAT" "PEXPIRE" "PEXPIREAT" "PEXPIRETIME" "EXPIRETIME" "PTTL" # String commands "GETDEL" "GETEX" "INCRBYFLOAT" "DECRBY" "INCRBY" "MSETNX" "PSETEX" "SUBSTR" # List commands "LPOS" "LMPOP" "LMOVE" "BLMOVE" "BLMPOP" "BLPOP" "BRPOP" "LPUSHX" "RPUSHX" "RPOPLPUSH" # Set commands "SRANDMEMBER" "SMOVE" "SMISMEMBER" "SINTERCARD" "SDIFFSTORE" "SINTERSTORE" "SUNIONSTORE" # Sorted set commands "ZDIFF" "ZDIFFSTORE" "ZLEXCOUNT" "ZRANGEBYLEX" "ZRANGESTORE" "ZRANDMEMBER" "ZREVRANGE" "ZREVRANGEBYLEX" "ZREVRANGEBYSCORE" "ZREVRANK" "ZMSCORE" "ZREMRANGEBYLEX" "ZREMRANGEBYRANK" "ZREMRANGEBYSCORE" "BZMPOP" "BZPOPMIN" "BZPOPMAX" # Hash commands "HRANDFIELD" "HSCAN" "HSETEX" "HSETNX" "HSTRLEN" "HINCRBYFLOAT" "HEXPIRE" # Server/client commands "CLIENT" "CONFIG" "MEMORY" "ACL" "HELLO" "COMMAND" "LATENCY" "SLOWLOG" "BGSAVE" "LASTSAVE" "ROLE" # Subcommands "OBJECT ENCODING" "OBJECT HELP" "OBJECT FREQ" "OBJECT IDLETIME" "CLIENT SETNAME" "CLIENT GETNAME" "CLIENT LIST" "CLIENT ID" "CLIENT INFO" "CONFIG GET" "CONFIG SET" "MEMORY USAGE" "MEMORY DOCTOR" "ACL LIST" "ACL WHOAMI" "ACL SETUSER" "COMMAND COUNT" "COMMAND INFO" # Scan operations (HSCAN already above) "SSCAN" "ZSCAN" # Function/script commands "FUNCTION" "FUNCTION LOAD" "FUNCTION LIST" "FUNCTION DELETE" # More JSON commands "JSON.ARRINSERT" "JSON.ARRTRIM" "JSON.ARRPOP" "JSON.ARRINDEX" "JSON.OBJKEYS" "JSON.OBJLEN" "JSON.STRAPPEND" "JSON.STRLEN" "JSON.TOGGLE" "JSON.CLEAR" "JSON.MERGE" "JSON.MGET" "JSON.MSET" "JSON.DEBUG" "JSON.RESP" # More Geo commands "GEOPOS" "GEOHASH" "GEOSEARCHSTORE" "GEORADIUSBYMEMBER" # Search commands "FT.CREATE" "FT.SEARCH" "FT.DROPINDEX" "FT.INFO" "FT.ALTER" # Additional arguments "REPLACE" "ABSTTL" "IDLETIME" "FREQ" "LEFT" "RIGHT" "BEFORE" "AFTER" "BY" "ASC" "DESC" "ALPHA" "STORE" "REV" "BYSCORE" "BYLEX" "CH" "KEEPTTL" "EXAT" "PXAT" "ENCODING" "REFCOUNT" # Malformed RESP for edge-case testing "*-2\x0d\x0a" "*999999\x0d\x0a" "$-2\x0d\x0a" "$999999999\x0d\x0a" "*\x0d\x0a" "$\x0d\x0a" "+\x0d\x0a" "-\x0d\x0a" ":\x0d\x0a" # Inline commands (no RESP framing) "PING\x0d\x0a" "PING\x0a" "SET key value\x0d\x0a" "GET key\x0a" "QUIT\x0d\x0a" # More binary patterns "\xfe\xff\x00\x01" "\x0d\x0a\x0d\x0a" "\x0d\x0d\x0a\x0a" "\x00\x01\x02\x03" # RESP edge cases (small fragments only) "$0\x0d\x0a\x0d\x0a" "$-1\x0d\x0a" ================================================ FILE: fuzz/generate_targeted_seeds.py ================================================ #!/usr/bin/env python3 """Generate PR-targeted fuzzing inputs from a code diff using an LLM. Fuzzing terminology used in this file: - Seed: An initial input file for the fuzzer. Each seed is a sequence of commands encoded in RESP wire format (see fuzz/seeds/resp/*.resp for examples). The fuzzer starts from these seeds and mutates them to explore code paths. - Targeted seed: A seed crafted specifically to exercise code paths changed in a PR. We send the PR diff + all existing seeds to an LLM, and it generates new seeds that target the changed code. - Focus commands: A list of command names (e.g. ["SET", "GET"]) that the AFL++ mutator should prefer. When set, the mutator picks these commands ~70% of the time instead of choosing uniformly from all known commands. Flow: 1. Read unified diff from stdin, extract changed C++ file paths. 2. Load all existing seed files so the LLM knows what's already covered. 3. Call Claude API: send the diff + seeds, get back JSON with command arrays + focus commands. 4. Encode commands as RESP wire format, write to output dir. The LLM returns commands as plain arrays (e.g. ["SET", "key", "value"]) and we handle RESP encoding ourselves — this avoids JSON escaping issues and byte-count mismatches. When ANTHROPIC_API_KEY is not available (e.g. fork PRs), exits with no output and the fuzzer runs with the existing seed corpus as-is. Usage: git diff base..HEAD | python3 fuzz/generate_targeted_seeds.py --output-dir /tmp/seeds """ import argparse import glob import json import os import re import sys # Max diff lines to send to the LLM (Haiku handles ~200K tokens, so this is generous) MAX_DIFF_LINES = 20000 LLM_SYSTEM_PROMPT = """\ You are a fuzzing expert for Dragonfly, a Redis-compatible in-memory database written in C++. Your job: given a code diff and existing seed files, generate NEW fuzzing seeds that \ target the changed code paths. You also return a list of Redis commands to focus on. ## Dragonfly architecture (for context) - src/server/*_family.cc — command implementations (e.g. string_family.cc has GET/SET/INCR) - src/server/main_service.cc — command dispatch, MULTI/EXEC - src/server/db_slice.cc — per-shard key-value storage - src/facade/redis_parser.cc — RESP protocol parsing - src/facade/dragonfly_connection.cc — connection handling - src/core/ — data structures (dash table, dense_set, compact_object, etc.) - src/server/journal/ — replication journal - src/server/cluster/ — cluster mode - src/server/search/ — search module (FT.* commands) - src/server/tiering/ — SSD tiering ## What to generate Based on the diff, figure out: 1. What commands are affected (new, modified, or impacted by infrastructure changes) 2. What edge cases the changes introduce (boundary values, empty inputs, error paths) 3. What command sequences would stress the changed code ## Output format Return valid JSON (no markdown, no explanation): { "focus_commands": ["CMD1", "CMD2", ...], "seeds": [ { "name": "pr_something.resp", "commands": [ ["SET", "mykey", "myvalue"], ["GET", "mykey"] ] } ] } Each "commands" entry is a list of Redis commands. Each command is a list of strings \ (command name + arguments). We handle RESP wire encoding — just give plain strings. CRITICAL: Output must be valid JSON. Do NOT use code expressions like "x" * 1024 or \ string concatenation. For long values write actual repeated characters inline, e.g. \ "xxxxxxxxxx" (just the literal string). Keep values short (under 100 chars) — \ the fuzzer will mutate and grow them. Rules for seeds: - 3-10 commands per seed, forming a logical sequence - Include setup commands before queries (e.g. SET before GET) - Test edge cases from the diff: boundary values, empty/huge inputs, type mismatches - Include at least one seed wrapping commands in MULTI/EXEC - Generate 3-8 seeds total - Prefix all names with "pr_" """ def extract_changed_files(diff_text): """Extract C++/header file paths from a unified diff.""" files = [] for match in re.finditer(r"^diff --git a/(.+?) b/(.+?)$", diff_text, re.MULTILINE): path = match.group(2) if re.search(r"\.(cc|h)$", path): files.append(path) return sorted(set(files)) def load_example_seeds(seeds_dir): """Load ALL existing seed files to show the LLM what's already covered. We send every seed so the LLM has full context about existing coverage and can generate complementary seeds for new/changed code paths. """ examples = [] for path in sorted(glob.glob(os.path.join(seeds_dir, "*.resp"))): name = os.path.basename(path) with open(path) as f: examples.append({"name": name, "content": f.read()}) return examples def truncate_diff(diff_text, max_lines=MAX_DIFF_LINES): """Truncate diff to max_lines.""" lines = diff_text.splitlines(True) if len(lines) <= max_lines: return diff_text, len(lines) return "".join(lines[:max_lines]), max_lines def encode_resp(commands): """Encode a list of commands as RESP wire format. Each command is a list of string arguments, e.g. ["SET", "key", "value"]. Returns bytes in RESP format: *N\\r\\n$len\\r\\narg\\r\\n... """ result = bytearray() for cmd in commands: if not cmd: continue result.extend(b"*%d\r\n" % len(cmd)) for arg in cmd: arg_bytes = arg.encode() if isinstance(arg, str) else arg result.extend(b"$%d\r\n%s\r\n" % (len(arg_bytes), arg_bytes)) return bytes(result) def call_llm(diff_text, changed_files, example_seeds, api_key, model): """Call Claude API to generate targeted seeds from the diff.""" try: import anthropic except ImportError: print("anthropic package not available", file=sys.stderr) return None truncated, num_lines = truncate_diff(diff_text) # Build examples section — show existing seeds so the LLM knows what's covered examples_text = "" for ex in example_seeds: examples_text += "--- %s ---\n%s\n\n" % (ex["name"], ex["content"].rstrip()) prompt = ( "Here are ALL existing seed files (RESP wire format) so you know what's already covered:\n\n" "%s\n" "Now analyze this diff and generate targeted fuzzing seeds.\n\n" "Changed files: %s\n\n" "Diff (%d lines):\n```\n%s\n```\n\n" "Respond with valid JSON only." ) % (examples_text, ", ".join(changed_files), num_lines, truncated) client = anthropic.Anthropic(api_key=api_key) response = client.messages.create( model=model, max_tokens=16384, system=LLM_SYSTEM_PROMPT, messages=[{"role": "user", "content": prompt}], ) text = response.content[0].text.strip() # Try to extract JSON from the response (LLMs sometimes wrap in markdown) json_match = re.search(r"```(?:json)?\s*\n(.*?)\n```", text, re.DOTALL) if json_match: text = json_match.group(1) try: return json.loads(text) except json.JSONDecodeError: pass # Try to find the outermost { ... } and parse that brace_match = re.search(r"\{.*\}", text, re.DOTALL) if brace_match: try: return json.loads(brace_match.group(0)) except json.JSONDecodeError: pass # Log raw response for debugging and raise print("Raw LLM response (first 2000 chars):\n%s" % text[:2000], file=sys.stderr) raise ValueError("Could not parse LLM response as JSON") def write_output(output_dir, focus_commands, seeds): """Write seed files and focus_commands.json to output directory.""" os.makedirs(output_dir, exist_ok=True) focus_path = os.path.join(output_dir, "focus_commands.json") with open(focus_path, "w") as f: json.dump(focus_commands, f) print("Wrote %d focus commands to %s" % (len(focus_commands), focus_path), file=sys.stderr) written = 0 for seed in seeds: name = seed.get("name") or "pr_seed_%d.resp" % written if not name.endswith(".resp"): name += ".resp" path = os.path.join(output_dir, name) with open(path, "wb") as f: f.write(seed["content"]) written += 1 print("Wrote %d seed files to %s" % (written, output_dir), file=sys.stderr) def main(): parser = argparse.ArgumentParser(description="Generate targeted fuzzing seeds from a PR diff") parser.add_argument( "--output-dir", default="fuzz/seeds/pr_targeted", help="Directory to write seeds and focus_commands.json", ) parser.add_argument( "--seeds-dir", default=None, help="Directory with existing seed files (auto-detected if not set)", ) parser.add_argument( "--api-key", default=None, help="Anthropic API key (or set ANTHROPIC_API_KEY env var)" ) parser.add_argument("--model", default="claude-haiku-4-5-20251001", help="Claude model to use") args = parser.parse_args() api_key = args.api_key or os.environ.get("ANTHROPIC_API_KEY") if not api_key: print("No ANTHROPIC_API_KEY set, skipping seed generation", file=sys.stderr) return diff_text = sys.stdin.read() if not diff_text.strip(): print("No diff provided, skipping", file=sys.stderr) return changed_files = extract_changed_files(diff_text) if not changed_files: print("No C++ files in diff, skipping", file=sys.stderr) return print("Changed C++ files: %s" % ", ".join(changed_files), file=sys.stderr) # Find seeds directory seeds_dir = args.seeds_dir if not seeds_dir: script_dir = os.path.dirname(os.path.abspath(__file__)) seeds_dir = os.path.join(script_dir, "seeds", "resp") example_seeds = load_example_seeds(seeds_dir) print("Loaded %d existing seeds" % len(example_seeds), file=sys.stderr) try: result = call_llm(diff_text, changed_files, example_seeds, api_key, args.model) except Exception as e: print("LLM call failed: %s" % e, file=sys.stderr) return if not result: return # Extract focus commands focus_commands = result.get("focus_commands", []) if not isinstance(focus_commands, list): focus_commands = [] # Encode command arrays as RESP and collect valid seeds valid_seeds = [] for s in result.get("seeds", []): if not isinstance(s, dict) or "commands" not in s: continue commands = s["commands"] if not isinstance(commands, list) or not commands: continue # Filter out non-list entries and ensure all args are strings clean_commands = [] for cmd in commands: if isinstance(cmd, list) and cmd: clean_commands.append([str(arg) for arg in cmd]) if not clean_commands: continue content = encode_resp(clean_commands) if content: valid_seeds.append({"name": s.get("name") or "", "content": content}) else: print("Discarding empty seed: %s" % s.get("name", "?"), file=sys.stderr) if not valid_seeds and not focus_commands: print("LLM returned no usable output", file=sys.stderr) return print( "Generated %d seeds, %d focus commands" % (len(valid_seeds), len(focus_commands)), file=sys.stderr, ) write_output(args.output_dir, focus_commands, valid_seeds) if __name__ == "__main__": main() ================================================ FILE: fuzz/memcache_mutator.py ================================================ """AFL++ custom mutator for memcache text protocol. Mutates at the command level instead of random bytes, keeping memcache protocol framing valid. Usage: export PYTHONPATH=/path/to/dragonfly/fuzz export AFL_PYTHON_MODULE=memcache_mutator afl-fuzz ... """ import random # fmt: off # (command, type, min_extra_args, max_extra_args) # type: "store" = key flags exptime bytes [noreply]\r\ndata\r\n # "cas" = key flags exptime bytes cas_unique [noreply]\r\ndata\r\n # "get" = key [key ...]\r\n # "gat" = exptime key [key ...]\r\n # "delta" = key delta [noreply]\r\n # "del" = key [noreply]\r\n # "bare" = \r\n (no args) # "meta_store" = key datalen [flags...]\r\ndata\r\n # "meta" = key [flags...]\r\n COMMANDS = [ # Store commands ("set", "store"), ("add", "store"), ("replace", "store"), ("append", "store"), ("prepend", "store"), ("cas", "cas"), # Retrieval ("get", "get"), ("gets", "get"), ("gat", "gat"), ("gats", "gat"), # Delete / arithmetic ("delete", "del"), ("incr", "delta"), ("decr", "delta"), # Utility ("flush_all", "bare"), ("stats", "bare"), ("version", "bare"), ("quit", "bare"), # Meta commands ("ms", "meta_store"), ("mg", "meta"), ("md", "meta"), ("ma", "meta"), ("mn", "bare"), ("me", "meta"), ] # fmt: on KEYS = [b"k", b"key", b"k1", b"k2", b"k3", b"mykey", b"counter", b"buf"] VALUES = [b"abc", b"hello", b"x", b"", b"0", b"12345", b"\x00\xff", b"a" * 100] EXPIRY = [b"0", b"10", b"100", b"3600", b"9999999"] FLAGS = [b"0", b"1", b"255", b"65535", b"4294967295"] DELTAS = [b"1", b"5", b"10", b"100", b"0", b"99999999999"] META_FLAGS = [b"T30", b"N10", b"R", b"v", b"h", b"l", b"t", b"c", b"f1", b"q", b"k"] FUZZ_VALUES = [b"\x00", b"\xff" * 4, b"\r\n", b"A" * 256, b"-1", b"NaN"] def init(seed): random.seed(seed) def _random_key(): if random.random() < 0.8: return random.choice(KEYS) return random.choice(FUZZ_VALUES) def _random_value(): if random.random() < 0.7: return random.choice(VALUES) return random.choice(FUZZ_VALUES) def _random_command(): """Generate a single random memcache command.""" cmd_name, cmd_type = random.choice(COMMANDS) cmd = cmd_name.encode() if isinstance(cmd_name, str) else cmd_name if cmd_type == "store": key = _random_key() flags = random.choice(FLAGS) expiry = random.choice(EXPIRY) value = _random_value() noreply = b" noreply" if random.random() < 0.3 else b"" return ( cmd + b" " + key + b" " + flags + b" " + expiry + b" " + str(len(value)).encode() + noreply + b"\r\n" + value + b"\r\n" ) elif cmd_type == "cas": key = _random_key() flags = random.choice(FLAGS) expiry = random.choice(EXPIRY) value = _random_value() cas_id = str(random.randint(0, 99999)).encode() noreply = b" noreply" if random.random() < 0.3 else b"" return ( cmd + b" " + key + b" " + flags + b" " + expiry + b" " + str(len(value)).encode() + b" " + cas_id + noreply + b"\r\n" + value + b"\r\n" ) elif cmd_type == "get": nkeys = random.randint(1, 4) keys = b" ".join(_random_key() for _ in range(nkeys)) return cmd + b" " + keys + b"\r\n" elif cmd_type == "gat": expiry = random.choice(EXPIRY) nkeys = random.randint(1, 3) keys = b" ".join(_random_key() for _ in range(nkeys)) return cmd + b" " + expiry + b" " + keys + b"\r\n" elif cmd_type == "delta": key = _random_key() delta = random.choice(DELTAS) noreply = b" noreply" if random.random() < 0.3 else b"" return cmd + b" " + key + b" " + delta + noreply + b"\r\n" elif cmd_type == "del": key = _random_key() noreply = b" noreply" if random.random() < 0.3 else b"" return cmd + b" " + key + noreply + b"\r\n" elif cmd_type == "meta_store": key = _random_key() value = _random_value() meta_flags = b" ".join(random.sample(META_FLAGS, random.randint(0, 3))) extra = (b" " + meta_flags) if meta_flags else b"" return ( cmd + b" " + key + b" " + str(len(value)).encode() + extra + b"\r\n" + value + b"\r\n" ) elif cmd_type == "meta": key = _random_key() meta_flags = b" ".join(random.sample(META_FLAGS, random.randint(0, 3))) extra = (b" " + meta_flags) if meta_flags else b"" return cmd + b" " + key + extra + b"\r\n" else: # bare return cmd + b"\r\n" def _parse_mc_commands(buf): """Best-effort parse of memcache text protocol into list of raw command lines. Returns (commands, success) where commands is a list of bytes.""" commands = [] data = bytes(buf) pos = 0 while pos < len(data): end = data.find(b"\r\n", pos) if end < 0: break line = data[pos:end] pos = end + 2 # Check if this is a store command that has a data block parts = line.split(b" ") if len(parts) >= 5 and parts[0].lower() in ( b"set", b"add", b"replace", b"append", b"prepend", b"cas", ): try: nbytes = int(parts[4]) if pos + nbytes + 2 <= len(data): value = data[pos : pos + nbytes] pos += nbytes + 2 # skip value + \r\n commands.append((line, value)) continue except (ValueError, IndexError): pass elif len(parts) >= 3 and parts[0].lower() == b"ms": try: nbytes = int(parts[2]) if len(parts) > 2 else int(parts[1]) if pos + nbytes + 2 <= len(data): value = data[pos : pos + nbytes] pos += nbytes + 2 commands.append((line, value)) continue except (ValueError, IndexError): pass commands.append((line, None)) return (commands, len(commands) > 0) def _commands_to_bytes(commands): """Serialize parsed commands back to memcache protocol bytes.""" parts = [] for line, value in commands: parts.append(line + b"\r\n") if value is not None: parts.append(value + b"\r\n") return b"".join(parts) def _mutate_commands(commands): """Apply random mutations to parsed memcache commands.""" result = list(commands) mutation = random.random() if mutation < 0.25 and len(result) > 0: # Replace a command entirely idx = random.randint(0, len(result) - 1) new_cmd = _random_command() # Parse the generated command back parsed, _ = _parse_mc_commands(new_cmd) if parsed: result[idx] = parsed[0] elif mutation < 0.45 and len(result) > 0: # Mutate a key or value in a command idx = random.randint(0, len(result) - 1) line, value = result[idx] parts = line.split(b" ") if len(parts) >= 2: cmd = parts[0].lower() # Mutate the correct key index depending on command if cmd in (b"gat", b"gats") and len(parts) >= 3: key_idx = random.randint(2, len(parts) - 1) parts[key_idx] = _random_key() else: parts[1] = _random_key() if value is not None: new_value = _random_value() # Update byte count in the header length_idx = None if cmd == b"ms" and len(parts) >= 3: length_idx = 2 elif len(parts) >= 5: length_idx = 4 if length_idx is not None: try: int(parts[length_idx]) parts[length_idx] = str(len(new_value)).encode() except ValueError: pass value = new_value result[idx] = (b" ".join(parts), value) elif mutation < 0.6: # Insert a new random command new_cmd = _random_command() parsed, _ = _parse_mc_commands(new_cmd) if parsed: pos = random.randint(0, len(result)) result.insert(pos, parsed[0]) elif mutation < 0.7 and len(result) > 1: # Remove a command idx = random.randint(0, len(result) - 1) result.pop(idx) elif mutation < 0.8 and len(result) >= 2: # Swap two commands i, j = random.sample(range(len(result)), 2) result[i], result[j] = result[j], result[i] elif mutation < 0.9 and len(result) > 0: # Duplicate a command idx = random.randint(0, len(result) - 1) result.insert(idx + 1, result[idx]) else: # Toggle noreply on a command if len(result) > 0: idx = random.randint(0, len(result) - 1) line, value = result[idx] if line.endswith(b" noreply"): line = line[:-8] else: line = line + b" noreply" result[idx] = (line, value) return result def fuzz(buf, add_buf, max_size): """Main mutation function called by AFL++.""" commands, ok = _parse_mc_commands(buf) if ok and commands: mutated = _mutate_commands(commands) result = _commands_to_bytes(mutated) else: n = random.randint(1, 5) result = b"".join(_random_command() for _ in range(n)) if len(result) > max_size: result = result[:max_size] return bytearray(result) def havoc_mutation(buf, max_size): """Called during havoc stage.""" commands, ok = _parse_mc_commands(buf) if not ok or not commands: return bytearray(_random_command()[:max_size]) mutated = _mutate_commands(commands) result = _commands_to_bytes(mutated) if len(result) > max_size: result = result[:max_size] return bytearray(result) def havoc_mutation_probability(): return 50 ================================================ FILE: fuzz/package_crash.sh ================================================ #!/usr/bin/env bash set -e GREEN='\033[0;32m' RED='\033[0;31m' NC='\033[0m' print_info() { echo -e "${GREEN}[INFO]${NC} $1"; } print_error() { echo -e "${RED}[ERROR]${NC} $1"; } usage() { echo "Usage: $0 [crashes_dir]" echo "" echo "Packages a crash and its RECORD files into a self-contained archive" echo "that can be sent to another developer for reproduction." echo "" echo "Arguments:" echo " crash_id Crash ID (e.g. 000000)" echo " crashes_dir Path to crashes directory (default: fuzz/artifacts/resp/default/crashes)" echo "" echo "Example:" echo " $0 000000" echo " $0 000001 /path/to/crashes" exit 1 } if [[ $# -lt 1 ]]; then usage fi CRASH_ID="$1" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" CRASHES_DIR="${2:-$SCRIPT_DIR/artifacts/resp/default/crashes}" if [[ ! -d "$CRASHES_DIR" ]]; then print_error "Crashes directory not found: $CRASHES_DIR" exit 1 fi # Find the crash input file CRASH_FILE=$(find "$CRASHES_DIR" -maxdepth 1 -name "id:${CRASH_ID},*" ! -name "RECORD:*" | head -1) if [[ -z "$CRASH_FILE" ]]; then print_error "Crash input not found for id:${CRASH_ID} in $CRASHES_DIR" exit 1 fi # Count RECORD files RECORD_COUNT=$(find "$CRASHES_DIR" -maxdepth 1 -name "RECORD:${CRASH_ID},cnt:*" | wc -l) ARCHIVE_NAME="crash-${CRASH_ID}" TMPDIR=$(mktemp -d) DEST="$TMPDIR/$ARCHIVE_NAME" mkdir -p "$DEST/crashes" print_info "Packaging crash ${CRASH_ID}..." print_info "Crash input: $(basename "$CRASH_FILE")" print_info "RECORD files: ${RECORD_COUNT}" # Copy crash input and RECORD files into crashes/ subdirectory cp "$CRASH_FILE" "$DEST/crashes/" if [[ $RECORD_COUNT -gt 0 ]]; then find "$CRASHES_DIR" -maxdepth 1 -name "RECORD:${CRASH_ID},cnt:*" -exec cp {} "$DEST/crashes/" \; fi # Copy replay_crash.py cp "$SCRIPT_DIR/replay_crash.py" "$DEST/" # Create archive OUTPUT="$(pwd)/${ARCHIVE_NAME}.tar.gz" tar -czf "$OUTPUT" -C "$TMPDIR" "$ARCHIVE_NAME" rm -rf "$TMPDIR" SIZE=$(du -h "$OUTPUT" | cut -f1) print_info "Archive created: ${OUTPUT} (${SIZE})" echo "" # Detect target from directory structure: artifacts//default/crashes TARGET_NAME=$(basename "$(dirname "$(dirname "$CRASHES_DIR")")") IS_MEMCACHE=false if [[ "$TARGET_NAME" == "memcache" ]]; then IS_MEMCACHE=true fi echo "To reproduce:" echo " 1. Start dragonfly:" if [[ "$IS_MEMCACHE" == true ]]; then echo " ./build/dragonfly --port 6379 --memcached_port=11211 --logtostderr --proactor_threads 1 --dbfilename=\"\"" else echo " ./build/dragonfly --port 6379 --logtostderr --proactor_threads 1 --dbfilename=\"\"" fi echo " 2. Extract and replay:" echo " tar xzf ${ARCHIVE_NAME}.tar.gz" echo " cd ${ARCHIVE_NAME}" if [[ "$IS_MEMCACHE" == true ]]; then echo " python3 replay_crash.py crashes ${CRASH_ID} 127.0.0.1 11211" else echo " python3 replay_crash.py crashes ${CRASH_ID}" fi ================================================ FILE: fuzz/replay_crash.py ================================================ #!/usr/bin/env python3 """Replays a crash from AFL++ persistent mode RECORD files. In persistent mode, a crash depends on accumulated server state from all previous iterations. AFL_PERSISTENT_RECORD saves these as RECORD files. This script replays them in order against a running Dragonfly instance. Usage: # Start dragonfly in another terminal: ./build-dbg/dragonfly --port 6379 --logtostderr --proactor_threads 1 # Replay crash: python3 fuzz/replay_crash.py fuzz/artifacts/resp/default/crashes 000000 """ import glob import os import socket import sys def send_input(host, port, data): """Send data over TCP. Mirrors SendFuzzInputToServer.""" try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(0.2) s.connect((host, port)) except ConnectionRefusedError: print("\033[0;31m[ERROR]\033[0m Connection refused — is Dragonfly running?") sys.exit(1) try: s.sendall(data) except Exception: pass try: s.recv(4096) except Exception: pass s.close() def main(): if len(sys.argv) < 3: print(f"Usage: {sys.argv[0]} [host] [port]") sys.exit(1) crash_dir = sys.argv[1] crash_id = sys.argv[2] host = sys.argv[3] if len(sys.argv) > 3 else "127.0.0.1" port = int(sys.argv[4]) if len(sys.argv) > 4 else 6379 # Find RECORD files sorted by cnt pattern = os.path.join(crash_dir, f"RECORD:{crash_id},cnt:*") records = sorted(glob.glob(pattern)) # Find crash input file crash_files = [ f for f in glob.glob(os.path.join(crash_dir, f"id:{crash_id},*")) if not os.path.basename(f).startswith("RECORD:") ] if not crash_files: print(f"\033[0;31m[ERROR]\033[0m Crash input not found for id:{crash_id}") sys.exit(1) crash_file = crash_files[0] print(f"\033[0;32m[INFO]\033[0m Replaying crash {crash_id} against {host}:{port}") print(f"\033[0;32m[INFO]\033[0m RECORD files: {len(records)}") print(f"\033[0;32m[INFO]\033[0m Crash file: {crash_file}") print() # Replay all RECORD inputs for i, rec in enumerate(records): if i % 1000 == 0: print(f"\033[1;33m[REPLAY]\033[0m Progress: {i} / {len(records)}") with open(rec, "rb") as f: data = f.read() send_input(host, port, data) # Send the crash input print(f"\033[1;33m[REPLAY]\033[0m Sending crash input: {os.path.basename(crash_file)}") with open(crash_file, "rb") as f: data = f.read() send_input(host, port, data) print() print("\033[0;32m[INFO]\033[0m Replay complete. Check if the Dragonfly process crashed.") print( "\033[0;32m[INFO]\033[0m If not, the bug may depend on thread timing (non-deterministic)." ) if __name__ == "__main__": main() ================================================ FILE: fuzz/resp_mutator.py ================================================ """AFL++ custom mutator for RESP protocol. Instead of random byte-level mutations (which would break protocol framing and get rejected by the parser), this mutator operates at the command level: it parses the input into commands, then randomly replaces/inserts/removes/reorders commands and arguments while keeping RESP encoding valid. This ensures mutated inputs actually reach command execution code paths. Focus commands (optional, set via FUZZ_FOCUS_COMMANDS env var): When running PR-targeted fuzzing, generate_targeted_seeds.py produces a list of command names affected by the code change. This mutator reads that list and picks those commands ~70% of the time, concentrating mutations on the changed code. Commands not already in the COMMANDS table are auto-registered with default arity. Usage: export PYTHONPATH=/path/to/dragonfly/fuzz export AFL_PYTHON_MODULE=resp_mutator export AFL_CUSTOM_MUTATOR_ONLY=1 afl-fuzz ... """ import json import os import random import struct # fmt: off # Commands grouped by arity pattern: (name, min_args, max_args) # min/max are argument counts AFTER the command name itself. COMMANDS = [ # String (b"GET", 1, 1), (b"SET", 2, 6), (b"MGET", 1, 5), (b"MSET", 2, 10), (b"SETNX", 2, 2), (b"SETEX", 3, 3), (b"PSETEX", 3, 3), (b"INCR", 1, 1), (b"DECR", 1, 1), (b"INCRBY", 2, 2), (b"DECRBY", 2, 2), (b"INCRBYFLOAT", 2, 2), (b"APPEND", 2, 2), (b"STRLEN", 1, 1), (b"GETRANGE", 3, 3), (b"SETRANGE", 3, 3), (b"GETSET", 2, 2), (b"GETDEL", 1, 1), (b"GETEX", 1, 3), (b"SUBSTR", 3, 3), (b"MSETNX", 2, 10), # Key (b"DEL", 1, 5), (b"UNLINK", 1, 5), (b"EXISTS", 1, 5), (b"EXPIRE", 2, 3), (b"EXPIREAT", 2, 3), (b"PEXPIRE", 2, 3), (b"PEXPIREAT", 2, 3), (b"PERSIST", 1, 1), (b"TTL", 1, 1), (b"PTTL", 1, 1), (b"EXPIRETIME", 1, 1), (b"PEXPIRETIME", 1, 1), (b"TYPE", 1, 1), (b"RENAME", 2, 2), (b"RENAMENX", 2, 2), (b"COPY", 2, 4), (b"DUMP", 1, 1), (b"TOUCH", 1, 5), (b"OBJECT", 2, 2), (b"RANDOMKEY", 0, 0), (b"KEYS", 1, 1), (b"SCAN", 1, 5), (b"SORT", 1, 7), (b"SORT_RO", 1, 7), # List (b"LPUSH", 2, 5), (b"RPUSH", 2, 5), (b"LPOP", 1, 2), (b"RPOP", 1, 2), (b"LLEN", 1, 1), (b"LINDEX", 2, 2), (b"LSET", 3, 3), (b"LRANGE", 3, 3), (b"LTRIM", 3, 3), (b"LREM", 3, 3), (b"LPOS", 2, 6), (b"LMOVE", 4, 4), (b"LMPOP", 2, 4), (b"LPUSHX", 2, 5), (b"RPUSHX", 2, 5), (b"RPOPLPUSH", 2, 2), (b"BLPOP", 2, 5), (b"BRPOP", 2, 5), (b"BLMOVE", 5, 5), (b"BLMPOP", 3, 5), # Hash (b"HSET", 3, 9), (b"HGET", 2, 2), (b"HDEL", 2, 5), (b"HEXISTS", 2, 2), (b"HLEN", 1, 1), (b"HKEYS", 1, 1), (b"HVALS", 1, 1), (b"HGETALL", 1, 1), (b"HINCRBY", 3, 3), (b"HINCRBYFLOAT", 3, 3), (b"HMSET", 3, 9), (b"HMGET", 2, 5), (b"HSETNX", 3, 3), (b"HSTRLEN", 2, 2), (b"HRANDFIELD", 1, 3), (b"HSCAN", 2, 6), # Set (b"SADD", 2, 5), (b"SREM", 2, 5), (b"SMEMBERS", 1, 1), (b"SISMEMBER", 2, 2), (b"SMISMEMBER", 2, 5), (b"SCARD", 1, 1), (b"SPOP", 1, 2), (b"SRANDMEMBER", 1, 2), (b"SMOVE", 3, 3), (b"SDIFF", 1, 3), (b"SINTER", 1, 3), (b"SUNION", 1, 3), (b"SDIFFSTORE", 2, 4), (b"SINTERSTORE", 2, 4), (b"SUNIONSTORE", 2, 4), (b"SINTERCARD", 2, 5), (b"SSCAN", 2, 6), # Sorted set (b"ZADD", 3, 9), (b"ZREM", 2, 5), (b"ZSCORE", 2, 2), (b"ZMSCORE", 2, 5), (b"ZRANK", 2, 2), (b"ZREVRANK", 2, 2), (b"ZCARD", 1, 1), (b"ZCOUNT", 3, 3), (b"ZLEXCOUNT", 3, 3), (b"ZRANGE", 3, 7), (b"ZRANGEBYLEX", 3, 7), (b"ZRANGEBYSCORE", 3, 7), (b"ZREVRANGE", 3, 5), (b"ZREVRANGEBYLEX", 3, 7), (b"ZREVRANGEBYSCORE", 3, 7), (b"ZRANGESTORE", 4, 8), (b"ZINCRBY", 3, 3), (b"ZRANDMEMBER", 1, 3), (b"ZPOPMIN", 1, 2), (b"ZPOPMAX", 1, 2), (b"BZPOPMIN", 2, 4), (b"BZPOPMAX", 2, 4), (b"ZDIFF", 2, 5), (b"ZDIFFSTORE", 3, 5), (b"ZMPOP", 2, 4), (b"BZMPOP", 3, 5), (b"ZREMRANGEBYRANK", 3, 3), (b"ZREMRANGEBYSCORE", 3, 3), (b"ZREMRANGEBYLEX", 3, 3), (b"ZSCAN", 2, 6), # Stream (b"XADD", 3, 9), (b"XLEN", 1, 1), (b"XRANGE", 3, 5), (b"XREVRANGE", 3, 5), (b"XREAD", 3, 7), (b"XTRIM", 2, 4), (b"XDEL", 2, 5), (b"XINFO", 2, 3), (b"XACK", 3, 5), (b"XGROUP", 3, 6), (b"XREADGROUP", 5, 9), (b"XAUTOCLAIM", 4, 6), (b"XCLAIM", 4, 8), # HyperLogLog (b"PFADD", 1, 5), (b"PFCOUNT", 1, 3), (b"PFMERGE", 2, 4), # Geo (b"GEOADD", 4, 10), (b"GEODIST", 3, 4), (b"GEOPOS", 2, 5), (b"GEOHASH", 2, 5), (b"GEOSEARCH", 4, 10), (b"GEOSEARCHSTORE", 5, 11), # Pub/Sub (b"SUBSCRIBE", 1, 3), (b"PUBLISH", 2, 2), (b"PSUBSCRIBE", 1, 3), # Transaction (b"MULTI", 0, 0), (b"EXEC", 0, 0), (b"DISCARD", 0, 0), (b"WATCH", 1, 3), (b"UNWATCH", 0, 0), # Script (b"EVAL", 2, 6), (b"EVALSHA", 2, 6), (b"EVALRO", 2, 6), # JSON (b"JSON.SET", 3, 4), (b"JSON.GET", 1, 4), (b"JSON.DEL", 1, 2), (b"JSON.TYPE", 1, 2), (b"JSON.NUMINCRBY", 3, 3), (b"JSON.ARRAPPEND", 3, 6), (b"JSON.ARRLEN", 1, 2), (b"JSON.ARRINSERT", 4, 6), (b"JSON.ARRTRIM", 4, 4), (b"JSON.ARRPOP", 1, 3), (b"JSON.ARRINDEX", 3, 5), (b"JSON.OBJKEYS", 1, 2), (b"JSON.OBJLEN", 1, 2), (b"JSON.STRAPPEND", 2, 3), (b"JSON.STRLEN", 1, 2), (b"JSON.TOGGLE", 2, 2), (b"JSON.CLEAR", 1, 2), (b"JSON.MERGE", 3, 3), (b"JSON.MGET", 2, 5), # Bloom filter (b"BF.ADD", 2, 2), (b"BF.EXISTS", 2, 2), (b"BF.MADD", 2, 5), (b"BF.MEXISTS", 2, 5), (b"BF.RESERVE", 3, 5), # Server (b"PING", 0, 1), (b"ECHO", 1, 1), (b"SELECT", 1, 1), (b"DBSIZE", 0, 0), (b"INFO", 0, 1), (b"CONFIG", 2, 3), (b"CLIENT", 1, 3), (b"COMMAND", 0, 2), (b"MEMORY", 1, 2), (b"ACL", 1, 5), (b"MONITOR", 0, 0), (b"RESET", 0, 0), (b"HELLO", 0, 5), (b"WAIT", 2, 2), (b"BGSAVE", 0, 1), (b"OBJECT", 2, 2), (b"LATENCY", 1, 2), (b"SLOWLOG", 1, 2), # Bitops (b"SETBIT", 3, 3), (b"GETBIT", 2, 2), (b"BITCOUNT", 1, 4), (b"BITOP", 3, 5), (b"BITPOS", 2, 5), (b"BITFIELD", 2, 8), # Search (b"FT.CREATE", 3, 15), (b"FT.SEARCH", 2, 10), (b"FT.DROPINDEX", 1, 2), (b"FT.INFO", 1, 1), (b"FT.ALTER", 3, 8), # Throttle (b"CL.THROTTLE", 5, 5), ] # fmt: on KEYS = [b"k", b"key", b"k1", b"k2", b"k3", b"src", b"dst", b"mylist", b"myset", b"myhash"] VALUES = [b"v", b"val", b"hello", b"0", b"1", b"-1", b"100", b"3.14", b"", b"a b"] SPECIAL = [b"*", b"?", b"[", b"NX", b"XX", b"EX", b"PX", b"GT", b"LT", b"KEEPTTL"] JSON_VALUES = [b'{"a":1}', b"[1,2,3]", b'"str"', b"42", b"null", b"true"] JSON_PATHS = [b"$", b"$.a", b"$.*", b"$.arr[0]", b"."] SCORE_VALUES = [b"0", b"1", b"-inf", b"+inf", b"(1", b"(5", b"3.14"] STREAM_IDS = [b"*", b"0-0", b"1-1", b"$", b">"] # Fuzzy values: binary junk, edge cases FUZZ_VALUES = [ b"\x00", b"\xff" * 4, b"\r\n", b"$-1\r\n", b"*0\r\n", b"A" * 256, b"-1", b"99999999999", b"NaN", b"inf", ] # Focus commands: when set via FUZZ_FOCUS_COMMANDS env var (JSON list of command names), # the mutator will prefer these commands ~70% of the time. Used by PR fuzzing to # concentrate mutations on commands affected by the code change. _FOCUS_COMMANDS = [] _FOCUS_WEIGHT = 0.7 _focus_env = os.environ.get("FUZZ_FOCUS_COMMANDS", "") if _focus_env: try: raw = json.loads(_focus_env) if isinstance(raw, str): raw = [raw] if isinstance(raw, list): _focus_names = {s.strip().upper() for s in raw if isinstance(s, str) and s.strip()} else: _focus_names = set() _FOCUS_COMMANDS = [c for c in COMMANDS if c[0].decode().upper() in _focus_names] # Add unknown commands (e.g. newly added in a PR) with default arity _known = {c[0].decode().upper() for c in COMMANDS} for name in _focus_names - _known: entry = (name.encode(), 1, 3) COMMANDS.append(entry) _FOCUS_COMMANDS.append(entry) except (json.JSONDecodeError, TypeError, ValueError): pass def _pick_command(): """Pick a command tuple, preferring focus commands when available.""" if _FOCUS_COMMANDS and random.random() < _FOCUS_WEIGHT: return random.choice(_FOCUS_COMMANDS) return random.choice(COMMANDS) def init(seed): random.seed(seed) def _encode_resp(*args): """Encode a list of args into RESP array.""" parts = [b"*%d\r\n" % len(args)] for a in args: if not isinstance(a, bytes): a = str(a).encode() parts.append(b"$%d\r\n%s\r\n" % (len(a), a)) return b"".join(parts) def _random_arg(): """Generate a random argument value.""" r = random.random() if r < 0.3: return random.choice(KEYS) if r < 0.55: return random.choice(VALUES) if r < 0.7: return random.choice(SPECIAL) if r < 0.8: return random.choice(FUZZ_VALUES) if r < 0.85: return random.choice(JSON_VALUES) if r < 0.9: return random.choice(JSON_PATHS) if r < 0.95: return random.choice(SCORE_VALUES) return random.choice(STREAM_IDS) def _random_command(): """Generate a single random RESP command.""" cmd_name, min_args, max_args = _pick_command() nargs = random.randint(min_args, max_args) args = [cmd_name] + [_random_arg() for _ in range(nargs)] return _encode_resp(*args) def _parse_resp_commands(buf): """Best-effort parse of RESP buffer into list of commands (each is list of bytes). Returns (commands, success). On parse failure returns ([], False).""" commands = [] pos = 0 data = bytes(buf) while pos < len(data): # Skip whitespace/newlines while pos < len(data) and data[pos : pos + 1] in (b"\r", b"\n", b" "): pos += 1 if pos >= len(data): break if data[pos : pos + 1] != b"*": return ([], False) # Parse *N\r\n end = data.find(b"\r\n", pos) if end < 0: return ([], False) try: nargs = int(data[pos + 1 : end]) except ValueError: return ([], False) pos = end + 2 args = [] for _ in range(nargs): if pos >= len(data) or data[pos : pos + 1] != b"$": return ([], False) end = data.find(b"\r\n", pos) if end < 0: return ([], False) try: slen = int(data[pos + 1 : end]) except ValueError: return ([], False) pos = end + 2 if slen < 0: args.append(b"") continue if pos + slen + 2 > len(data): return ([], False) args.append(data[pos : pos + slen]) pos += slen + 2 if args: commands.append(args) return (commands, True) def _mutate_commands(commands): """Apply random mutations to a list of parsed commands.""" result = list(commands) mutation = random.random() if mutation < 0.2 and len(result) > 0: # Replace a random command entirely idx = random.randint(0, len(result) - 1) cmd_name, min_args, max_args = _pick_command() nargs = random.randint(min_args, max_args) result[idx] = [cmd_name] + [_random_arg() for _ in range(nargs)] elif mutation < 0.4 and len(result) > 0: # Mutate an argument of a random command idx = random.randint(0, len(result) - 1) cmd = list(result[idx]) if len(cmd) > 1: arg_idx = random.randint(1, len(cmd) - 1) cmd[arg_idx] = _random_arg() result[idx] = cmd elif mutation < 0.55: # Insert a new random command pos = random.randint(0, len(result)) cmd_name, min_args, max_args = _pick_command() nargs = random.randint(min_args, max_args) result.insert(pos, [cmd_name] + [_random_arg() for _ in range(nargs)]) elif mutation < 0.65 and len(result) > 1: # Remove a random command idx = random.randint(0, len(result) - 1) result.pop(idx) elif mutation < 0.75 and len(result) >= 2: # Swap two commands i, j = random.sample(range(len(result)), 2) result[i], result[j] = result[j], result[i] elif mutation < 0.85 and len(result) > 0: # Duplicate a command idx = random.randint(0, len(result) - 1) result.insert(idx + 1, list(result[idx])) elif mutation < 0.92 and len(result) > 0: # Wrap some commands in MULTI/EXEC start = random.randint(0, len(result) - 1) end = random.randint(start + 1, min(start + 5, len(result))) result.insert(start, [b"MULTI"]) result.insert(end + 1, [b"EXEC"]) else: # Add extra argument to a random command if len(result) > 0: idx = random.randint(0, len(result) - 1) result[idx] = list(result[idx]) + [_random_arg()] return result def _commands_to_resp(commands): """Serialize list of commands back to RESP bytes.""" parts = [] for cmd in commands: parts.append(_encode_resp(*cmd)) return b"".join(parts) def fuzz(buf, add_buf, max_size): """Main mutation function called by AFL++.""" # Try to parse the input as RESP commands, ok = _parse_resp_commands(buf) if ok and commands: # Parsed successfully — mutate at command level mutated = _mutate_commands(commands) result = _commands_to_resp(mutated) else: # Could not parse — generate random commands from scratch n = random.randint(1, 5) result = b"".join(_random_command() for _ in range(n)) if len(result) > max_size: result = result[:max_size] return bytearray(result) def havoc_mutation(buf, max_size): """Called during havoc stage — single small mutation.""" commands, ok = _parse_resp_commands(buf) if not ok or not commands: return bytearray(_random_command()[:max_size]) # Single small mutation mutated = _mutate_commands(commands) result = _commands_to_resp(mutated) if len(result) > max_size: result = result[:max_size] return bytearray(result) def havoc_mutation_probability(): """How often our havoc_mutation is called vs AFL++'s built-in mutations.""" return 50 ================================================ FILE: fuzz/run_fuzzer.sh ================================================ #!/usr/bin/env bash set -e GREEN='\033[0;32m' BLUE='\033[0;34m' YELLOW='\033[1;33m' NC='\033[0m' SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # Target: "resp" (default) or "memcache" TARGET="${1:-resp}" BUILD_DIR="${BUILD_DIR:-$PROJECT_ROOT/build-dbg}" FUZZ_DIR="$SCRIPT_DIR" OUTPUT_DIR="${OUTPUT_DIR:-$FUZZ_DIR/artifacts/$TARGET}" CORPUS_DIR="${CORPUS_DIR:-$FUZZ_DIR/corpus/$TARGET}" SEEDS_DIR="${SEEDS_DIR:-$FUZZ_DIR/seeds/$TARGET}" DICT_FILE="${DICT_FILE:-$FUZZ_DIR/dict/$TARGET.dict}" TIMEOUT="5000" FUZZ_TARGET="$BUILD_DIR/dragonfly" AFL_PROACTOR_THREADS="${AFL_PROACTOR_THREADS:-1}" # Persistent record: restart server every N iterations and record the last N inputs. # This ensures that on crash, ALL inputs that built the current server state are available # for replay. Without this, state from earlier iterations is lost and crashes become # non-reproducible. Max recommended by AFL++: 10000. AFL_LOOP_LIMIT="${AFL_LOOP_LIMIT:-10000}" print_info() { echo -e "${GREEN}[INFO]${NC} $1" } print_note() { echo -e "${BLUE}[NOTE]${NC} $1" } print_warning() { echo -e "${YELLOW}[WARNING]${NC} $1" } check_requirements() { if [[ ! -f "${FUZZ_TARGET}" ]]; then print_warning "Dragonfly not found at ${FUZZ_TARGET}" print_warning "Build with: -DUSE_AFL=ON" exit 1 fi if [[ "$TARGET" != "resp" && "$TARGET" != "memcache" ]]; then print_warning "Unknown target: $TARGET (use 'resp' or 'memcache')" exit 1 fi } setup_directories() { print_info "Setting up directories..." mkdir -p "${OUTPUT_DIR}" mkdir -p "${CORPUS_DIR}" if [[ -z "$(ls -A "$CORPUS_DIR" 2>/dev/null)" ]]; then if [[ -d "${SEEDS_DIR}" ]] && [[ -n "$(ls -A "${SEEDS_DIR}" 2>/dev/null)" ]]; then print_info "Copying seeds to corpus..." cp "${SEEDS_DIR}"/* "${CORPUS_DIR}/" 2>/dev/null || true else print_warning "No seeds found, creating minimal seed" if [[ "$TARGET" == "memcache" ]]; then printf 'version\r\n' > "${CORPUS_DIR}/version" else echo -e '*1\r\n$4\r\nPING\r\n' > "${CORPUS_DIR}/ping" fi fi fi } show_config() { echo "" print_info "AFL++ Persistent Mode Configuration:" echo " Target: ${TARGET}" echo " Binary: ${FUZZ_TARGET}" echo " Corpus: ${CORPUS_DIR}" echo " Output: ${OUTPUT_DIR}" echo " Dictionary: ${DICT_FILE}" echo " Timeout: ${TIMEOUT}ms" echo " Proactor threads: ${AFL_PROACTOR_THREADS}" echo " Loop limit: ${AFL_LOOP_LIMIT} (= AFL_PERSISTENT_RECORD)" echo "" print_note "Fuzzing integrated in dragonfly (USE_AFL + persistent mode)" print_note "Usage: ./run_fuzzer.sh [resp|memcache]" print_note "To change proactor threads: export AFL_PROACTOR_THREADS=N (default: 1)" print_note "To change loop limit: export AFL_LOOP_LIMIT=N (default: 10000)" echo "" } run_fuzzer() { print_info "Starting AFL++ persistent mode fuzzing (target: $TARGET)..." print_info "Press Ctrl+C to stop" echo "" AFL_CMD=( afl-fuzz -o "${OUTPUT_DIR}" -t "${TIMEOUT}" -m 4096 -i "${CORPUS_DIR}" ) if [[ -f "${DICT_FILE}" ]]; then AFL_CMD+=(-x "${DICT_FILE}") fi AFL_CMD+=( -- "${FUZZ_TARGET}" --port=6379 --logtostderr --proactor_threads=${AFL_PROACTOR_THREADS} --afl_loop_limit=${AFL_LOOP_LIMIT} --bind=0.0.0.0 --bind=:: --dbfilename="" --omit_basic_usage --rename_command=SHUTDOWN= --rename_command=DEBUG= --rename_command=FLUSHALL= --rename_command=FLUSHDB= --max_bulk_len=1048576 ) if [[ "$TARGET" == "memcache" ]]; then AFL_CMD+=(--memcached_port=11211 --afl_target_port=11211) fi print_info "Running: ${AFL_CMD[*]}" echo "" cd "${OUTPUT_DIR}" # Run AFL++ - fuzzing integrated in dragonfly via USE_AFL # AFL_HANG_TMOUT: Only consider it a hang if no response for 60 seconds # This prevents false positives from slow but legitimate operations export AFL_HANG_TMOUT=60000 # Dragonfly has ~350K edges, default AFL++ bitmap is 64KB (massive collisions). # Use 512KB bitmap to reduce hash collisions and improve stability. export AFL_MAP_SIZE=524288 # Record the last N inputs before a crash for replay. # Synced with afl_loop_limit so the full server state history is always captured. export AFL_PERSISTENT_RECORD=${AFL_LOOP_LIMIT} # Even with 1 proactor thread, some coverage instability is expected. # Tell AFL++ to continue despite unstable coverage — don't bail on flaky edges. export AFL_IGNORE_PROBLEMS=1 # More aggressive havoc mutations from the start — don't wait for deterministic # stages to finish. Useful for protocol fuzzing where random mutations find new paths. export AFL_EXPAND_HAVOC_NOW=1 # Custom protocol mutator — mutates at command/argument level # instead of random bytes, keeping protocol framing valid. export PYTHONPATH="$FUZZ_DIR" if [[ "$TARGET" == "memcache" ]]; then export AFL_PYTHON_MODULE=memcache_mutator else export AFL_PYTHON_MODULE=resp_mutator fi exec "${AFL_CMD[@]}" } main() { check_requirements setup_directories show_config run_fuzzer } main "$@" ================================================ FILE: fuzz/seeds/memcache/add_replace.mc ================================================ set key1 0 0 3 abc add key2 0 0 3 def replace key1 0 0 3 xyz ================================================ FILE: fuzz/seeds/memcache/append_prepend.mc ================================================ set buf 0 0 5 hello append buf 0 0 6 world prepend buf 0 0 4 say get buf ================================================ FILE: fuzz/seeds/memcache/cas.mc ================================================ set mykey 0 0 3 abc gets mykey cas mykey 0 0 3 1 xyz ================================================ FILE: fuzz/seeds/memcache/delete.mc ================================================ set key1 0 0 1 a set key2 0 0 1 b delete key1 delete key2 noreply get key1 ================================================ FILE: fuzz/seeds/memcache/expiry.mc ================================================ set exp1 0 10 3 abc set exp2 0 0 3 def set exp3 0 9999999 3 ghi get exp1 exp2 exp3 ================================================ FILE: fuzz/seeds/memcache/flags.mc ================================================ set f1 0 0 3 abc set f2 1 0 3 def set f3 65535 0 3 ghi set f4 4294967295 0 3 jkl gets f1 f2 f3 f4 ================================================ FILE: fuzz/seeds/memcache/flush.mc ================================================ set a 0 0 1 x set b 0 0 1 y flush_all get a b ================================================ FILE: fuzz/seeds/memcache/gat.mc ================================================ set mykey 0 100 5 hello gat 200 mykey gats 300 mykey ================================================ FILE: fuzz/seeds/memcache/incr_decr.mc ================================================ set counter 0 0 1 0 incr counter 1 incr counter 10 decr counter 5 get counter ================================================ FILE: fuzz/seeds/memcache/large_value.mc ================================================ set big 0 0 100 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa get big ================================================ FILE: fuzz/seeds/memcache/meta_commands.mc ================================================ ms mykey 5 hello mg mykey md mykey ma counter mn ================================================ FILE: fuzz/seeds/memcache/multiget.mc ================================================ set k1 0 0 1 a set k2 0 0 1 b set k3 0 0 1 c get k1 k2 k3 gets k1 k2 k3 ================================================ FILE: fuzz/seeds/memcache/noreply.mc ================================================ set key1 0 0 3 noreply abc add key2 0 0 3 noreply def replace key1 0 0 3 noreply xyz incr counter 1 noreply delete key2 noreply ================================================ FILE: fuzz/seeds/memcache/set_get.mc ================================================ set mykey 0 0 5 hello get mykey ================================================ FILE: fuzz/seeds/memcache/stats_version.mc ================================================ stats version quit ================================================ FILE: fuzz/seeds/resp/acl.resp ================================================ *2 $3 ACL $6 WHOAMI *2 $3 ACL $4 LIST ================================================ FILE: fuzz/seeds/resp/acl_ops.resp ================================================ *2 $3 ACL $6 WHOAMI *2 $3 ACL $4 LIST *2 $3 ACL $5 USERS *3 $3 ACL $3 CAT *2 $3 ACL $7 GENPASS *2 $7 COMMAND $5 COUNT *2 $7 COMMAND $4 DOCS ================================================ FILE: fuzz/seeds/resp/acl_ops2.resp ================================================ *2 $3 ACL $6 WHOAMI *2 $3 ACL $4 LIST *2 $3 ACL $5 USERS *3 $3 ACL $3 CAT $6 string *2 $3 ACL $7 GENPASS *4 $3 ACL $7 SETUSER $8 testuser $2 on *3 $3 ACL $7 GETUSER $8 testuser *4 $3 ACL $6 DRYRUN $8 testuser $3 GET $3 key *2 $3 ACL $4 HELP *3 $3 ACL $3 LOG $5 RESET *3 $3 ACL $7 DELUSER $8 testuser *2 $3 ACL $4 SAVE *2 $3 ACL $4 LOAD *3 $4 AUTH $8 testuser $8 password *2 $7 COMMAND $5 COUNT *3 $7 COMMAND $4 INFO $3 GET *2 $7 COMMAND $4 DOCS $3 SET *2 $7 COMMAND $4 LIST ================================================ FILE: fuzz/seeds/resp/bf_add.resp ================================================ *4 $10 BF.RESERVE $7 mybloom $4 0.01 $4 1000 *3 $6 BF.ADD $7 mybloom $5 item1 *3 $9 BF.EXISTS $7 mybloom $5 item1 ================================================ FILE: fuzz/seeds/resp/bitfield.resp ================================================ *6 $8 BITFIELD $3 key $3 GET $2 u8 $1 0 ================================================ FILE: fuzz/seeds/resp/bitfield_ops.resp ================================================ *8 $8 BITFIELD $2 bk $3 SET $3 u8 $1 0 $3 200 $3 GET $3 u8 $1 0 *5 $8 BITFIELD $2 bk $6 INCRBY $3 u8 $1 0 $2 10 *6 $8 BITFIELD $2 bk $8 OVERFLOW $3 SAT $6 INCRBY $3 u8 $1 0 $3 100 *5 $5 BITOP $3 AND $4 dest $2 bk $2 bk *3 $6 BITPOS $2 bk $1 1 ================================================ FILE: fuzz/seeds/resp/bitops.resp ================================================ *4 $6 SETBIT $2 bk $1 7 $1 1 *3 $6 GETBIT $2 bk $1 0 *2 $8 BITCOUNT $2 bk ================================================ FILE: fuzz/seeds/resp/bloom_ops.resp ================================================ *4 $10 BF.RESERVE $2 bf $4 0.01 $4 1000 *3 $6 BF.ADD $2 bf $5 item1 *3 $9 BF.EXISTS $2 bf $5 item1 *5 $7 BF.MADD $2 bf $5 item2 $5 item3 $5 item4 *5 $10 BF.MEXISTS $2 bf $5 item1 $5 item2 $5 itemX ================================================ FILE: fuzz/seeds/resp/client.resp ================================================ *3 $6 CLIENT $7 SETNAME $10 testclient *2 $6 CLIENT $7 GETNAME *2 $6 CLIENT $4 LIST ================================================ FILE: fuzz/seeds/resp/config.resp ================================================ *3 $6 CONFIG $3 GET $9 maxmemory ================================================ FILE: fuzz/seeds/resp/copy.resp ================================================ *3 $3 SET $3 src $5 hello *3 $4 COPY $3 src $3 dst ================================================ FILE: fuzz/seeds/resp/del.resp ================================================ *2 $3 DEL $3 key ================================================ FILE: fuzz/seeds/resp/eval.resp ================================================ *3 $4 EVAL $26 return redis.call("PING") $0 ================================================ FILE: fuzz/seeds/resp/expire_ops.resp ================================================ *3 $3 SET $2 ek $3 val *3 $6 EXPIRE $2 ek $3 300 *2 $3 TTL $2 ek *2 $4 PTTL $2 ek *2 $10 EXPIRETIME $2 ek *2 $11 PEXPIRETIME $2 ek *3 $8 EXPIREAT $2 ek $10 9999999999 *2 $7 PERSIST $2 ek *3 $7 PEXPIRE $2 ek $6 300000 *3 $9 PEXPIREAT $2 ek $13 9999999999000 *2 $5 TOUCH $2 ek ================================================ FILE: fuzz/seeds/resp/function.resp ================================================ *3 $8 FUNCTION $4 LOAD $56 #!lua name=mylib redis.register_function('myfunc', function() return 1 end) *2 $8 FUNCTION $4 LIST ================================================ FILE: fuzz/seeds/resp/function_ops.resp ================================================ *3 $8 FUNCTION $4 LOAD $56 #!lua name=mylib redis.register_function('myfunc', function() return 1 end) *2 $8 FUNCTION $4 LIST *3 $8 FUNCTION $6 DELETE $5 mylib ================================================ FILE: fuzz/seeds/resp/generic_ops.resp ================================================ *3 $3 SET $2 gk $3 val *2 $4 TYPE $2 gk *2 $6 EXISTS $2 gk *3 $6 EXPIRE $2 gk $3 300 *2 $3 TTL $2 gk *2 $4 PTTL $2 gk *2 $10 EXPIRETIME $2 gk *3 $7 PEXPIRE $2 gk $6 300000 *2 $11 PEXPIRETIME $2 gk *2 $7 PERSIST $2 gk *3 $4 COPY $2 gk $3 gk2 *3 $6 RENAME $3 gk2 $3 gk3 *2 $4 DUMP $3 gk3 *2 $6 UNLINK $3 gk3 *2 $4 KEYS $1 * *3 $4 SCAN $1 0 $5 COUNT $1 5 *2 $9 RANDOMKEY *2 $6 DBSIZE *2 $4 TIME *3 $6 SELECT $1 0 *5 $4 SORT $2 gk $2 BY $6 nosort $5 ALPHA ================================================ FILE: fuzz/seeds/resp/generic_ops2.resp ================================================ *3 $3 SET $3 gk1 $3 val *3 $3 SET $3 gk2 $3 val *2 $2 DEL $3 gk1 *2 $3 GET $3 gk2 *2 $3 TTL $3 gk2 *3 $8 RENAMENX $3 gk2 $3 gk3 *2 $4 ECHO $5 hello *3 $5 STICK $3 gk3 *2 $5 TOUCH $3 gk3 *2 $4 TYPE $3 gk3 *3 $4 MOVE $3 gk3 $1 1 *2 $7 SORT_RO $3 gk3 *3 $3 SET $3 gk4 $3 val *4 $7 RESTORE $3 gk5 $1 0 $5 dummy ================================================ FILE: fuzz/seeds/resp/geo_ops.resp ================================================ *8 $6 GEOADD $2 gk $9 13.361389 $9 38.115556 $7 Palermo $9 15.087269 $9 37.502669 $7 Catania *5 $7 GEODIST $2 gk $7 Palermo $7 Catania $2 km *3 $7 GEOHASH $2 gk $7 Palermo *3 $6 GEOPOS $2 gk $7 Palermo *7 $9 GEOSEARCH $2 gk $9 FROMLONLAT $2 15 $2 37 $6 BYRADIUS $3 200 $2 km *6 $10 GEORADIUS $2 gk $2 15 $2 37 $3 200 $2 km ================================================ FILE: fuzz/seeds/resp/geo_ops2.resp ================================================ *11 $6 GEOADD $3 gx1 $9 13.361389 $9 38.115556 $7 Palermo $9 15.087269 $9 37.502669 $7 Catania $9 2.349014 $9 48.864716 $5 Paris *7 $10 GEORADIUS $3 gx1 $2 15 $2 37 $3 200 $2 km *6 $19 GEORADIUSBYMEMBER $3 gx1 $7 Palermo $3 200 $2 km *7 $13 GEORADIUS_RO $3 gx1 $2 15 $2 37 $3 200 $2 km *6 $22 GEORADIUSBYMEMBER_RO $3 gx1 $7 Palermo $3 200 $2 km *9 $9 GEOSEARCH $3 gx1 $10 FROMLONLAT $2 15 $2 37 $6 BYRADIUS $3 200 $2 km $3 ASC *10 $14 GEOSEARCHSTORE $4 gdst $3 gx1 $10 FROMLONLAT $2 15 $2 37 $6 BYRADIUS $3 200 $2 km ================================================ FILE: fuzz/seeds/resp/geoadd.resp ================================================ *5 $6 GEOADD $5 mygeo $9 13.361389 $9 38.115556 $7 Palermo *5 $7 GEODIST $5 mygeo $7 Palermo $7 Catania $2 km ================================================ FILE: fuzz/seeds/resp/get.resp ================================================ *2 $3 GET $3 key ================================================ FILE: fuzz/seeds/resp/getdel.resp ================================================ *3 $3 SET $1 k $1 v *2 $6 GETDEL $1 k ================================================ FILE: fuzz/seeds/resp/hash_ops.resp ================================================ *8 $4 HSET $2 hh $2 f1 $2 v1 $2 f2 $2 v2 $2 f3 $2 10 *3 $4 HGET $2 hh $2 f1 *4 $5 HMGET $2 hh $2 f1 $2 f2 *2 $7 HGETALL $2 hh *2 $5 HKEYS $2 hh *2 $5 HVALS $2 hh *2 $4 HLEN $2 hh *3 $7 HEXISTS $2 hh $2 f1 *3 $7 HSTRLEN $2 hh $2 f1 *4 $7 HINCRBY $2 hh $2 f3 $1 5 *4 $12 HINCRBYFLOAT $2 hh $2 f3 $3 1.5 *3 $10 HRANDFIELD $2 hh $1 2 *4 $6 HSETNX $2 hh $4 newf $4 newv *3 $4 HDEL $2 hh $2 f2 *3 $5 HSCAN $2 hh $1 0 ================================================ FILE: fuzz/seeds/resp/hash_ops2.resp ================================================ *6 $4 HSET $3 hx1 $2 f1 $2 v1 $2 f2 $2 v2 *4 $5 HMSET $3 hx1 $2 f3 $2 v3 *4 $6 HSETNX $3 hx1 $6 newkey $5 newvl *4 $7 HSTRLEN $3 hx1 $2 f1 *3 $12 HINCRBYFLOAT $3 hx1 $2 f1 $3 1.5 *3 $9 HRANDFIELD $3 hx1 $1 2 *5 $6 HSETEX $3 hx1 $3 300 $2 f4 $2 v4 *4 $7 HEXPIRE $3 hx1 $3 300 $2 f4 ================================================ FILE: fuzz/seeds/resp/hll_ops.resp ================================================ *5 $5 PFADD $4 hll1 $1 a $1 b $1 c *4 $5 PFADD $4 hll2 $1 c $1 d $1 e *2 $7 PFCOUNT $4 hll1 *3 $7 PFCOUNT $4 hll1 $4 hll2 *4 $7 PFMERGE $4 hll3 $4 hll1 $4 hll2 ================================================ FILE: fuzz/seeds/resp/hset.resp ================================================ *4 $4 HSET $4 hash $5 field $5 value ================================================ FILE: fuzz/seeds/resp/json.resp ================================================ *4 $8 JSON.SET $3 doc $1 $ $15 {"name":"test"} ================================================ FILE: fuzz/seeds/resp/json_ops.resp ================================================ *4 $8 JSON.SET $2 jk $1 $ $52 {"name":"test","age":30,"tags":["a","b"],"nested":{"x":1}} *3 $8 JSON.GET $2 jk $1 $ *3 $9 JSON.TYPE $2 jk $1 $ *3 $10 JSON.STRLEN $2 jk $6 $.name *3 $11 JSON.OBJLEN $2 jk $1 $ *3 $11 JSON.OBJKEYS $2 jk $1 $ *3 $10 JSON.ARRLEN $2 jk $6 $.tags *4 $13 JSON.ARRAPPEND $2 jk $6 $.tags $3 "c" *5 $13 JSON.ARRINSERT $2 jk $6 $.tags $1 0 $3 "z" *4 $11 JSON.ARRPOP $2 jk $6 $.tags $2 -1 *5 $12 JSON.ARRTRIM $2 jk $6 $.tags $1 0 $1 2 *4 $12 JSON.ARRINDEX $2 jk $6 $.tags $3 "a" *3 $14 JSON.NUMINCRBY $2 jk $5 $.age $1 1 *3 $14 JSON.NUMMULTBY $2 jk $5 $.age $1 2 *4 $12 JSON.STRAPPEND $2 jk $6 $.name $4 "_x" *3 $11 JSON.TOGGLE $2 jk $6 $.tags *3 $10 JSON.CLEAR $2 jk $6 $.tags *3 $8 JSON.DEL $2 jk $8 $.nested *3 $9 JSON.RESP $2 jk $1 $ ================================================ FILE: fuzz/seeds/resp/json_ops2.resp ================================================ *4 $8 JSON.SET $3 jm1 $1 $ $13 {"a":1,"b":2} *4 $8 JSON.SET $3 jm2 $1 $ $13 {"a":3,"c":4} *3 $9 JSON.MGET $3 jm1 $3 jm2 $1 $ *4 $9 JSON.MSET $3 jm1 $3 $.a $1 9 *4 $10 JSON.MERGE $3 jm1 $1 $ $9 {"d":"new"} *3 $10 JSON.DEBUG $6 MEMORY $3 jm1 $1 $ *3 $10 JSON.FORGET $3 jm2 $3 $.c ================================================ FILE: fuzz/seeds/resp/list_blocking.resp ================================================ *5 $5 RPUSH $3 lb1 $1 a $1 b $1 c *5 $5 RPUSH $3 lb2 $1 x $1 y $1 z *3 $10 RPOPLPUSH $3 lb1 $3 lb2 *5 $5 LMOVE $3 lb1 $3 lb2 $4 LEFT $5 RIGHT *4 $5 LMPOP $1 2 $3 lb1 $3 lb2 $4 LEFT *4 $5 LPUSH $3 bq1 $1 1 $1 2 *3 $5 BLPOP $3 bq1 $1 1 *3 $5 BRPOP $3 bq1 $1 1 *5 $6 BLMOVE $3 lb1 $3 lb2 $4 LEFT $5 RIGHT $1 1 *5 $6 BLMPOP $1 1 $1 1 $3 lb1 $4 LEFT ================================================ FILE: fuzz/seeds/resp/list_ops.resp ================================================ *5 $5 RPUSH $2 ll $1 a $1 b $1 c *3 $6 LPUSHX $2 ll $1 x *3 $6 RPUSHX $2 ll $1 z *2 $4 LLEN $2 ll *4 $6 LRANGE $2 ll $1 0 $2 -1 *3 $6 LINDEX $2 ll $1 2 *5 $7 LINSERT $2 ll $6 BEFORE $1 b $4 new1 *4 $4 LSET $2 ll $1 0 $4 head *4 $5 LTRIM $2 ll $1 0 $1 4 *4 $4 LREM $2 ll $1 1 $1 a *2 $4 LPOP $2 ll *2 $4 RPOP $2 ll *5 $5 RPUSH $2 l2 $1 1 $1 2 $1 3 *4 $5 LMOVE $2 ll $2 l2 $4 LEFT ================================================ FILE: fuzz/seeds/resp/lmpop.resp ================================================ *5 $5 RPUSH $6 mylist $1 a $1 b $1 c *4 $5 LMPOP $1 1 $6 mylist $4 LEFT ================================================ FILE: fuzz/seeds/resp/lpos.resp ================================================ *7 $5 RPUSH $6 mylist $1 a $1 b $1 c $1 a $1 d *3 $4 LPOS $6 mylist $1 a ================================================ FILE: fuzz/seeds/resp/lpush.resp ================================================ *3 $5 LPUSH $4 list $4 item ================================================ FILE: fuzz/seeds/resp/memory.resp ================================================ *3 $3 SET $5 mykey $9 somevalue *3 $6 MEMORY $5 USAGE $5 mykey ================================================ FILE: fuzz/seeds/resp/monitor.resp ================================================ *1 $7 MONITOR ================================================ FILE: fuzz/seeds/resp/mset.resp ================================================ *5 $4 MSET $1 a $1 1 $1 b $1 2 *3 $4 MGET $1 a $1 b ================================================ FILE: fuzz/seeds/resp/multi_type_pipeline.resp ================================================ *3 $3 SET $2 pk $5 hello *5 $5 RPUSH $2 pl $1 a $1 b $1 c *4 $4 HSET $2 ph $1 f $1 v *4 $4 SADD $2 ps $1 x $1 y *6 $4 ZADD $2 pz $1 1 $1 a $1 2 $1 b *5 $4 XADD $2 px $1 * $1 k $1 v *4 $8 JSON.SET $2 pj $1 $ $13 {"a":1,"b":2} *2 $4 TYPE $2 pk *2 $4 TYPE $2 pl *2 $4 TYPE $2 ph *2 $4 TYPE $2 ps *2 $4 TYPE $2 pz *2 $4 TYPE $2 px *8 $3 DEL $2 pk $2 pl $2 ph $2 ps $2 pz $2 px $2 pj ================================================ FILE: fuzz/seeds/resp/object.resp ================================================ *3 $3 SET $5 mykey $3 val *3 $6 OBJECT $8 ENCODING $5 mykey ================================================ FILE: fuzz/seeds/resp/pfadd.resp ================================================ *5 $5 PFADD $4 hll1 $1 a $1 b $1 c *2 $7 PFCOUNT $4 hll1 ================================================ FILE: fuzz/seeds/resp/ping.resp ================================================ *1 $4 PING ================================================ FILE: fuzz/seeds/resp/pipeline.resp ================================================ *1 $4 PING *3 $3 SET $1 a $1 1 *2 $4 INCR $1 a *2 $3 GET $1 a *2 $3 DEL $1 a ================================================ FILE: fuzz/seeds/resp/pubsub_ops.resp ================================================ *3 $7 PUBLISH $4 chan $5 hello *2 $6 PUBSUB $8 CHANNELS *3 $6 PUBSUB $6 NUMSUB $4 chan ================================================ FILE: fuzz/seeds/resp/pubsub_ops2.resp ================================================ *3 $7 PUBLISH $5 chan1 $3 msg *3 $7 PUBLISH $5 chan2 $4 msg2 *2 $6 PUBSUB $8 CHANNELS *3 $6 PUBSUB $6 NUMSUB $5 chan1 *2 $6 PUBSUB $8 NUMPAT *2 $9 SUBSCRIBE $5 chan1 *2 $11 UNSUBSCRIBE $5 chan1 *2 $10 PSUBSCRIBE $5 chan* *2 $12 PUNSUBSCRIBE $5 chan* *2 $10 SSUBSCRIBE $5 chan1 *3 $8 SPUBLISH $5 chan1 $4 smsg ================================================ FILE: fuzz/seeds/resp/rename.resp ================================================ *3 $3 SET $3 foo $5 hello *3 $6 RENAME $3 foo $3 bar ================================================ FILE: fuzz/seeds/resp/rpoplpush.resp ================================================ *3 $5 LPUSH $3 src $1 a *3 $5 LPUSH $3 src $1 b *3 $9 RPOPLPUSH $3 src $3 dst ================================================ FILE: fuzz/seeds/resp/sadd.resp ================================================ *3 $4 SADD $3 set $6 member ================================================ FILE: fuzz/seeds/resp/scan_hscan.resp ================================================ *6 $4 HSET $1 h $2 f1 $2 v1 $2 f2 $2 v2 *3 $5 HSCAN $1 h $1 0 ================================================ FILE: fuzz/seeds/resp/script_ops.resp ================================================ *3 $4 EVAL $28 return redis.call('PING') $1 0 *4 $4 EVAL $44 return redis.call('SET', KEYS[1], ARGV[1]) $1 1 $2 ek $2 ev *4 $7 EVAL_RO $37 return redis.call('GET', KEYS[1]) $1 1 $2 ek *2 $6 SCRIPT $5 FLUSH ================================================ FILE: fuzz/seeds/resp/script_ops2.resp ================================================ *4 $4 EVAL $44 return redis.call('SET', KEYS[1], ARGV[1]) $1 1 $3 esk $5 esval *4 $7 EVAL_RO $37 return redis.call('GET', KEYS[1]) $1 1 $3 esk *3 $6 SCRIPT $5 FLUSH $5 ASYNC *3 $6 SCRIPT $6 EXISTS $40 e0e1f9fabfc9d4800c877a703b823ac0578ff831 *4 $8 EVALSHA $40 e0e1f9fabfc9d4800c877a703b823ac0578ff831 $1 0 *4 $11 EVALSHA_RO $40 e0e1f9fabfc9d4800c877a703b823ac0578ff831 $1 0 ================================================ FILE: fuzz/seeds/resp/sdiffstore.resp ================================================ *4 $4 SADD $2 s1 $1 a $1 b *3 $4 SADD $2 s2 $1 b *4 $10 SDIFFSTORE $3 dst $2 s1 $2 s2 ================================================ FILE: fuzz/seeds/resp/search_ops.resp ================================================ *8 $9 FT.CREATE $5 myidx $2 ON $4 HASH $6 SCHEMA $5 title $4 TEXT $5 score $7 NUMERIC *3 $7 FT.INFO $5 myidx *8 $4 HSET $4 doc1 $5 title $5 hello $5 score $1 1 *8 $4 HSET $4 doc2 $5 title $5 world $5 score $1 2 *3 $9 FT.SEARCH $5 myidx $5 hello *5 $9 FT.SEARCH $5 myidx $1 * $5 LIMIT $1 0 $1 5 *2 $8 FT._LIST *3 $12 FT.DROPINDEX $5 myidx ================================================ FILE: fuzz/seeds/resp/search_ops2.resp ================================================ *8 $9 FT.CREATE $5 idx2 $2 ON $4 HASH $6 PREFIX $1 1 $4 doc: $6 SCHEMA $5 title $4 TEXT $5 score $7 NUMERIC *4 $4 HSET $5 doc:1 $5 title $5 hello $5 score $1 1 *4 $4 HSET $5 doc:2 $5 title $5 world $5 score $1 2 *3 $9 FT.SEARCH $5 idx2 $5 hello *7 $9 FT.SEARCH $5 idx2 $1 * $6 SORTBY $5 score $5 LIMIT $1 0 $1 1 *2 $7 FT.INFO $5 idx2 *5 $8 FT.ALTER $5 idx2 $6 SCHEMA $3 ADD $3 tag $3 TAG *3 $9 FT.CONFIG $3 GET $1 * *3 $9 FT.CONFIG $3 SET $14 MAXSEARCHRESULTS $5 10000 *6 $12 FT.SYNUPDATE $5 idx2 $2 g1 $5 hello $2 hi $3 hey *2 $10 FT.SYNDUMP $5 idx2 *3 $12 FT.AGGREGATE $5 idx2 $1 * *2 $10 FT.TAGVALS $5 idx2 $3 tag *2 $12 FT.DROPINDEX $5 idx2 ================================================ FILE: fuzz/seeds/resp/server_ops.resp ================================================ *2 $4 INFO $6 server *2 $4 INFO $6 memory *2 $4 INFO $11 replication *1 $6 DBSIZE *3 $6 CLIENT $7 SETNAME $4 fuzz *2 $6 CLIENT $7 GETNAME *2 $6 CLIENT $2 ID *2 $6 CLIENT $4 INFO *3 $6 CONFIG $3 GET $9 maxmemory *2 $4 ROLE *2 $7 LASTSAVE *3 $6 MEMORY $5 USAGE $4 nokey *2 $7 SLOWLOG $3 LEN *2 $7 LATENCY $6 LATEST *3 $5 HELLO $1 2 ================================================ FILE: fuzz/seeds/resp/server_ops2.resp ================================================ *2 $4 INFO $3 all *2 $6 CLIENT $4 LIST *3 $6 CLIENT $4 INFO *2 $7 CLUSTER $4 INFO *2 $7 CLUSTER $5 MYID *2 $7 CLUSTER $5 SLOTS *1 $8 READONLY *1 $9 READWRITE *2 $7 SLOWLOG $3 GET *2 $7 LATENCY $7 HISTORY $5 event *2 $6 MEMORY $6 DOCTOR *2 $6 MEMORY $5 STATS *3 $5 HELLO $1 3 *4 $4 DFLY $7 CLUSTER $6 CONFIG $2 {} *2 $1 QUIT ================================================ FILE: fuzz/seeds/resp/set.resp ================================================ *3 $3 SET $3 key $5 value ================================================ FILE: fuzz/seeds/resp/set_ops.resp ================================================ *6 $4 SADD $2 s1 $1 a $1 b $1 c $1 d *5 $4 SADD $2 s2 $1 c $1 d $1 e *2 $5 SCARD $2 s1 *2 $8 SMEMBERS $2 s1 *3 $9 SISMEMBER $2 s1 $1 a *4 $10 SMISMEMBER $2 s1 $1 a $1 z *3 $4 SREM $2 s1 $1 d *3 $5 SMOVE $2 s1 $2 s2 $1 a *3 $6 SUNION $2 s1 $2 s2 *3 $5 SINTER $2 s1 $2 s2 *3 $5 SDIFF $2 s1 $2 s2 *4 $11 SUNIONSTORE $4 sdst $2 s1 $2 s2 *4 $11 SINTERSTORE $4 idst $2 s1 $2 s2 *4 $10 SDIFFSTORE $4 ddst $2 s1 $2 s2 *4 $10 SINTERCARD $1 2 $2 s1 $2 s2 *3 $4 SPOP $2 s1 $1 1 *3 $5 SSCAN $2 s2 $1 0 ================================================ FILE: fuzz/seeds/resp/set_ops2.resp ================================================ *4 $4 SADD $3 sx1 $1 a $1 b *4 $6 SADDEX $3 sx1 $3 300 $1 c ================================================ FILE: fuzz/seeds/resp/smove.resp ================================================ *3 $4 SADD $3 src $1 a *3 $4 SADD $3 dst $1 b *4 $5 SMOVE $3 src $3 dst $1 a ================================================ FILE: fuzz/seeds/resp/sort.resp ================================================ *4 $5 LPUSH $4 list $1 3 $1 1 *3 $5 LPUSH $4 list $1 2 *4 $4 SORT $4 list $5 STORE $6 sorted ================================================ FILE: fuzz/seeds/resp/srandmember.resp ================================================ *7 $4 SADD $5 myset $1 a $1 b $1 c $1 d $1 e *3 $11 SRANDMEMBER $5 myset $1 3 ================================================ FILE: fuzz/seeds/resp/stream_ops.resp ================================================ *5 $4 XADD $2 st $1 * $1 k $1 v *5 $4 XADD $2 st $1 * $1 k $2 v2 *5 $4 XADD $2 st $1 * $1 k $2 v3 *2 $4 XLEN $2 st *4 $6 XRANGE $2 st $1 - $1 + *4 $9 XREVRANGE $2 st $1 + $1 - *4 $5 XTRIM $2 st $6 MAXLEN $1 2 *4 $6 XGROUP $6 CREATE $2 st $2 g1 $1 0 *7 $10 XREADGROUP $5 GROUP $2 g1 $2 c1 $7 STREAMS $2 st $1 > *4 $4 XACK $2 st $2 g1 $3 0-1 *4 $8 XPENDING $2 st $2 g1 $1 - $1 + $2 10 *4 $5 XINFO $6 STREAM $2 st *3 $6 XSETID $2 st $3 0-5 ================================================ FILE: fuzz/seeds/resp/stream_ops2.resp ================================================ *5 $4 XADD $3 sx1 $1 * $1 k $2 v1 *5 $4 XADD $3 sx1 $1 * $1 k $2 v2 *5 $4 XADD $3 sx1 $1 * $1 k $2 v3 *4 $6 XGROUP $6 CREATE $3 sx1 $3 sg1 $1 0 *7 $10 XREADGROUP $5 GROUP $3 sg1 $2 c1 $7 STREAMS $3 sx1 $1 > *5 $6 XCLAIM $3 sx1 $3 sg1 $2 c1 $1 0 $3 0-1 *6 $10 XAUTOCLAIM $3 sx1 $3 sg1 $2 c1 $1 0 $3 0-0 *3 $4 XDEL $3 sx1 $3 0-1 ================================================ FILE: fuzz/seeds/resp/string_ops.resp ================================================ *3 $3 SET $2 sk $5 hello *3 $6 APPEND $2 sk $6 _world *2 $6 STRLEN $2 sk *4 $8 GETRANGE $2 sk $1 0 $1 4 *4 $8 SETRANGE $2 sk $1 6 $3 foo *3 $5 SETEX $3 sk2 $2 60 $4 temp *3 $6 PSETEX $3 sk3 $5 60000 $4 temp *3 $5 SETNX $3 sk4 $3 new *3 $6 GETSET $2 sk $3 old *6 $4 MSET $2 m1 $2 v1 $2 m2 $2 v2 *3 $4 MGET $2 m1 $2 m2 *3 $3 SET $2 ci $1 0 *2 $4 INCR $2 ci *2 $4 DECR $2 ci *3 $6 INCRBY $2 ci $2 10 *3 $6 DECRBY $2 ci $1 5 *3 $12 INCRBYFLOAT $2 ci $3 1.5 *2 $6 GETDEL $2 m2 *4 $5 GETEX $2 m1 $2 EX $2 60 ================================================ FILE: fuzz/seeds/resp/string_ops2.resp ================================================ *4 $5 MSETNX $2 nx1 $2 v1 $2 nx2 $2 v2 *3 $7 PREPEND $2 nx1 $3 pre *3 $6 SUBSTR $2 nx1 $1 0 $1 3 *2 $6 DIGEST $2 nx1 *4 $5 SETEX $2 sx $1 3 $3 val *4 $6 PSETEX $2 px $4 3000 $3 val *3 $5 GETEX $2 sx $2 EX $1 5 *3 $6 APPEND $2 sx $4 _end *3 $8 SETRANGE $2 sx $1 0 $3 NEW *2 $6 GETDEL $2 px ================================================ FILE: fuzz/seeds/resp/subscribe.resp ================================================ *2 $9 SUBSCRIBE $9 mychannel ================================================ FILE: fuzz/seeds/resp/throttle.resp ================================================ *6 $11 CL.THROTTLE $6 myrate $2 10 $2 30 $2 60 $1 1 ================================================ FILE: fuzz/seeds/resp/transaction.resp ================================================ *1 $5 MULTI *3 $3 SET $1 a $1 1 *1 $4 EXEC ================================================ FILE: fuzz/seeds/resp/transaction_ops2.resp ================================================ *3 $3 SET $2 tk $3 val *1 $5 WATCH $2 tk *1 $5 MULTI *3 $3 SET $2 tk $4 new1 *1 $7 DISCARD *1 $7 UNWATCH *1 $5 MULTI *3 $3 SET $2 tk $4 new2 *1 $4 EXEC ================================================ FILE: fuzz/seeds/resp/watch.resp ================================================ *2 $5 WATCH $1 a *1 $5 MULTI *3 $3 SET $1 a $1 1 *1 $4 EXEC ================================================ FILE: fuzz/seeds/resp/watch_multi.resp ================================================ *2 $5 WATCH $1 k *1 $5 MULTI *3 $3 SET $1 k $1 1 *1 $4 EXEC ================================================ FILE: fuzz/seeds/resp/xadd.resp ================================================ *5 $4 XADD $6 stream $1 * $5 field $5 value ================================================ FILE: fuzz/seeds/resp/xread.resp ================================================ *5 $5 XREAD $5 COUNT $1 1 $7 STREAMS $6 stream $1 0 ================================================ FILE: fuzz/seeds/resp/zadd.resp ================================================ *5 $4 ZADD $4 zset $1 1 $6 member ================================================ FILE: fuzz/seeds/resp/zmpop.resp ================================================ *8 $4 ZADD $5 myzst $1 1 $1 a $1 2 $1 b $1 3 $1 c *4 $5 ZMPOP $1 1 $5 myzst $3 MIN ================================================ FILE: fuzz/seeds/resp/zrangebyscore.resp ================================================ *5 $13 ZRANGEBYSCORE $4 zset $4 -inf $4 +inf $10 WITHSCORES ================================================ FILE: fuzz/seeds/resp/zset_ops.resp ================================================ *8 $4 ZADD $2 z1 $1 1 $1 a $1 2 $1 b $1 3 $1 c *6 $4 ZADD $2 z2 $1 2 $1 b $1 4 $1 d *3 $7 ZINCRBY $2 z1 $1 5 $1 a *3 $5 ZSCORE $2 z1 $1 a *4 $7 ZMSCORE $2 z1 $1 a $1 c *2 $5 ZCARD $2 z1 *4 $6 ZCOUNT $2 z1 $4 -inf $4 +inf *3 $5 ZRANK $2 z1 $1 b *3 $8 ZREVRANK $2 z1 $1 b *4 $6 ZRANGE $2 z1 $1 0 $2 -1 *4 $9 ZREVRANGE $2 z1 $1 0 $2 -1 *4 $12 ZRANGEBYLEX $2 z1 $1 - $1 + *5 $12 ZRANGEBYSCORE $2 z1 $1 1 $1 3 $10 WITHSCORES *4 $15 ZREMRANGEBYRANK $2 z2 $1 0 $1 0 *4 $16 ZREMRANGEBYSCORE $2 z2 $1 0 $1 2 *3 $7 ZPOPMIN $2 z1 $1 1 *3 $7 ZPOPMAX $2 z1 $1 1 *3 $6 ZUNION $1 2 $2 z1 $2 z2 *4 $11 ZUNIONSTORE $4 zdst $1 2 $2 z1 *3 $5 ZSCAN $2 z1 $1 0 *3 $11 ZRANDMEMBER $2 z1 $1 2 ================================================ FILE: fuzz/seeds/resp/zset_ops2.resp ================================================ *8 $4 ZADD $3 za1 $1 1 $1 a $1 2 $1 b $1 3 $1 c *8 $4 ZADD $3 za2 $1 2 $1 b $1 4 $1 d $1 5 $1 e *4 $6 ZINTER $1 2 $3 za1 $3 za2 *4 $11 ZINTERSTORE $4 zint $1 2 $3 za1 $3 za2 *5 $10 ZINTERCARD $1 2 $3 za1 $3 za2 *4 $5 ZDIFF $1 2 $3 za1 $3 za2 *4 $10 ZDIFFSTORE $5 zdiff $1 2 $3 za1 *3 $4 ZREM $3 za2 $1 d *4 $14 ZREMRANGEBYLEX $3 za1 $3 [a] $3 [b] *6 $11 ZRANGESTORE $5 zrngs $3 za1 $1 0 $2 -1 $7 BYSCORE *4 $9 ZLEXCOUNT $3 za1 $1 - $1 + *6 $15 ZREVRANGEBYSCORE $3 za1 $4 +inf $4 -inf $10 WITHSCORES $5 LIMIT $1 0 $1 2 *4 $13 ZREVRANGEBYLEX $3 za1 $1 + $1 - ================================================ FILE: fuzz/triage_crashes.sh ================================================ #!/usr/bin/env bash # Triage AFL++ crash artifacts: replay each crash against a fresh Dragonfly # instance and report whether it's confirmed or a false positive. # # Usage: # ./fuzz/triage_crashes.sh # # dragonfly_binary Path to Dragonfly binary # mode Protocol: 'resp' or 'memcache' # crashes.zip .zip downloaded from CI artifacts (contains crash-*.tar.gz files) # # Examples: # ./fuzz/triage_crashes.sh ./build-dbg/dragonfly resp fuzz-long-resp-crashes-35.zip # ./fuzz/triage_crashes.sh ./build-dbg/dragonfly memcache fuzz-long-memcache-crashes-35.zip set -euo pipefail # ─── Colors ─────────────────────────────────────────────────────────────────── RED='\033[0;31m' GREEN='\033[0;32m' YELLOW='\033[1;33m' CYAN='\033[0;36m' BOLD='\033[1m' NC='\033[0m' # ─── Config ─────────────────────────────────────────────────────────────────── RESP_PORT=6379 MC_PORT=11211 STARTUP_TIMEOUT=5 # seconds to wait for Dragonfly to accept connections POST_REPLAY_WAIT=3 # seconds to wait after replay for Dragonfly to crash print_info() { echo -e "${GREEN}[INFO]${NC} $1"; } print_error() { echo -e "${RED}[ERROR]${NC} $1"; } print_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } usage() { echo -e "${BOLD}Usage:${NC} $0 " echo "" echo " dragonfly_binary Path to Dragonfly binary" echo " mode Protocol: 'resp' or 'memcache'" echo " crashes.zip .zip downloaded from CI artifacts" echo "" echo "Examples:" echo " $0 ./build-dbg/dragonfly resp fuzz-long-resp-crashes-35.zip" echo " $0 ./build-dbg/dragonfly memcache fuzz-long-memcache-crashes-35.zip" exit 1 } # ─── Args ───────────────────────────────────────────────────────────────────── if [[ $# -lt 3 ]]; then usage fi DRAGONFLY_BIN="$(realpath "$1")" MODE="$2" CRASHES_ZIP="$(realpath "$3")" if [[ ! -f "$DRAGONFLY_BIN" ]]; then print_error "Dragonfly binary not found: $DRAGONFLY_BIN" exit 1 fi if [[ "$MODE" != "resp" && "$MODE" != "memcache" ]]; then print_error "Mode must be 'resp' or 'memcache', got: $MODE" exit 1 fi if [[ ! -f "$CRASHES_ZIP" ]]; then print_error "Crashes zip not found: $CRASHES_ZIP" exit 1 fi if [[ "$CRASHES_ZIP" != *.zip ]]; then print_error "Expected a .zip file (CI artifact), got: $CRASHES_ZIP" exit 1 fi # ─── Working directory ──────────────────────────────────────────────────────── WORK_DIR=$(mktemp -d /tmp/triage_XXXXXX) DF_PID="" cleanup() { [[ -n "$DF_PID" ]] && kill -9 "$DF_PID" 2>/dev/null || true rm -rf "$WORK_DIR" } trap cleanup EXIT INT TERM # ─── Extract zip ────────────────────────────────────────────────────────────── print_info "Extracting $(basename "$CRASHES_ZIP")..." unzip -q "$CRASHES_ZIP" -d "$WORK_DIR/input" CRASHES_DIR="$WORK_DIR/input" # ─── Find crash archives ────────────────────────────────────────────────────── mapfile -t CRASH_ARCHIVES < <(find "$CRASHES_DIR" -name 'crash-*.tar.gz' | sort) TOTAL=${#CRASH_ARCHIVES[@]} if [[ $TOTAL -eq 0 ]]; then print_error "No crash-*.tar.gz files found in: $CRASHES_DIR" exit 1 fi print_info "Found $TOTAL crash archive(s) mode=$MODE binary=$DRAGONFLY_BIN" echo "" # ─── Locate replay_crash.py ─────────────────────────────────────────────────── SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" REPLAY_SCRIPT="$SCRIPT_DIR/replay_crash.py" if [[ ! -f "$REPLAY_SCRIPT" ]]; then print_error "replay_crash.py not found at: $REPLAY_SCRIPT" print_error "Run this script from the repository root or fuzz/ directory." exit 1 fi # ─── Helpers ────────────────────────────────────────────────────────────────── # Wait until a TCP port accepts connections wait_for_port() { local host="$1" port="$2" timeout_sec="$3" local deadline=$((SECONDS + timeout_sec)) while [[ $SECONDS -lt $deadline ]]; do if (>/dev/tcp/"$host"/"$port") 2>/dev/null; then return 0 fi sleep 0.2 done return 1 } # Wait until a TCP port stops accepting connections wait_port_free() { local port="$1" timeout_sec="${2:-5}" local deadline=$((SECONDS + timeout_sec)) while [[ $SECONDS -lt $deadline ]]; do if ! (>/dev/tcp/127.0.0.1/"$port") 2>/dev/null; then return 0 fi sleep 0.2 done return 1 } # Show crash info from glog log directory. show_crash_log() { local log_dir="$1" local fatal_link="$log_dir/dragonfly.FATAL" if [[ -f "$fatal_link" ]]; then # Skip the 4-line glog file header, show crash message + stack trace sed -n '5,$p' "$fatal_link" | head -40 | sed 's/^/ /' return fi # No FATAL file — fall back to tail of INFO log local info_log info_log=$(ls -t "$log_dir"/dragonfly.*.log.INFO.* 2>/dev/null | head -1 || true) if [[ -n "$info_log" ]]; then echo " (no FATAL log — last INFO log lines:)" tail -20 "$info_log" | sed 's/^/ /' else echo " (no log files found in $log_dir)" fi } # ─── Main loop ──────────────────────────────────────────────────────────────── CONFIRMED=0 FALSE_POSITIVE=0 FAILED=0 for CRASH_ARCHIVE in "${CRASH_ARCHIVES[@]}"; do CRASH_NAME=$(basename "$CRASH_ARCHIVE" .tar.gz) # crash-000000 CRASH_ID="${CRASH_NAME#crash-}" # 000000 IDX=$((CONFIRMED + FALSE_POSITIVE + FAILED + 1)) echo -e "${CYAN}${BOLD}─── [$IDX/$TOTAL] Crash ${CRASH_ID} ───${NC}" # Extract this crash archive EXTRACT_DIR="$WORK_DIR/current_crash" rm -rf "$EXTRACT_DIR" mkdir -p "$EXTRACT_DIR" tar -xzf "$CRASH_ARCHIVE" -C "$EXTRACT_DIR" CRASH_DATA_DIR="$EXTRACT_DIR/${CRASH_NAME}/crashes" if [[ ! -d "$CRASH_DATA_DIR" ]]; then print_warn "Expected directory not found: $CRASH_DATA_DIR — skipping" FAILED=$((FAILED + 1)) echo "" continue fi # Kill any leftover process on the port from a previous iteration if (>/dev/tcp/127.0.0.1/"$RESP_PORT") 2>/dev/null; then print_warn "Port $RESP_PORT still in use — waiting..." wait_port_free "$RESP_PORT" 5 || { print_error "Port $RESP_PORT still blocked after 5s — cannot start Dragonfly" FAILED=$((FAILED + 1)) echo "" continue } fi # Start Dragonfly — use --log_dir so glog writes to separate per-level files # (dragonfly.FATAL symlink is created on crash and contains the fatal message) LOG_DIR="$WORK_DIR/logs_${CRASH_ID}" mkdir -p "$LOG_DIR" # Mirror the exact flags used by run_fuzzer.sh so replay runs in the same # server configuration as when the crash was found. # Missing rename_command flags are the most common cause of false positives: # if FLUSHALL/FLUSHDB/SHUTDOWN are not disabled, they execute during replay, # wiping state or shutting down the server before the crash can trigger. DF_ARGS=( --port "$RESP_PORT" --log_dir="$LOG_DIR" --proactor_threads 1 --dbfilename="" --omit_basic_usage --rename_command=SHUTDOWN= --rename_command=DEBUG= --rename_command=FLUSHALL= --rename_command=FLUSHDB= --max_bulk_len=1048576 ) [[ "$MODE" == "memcache" ]] && DF_ARGS+=(--memcached_port="$MC_PORT") "$DRAGONFLY_BIN" "${DF_ARGS[@]}" >/dev/null 2>&1 & DF_PID=$! if ! wait_for_port 127.0.0.1 "$RESP_PORT" "$STARTUP_TIMEOUT"; then print_error "Dragonfly did not start within ${STARTUP_TIMEOUT}s (crash $CRASH_ID)" kill -9 "$DF_PID" 2>/dev/null || true wait "$DF_PID" 2>/dev/null && true || true DF_PID="" FAILED=$((FAILED + 1)) echo "" continue fi # In memcache mode also verify the memcache listener is up before replaying if [[ "$MODE" == "memcache" ]] && ! wait_for_port 127.0.0.1 "$MC_PORT" 3; then print_error "Memcache port $MC_PORT not ready (crash $CRASH_ID)" kill -9 "$DF_PID" 2>/dev/null || true wait "$DF_PID" 2>/dev/null && true || true DF_PID="" FAILED=$((FAILED + 1)) echo "" continue fi # Replay the crash REPLAY_PORT="$RESP_PORT" [[ "$MODE" == "memcache" ]] && REPLAY_PORT="$MC_PORT" if ! python3 "$REPLAY_SCRIPT" \ "$CRASH_DATA_DIR" "$CRASH_ID" 127.0.0.1 "$REPLAY_PORT" \ >/dev/null 2>&1; then print_warn "Replay script failed for crash $CRASH_ID — skipping" kill -9 "$DF_PID" 2>/dev/null || true wait "$DF_PID" 2>/dev/null && true || true DF_PID="" FAILED=$((FAILED + 1)) echo "" continue fi # Wait for Dragonfly to die (poll every 100ms) DIED=false for _ in $(seq 1 $((POST_REPLAY_WAIT * 10))); do if ! kill -0 "$DF_PID" 2>/dev/null; then DIED=true break fi sleep 0.1 done if ! $DIED; then echo -e " ${YELLOW}FALSE POSITIVE${NC} — Dragonfly alive after replay" FALSE_POSITIVE=$((FALSE_POSITIVE + 1)) kill -9 "$DF_PID" 2>/dev/null || true wait "$DF_PID" 2>/dev/null && true || true DF_PID="" else # Capture signal without triggering set -e (assignment always exits 0) wait "$DF_PID" 2>/dev/null && EXIT_CODE=0 || EXIT_CODE=$? DF_PID="" # Sanity check: exit code > 128 means killed by signal; otherwise not a signal death if [[ $EXIT_CODE -le 128 ]]; then echo -e " ${YELLOW}FALSE POSITIVE${NC} — Dragonfly exited cleanly (code $EXIT_CODE)" FALSE_POSITIVE=$((FALSE_POSITIVE + 1)) echo "" continue fi SIGNAL=$((EXIT_CODE - 128)) CONFIRMED=$((CONFIRMED + 1)) if [[ $SIGNAL -eq 6 ]]; then echo -e " ${RED}CONFIRMED${NC} — SIGABRT (signal 6) — assertion / LOG(FATAL)" show_crash_log "$LOG_DIR" elif [[ $SIGNAL -eq 11 ]]; then echo -e " ${RED}CONFIRMED${NC} — SIGSEGV (signal 11) — segmentation fault" show_crash_log "$LOG_DIR" else echo -e " ${RED}CONFIRMED${NC} — signal $SIGNAL (exit code $EXIT_CODE)" show_crash_log "$LOG_DIR" fi fi echo "" done # ─── Summary ────────────────────────────────────────────────────────────────── echo -e "${CYAN}${BOLD}═══ Triage Summary ═══${NC}" printf " %-18s %d\n" "Total:" "$TOTAL" printf " ${RED}%-18s %d${NC}\n" "Confirmed:" "$CONFIRMED" printf " ${YELLOW}%-18s %d${NC}\n" "False positive:" "$FALSE_POSITIVE" [[ $FAILED -gt 0 ]] && printf " ${RED}%-18s %d${NC}\n" "Failed/skipped:" "$FAILED" # Exit 1 if any confirmed crashes found [[ $CONFIRMED -gt 0 ]] && exit 1 exit 0 ================================================ FILE: go.work ================================================ go 1.24.0 toolchain go1.24.7 use ( ./contrib/charts/dragonfly ./tools/replay ) ================================================ FILE: go.work.sum ================================================ cel.dev/expr v0.16.1/go.mod h1:AsGA5zb3WruAEQeQng1RZdGEXmBj0jvMWh6l5SnNuC8= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/auth v0.10.2/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= cloud.google.com/go/auth/oauth2adapt v0.2.5/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= cloud.google.com/go/cloudbuild v1.19.0/go.mod h1:ZGRqbNMrVGhknIIjwASa6MqoRTOpXIVMSI+Ew5DMPuY= cloud.google.com/go/compute v1.19.1/go.mod h1:6ylj3a05WF8leseCdIf77NK0g1ey+nj5IKd5/kvShxE= cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= cloud.google.com/go/iam v1.2.2/go.mod h1:0Ys8ccaZHdI1dEUilwzqng/6ps2YB6vRsjIe00/+6JY= cloud.google.com/go/longrunning v0.6.2/go.mod h1:k/vIs83RN4bE3YCswdXC5PFfWVILjm3hpEUlSko4PiI= cloud.google.com/go/monitoring v1.21.2/go.mod h1:hS3pXvaG8KgWTSz+dAdyzPrGUYmi2Q+WFX8g2hqVEZU= cloud.google.com/go/storage v1.47.0/go.mod h1:Ks0vP374w0PW6jOUameJbapbQKXqkjGd/OJRp2fb9IQ= github.com/Azure/azure-sdk-for-go v51.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0/go.mod h1:XCW7KnZet0Opnr7HccfUw1PLc4CjHqpcaxW8DHklNkQ= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appcontainers/armappcontainers/v3 v3.0.0/go.mod h1:LDN3sr8FJ36sY6ZmMes6Q2vHJ+5r1aFsE3wEo7VbXJg= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= github.com/Azure/go-autorest/autorest v0.11.20/go.mod h1:o3tqFY+QR40VOlk+pV4d77mORO64jOXSgEnPQgLK6JY= github.com/Azure/go-autorest/autorest/adal v0.9.13/go.mod h1:W/MM4U6nLxnIskrw4UwWzlHfGjwUS50aOsc/I3yuU8M= github.com/Azure/go-autorest/autorest/azure/auth v0.5.8/go.mod h1:kxyKZTSfKh8OVFWPAgOgQ/frrJgeYQJPyR5fLFmXko4= github.com/Azure/go-autorest/autorest/azure/cli v0.4.2/go.mod h1:7qkJkT+j6b+hIpzMOwPChJhTqS8VbsqqgULzMNRugoM= github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= github.com/Azure/go-autorest/autorest/to v0.4.0/go.mod h1:fE8iZBn7LQR7zH/9XU2NcPR4o9jEImooCeWJcYV/zLE= github.com/Azure/go-autorest/autorest/validation v0.3.1/go.mod h1:yhLgjC0Wda5DYXl6JAsWyUe4KVNffhoDhG0zVzUMo3E= github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.24.1/go.mod h1:itPGVDKf9cC/ov4MdvJ2QZ0khw4bfoo9jzwTJlaxy2k= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1/go.mod h1:jyqM3eLpJ3IbIFDTKVz2rF9T/xWGW0rIriGwnz8l9Tk= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1/go.mod h1:viRWSEhtMZqz1rhwmOVKkWl6SwmVowfL9O2YR5gI2PE= github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/apparentlymart/go-textseg/v13 v13.0.0/go.mod h1:ZK2fH7c4NqDTLtiYLvIkEghdlcqw7yxLeM89kiTRPUo= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/aws/aws-lambda-go v1.47.0/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/aws/aws-sdk-go v1.44.122/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ00z/TKoufEY6K/a0k6AhaJrQKdFe6OfVXsa4= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bradleyfalzon/ghinstallation v1.1.1/go.mod h1:vyCmHTciHx/uuyN82Zc3rXN3X2KTK8nUTCrTMwAhcug= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/containerd/stargz-snapshotter/estargz v0.14.3/go.mod h1:KY//uOCIkSuNAHhJogcZtrNHdKrA99/FCCRjE3HD36o= github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= github.com/docker/cli v27.1.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker-credential-helpers v0.7.0/go.mod h1:rETQfLdHNT3foU5kuNkFR1R1V12OJRRO5lzt2D1b5X0= github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8= github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/gonvenience/wrap v1.1.2/go.mod h1:GiryBSXoI3BAAhbWD1cZVj7RZmtiu0ERi/6R6eJfslI= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= github.com/google/go-github/v29 v29.0.2/go.mod h1:CHKiKKPHJ0REzfwc14QMklvtHwCveD0PxlMjLlzAM5E= github.com/google/go-github/v44 v44.1.0/go.mod h1:iWn00mWcP6PRWHhXm0zuFJ8wbEjE5AGO5D5HXYM4zgw= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= github.com/googleapis/gax-go/v2 v2.14.0/go.mod h1:lhBCnjdLrWRaPvLWhmc8IS24m9mr07qSYnHncrgo+zk= github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-getter v1.7.5/go.mod h1:W7TalhMmbPmsSMdNjD0ZskARur/9GJ17cfHTRtXV744= github.com/hashicorp/go-getter/v2 v2.2.3/go.mod h1:hp5Yy0GMQvwWVUmwLs3ygivz1JSLI323hdIE9J9m7TY= github.com/hashicorp/go-safetemp v1.0.0/go.mod h1:oaerMy3BhqiTbVye6QuFhFtIceqFoDHxNAB65b+Rj1I= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/hcl/v2 v2.22.0/go.mod h1:62ZYHrXgPoX8xBnzl8QzbWq4dyDsDtfCRgIq1rbJEvA= github.com/hashicorp/terraform-json v0.23.0/go.mod h1:MHdXbBAbSg0GvzuWazEGKAn/cyNfIB7mN6y7KJN6y2c= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a/go.mod h1:yL958EeXv8Ylng6IfnvG4oflryUi3vgA3xPs9hmII1s= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jstemmer/go-junit-report v1.0.0/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/mitchellh/go-testing-interface v1.14.1/go.mod h1:gfgS7OtZj6MA4U1UrDRp04twqAjfvlZyCfX3sDjEym8= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0-rc3/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= github.com/oracle/oci-go-sdk v7.1.0+incompatible/go.mod h1:VQb79nF8Z2cwLkLS35ukwStZIg5F66tcBccjip/j888= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/slack-go/slack v0.15.0/go.mod h1:hlGi5oXA+Gt+yWTPP0plCdRKmjsDxecdHxYQdlMQKOw= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/tmccombs/hcl2json v0.6.4/go.mod h1:+ppKlIW3H5nsAsZddXPy2iMyvld3SHxyjswOZhavRDk= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/urfave/cli v1.22.16 h1:MH0k6uJxdwdeWQTwhSO42Pwr4YLrNLwBtg1MRgTqPdQ= github.com/urfave/cli v1.22.16/go.mod h1:EeJR6BKodywf4zciqrdw6hpCPk68JO9z5LazXZMn5Po= github.com/vbatts/tar-split v0.11.3/go.mod h1:9QlHN18E+fEH7RdG+QAJJcuya3rqT7eXSTY7wGrAokY= github.com/zclconf/go-cty v1.15.0/go.mod h1:VvMs5i0vgZdhYawQNq5kePSpLAoz8u1xvZgrPIxfnZE= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/detectors/gcp v1.29.0/go.mod h1:GW2aWZNwR2ZxDLdv8OyC2G8zkRoQBuURgV7RPQgcPoU= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= go.opentelemetry.io/otel/sdk v1.29.0/go.mod h1:pM8Dx5WKnvxLCb+8lG1PRNIDxu9g9b9g59Qr7hfAAok= go.opentelemetry.io/otel/sdk/metric v1.29.0/go.mod h1:6zZLdCl2fkauYoZIOn/soQIDSWFmNSRcICarHfuhNJQ= go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.206.0/go.mod h1:BtB8bfjTYIrai3d8UyvPmV9REGgox7coh+ZRwm0b+W8= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20241113202542-65e8d215514f/go.mod h1:Q5m6g8b5KaFFzsQFIGdJkSJDGeJiybVenoYFMMa3ohI= google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4= google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a/go.mod h1:9i1T9n4ZinTUZGgzENMi8MDDgbGC5mqTS75JAv6xN3A= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= k8s.io/gengo/v2 v2.0.0-20250604051438-85fd79dbfd9f/go.mod h1:EJykeLsmFC60UQbYJezXkEsG2FLrt0GPNkU5iK5GWxU= sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= ================================================ FILE: patches/mimalloc-v2.2.4/0_base.patch ================================================ diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ce084f6..00eba70c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18) +cmake_minimum_required(VERSION 3.16) project(libmimalloc C CXX) set(CMAKE_C_STANDARD 11) @@ -44,7 +44,38 @@ option(MI_WIN_USE_FLS "Use Fiber local storage on Windows to detect thread option(MI_CHECK_FULL "Use full internal invariant checking in DEBUG mode (deprecated, use MI_DEBUG_FULL instead)" OFF) option(MI_USE_LIBATOMIC "Explicitly link with -latomic (on older systems) (deprecated and detected automatically)" OFF) -include(CheckLinkerFlag) # requires cmake 3.18 +function(CHECK_LINKER_FLAG _lang _flag _var) + get_property (_supported_languages GLOBAL PROPERTY ENABLED_LANGUAGES) + if (NOT _lang IN_LIST _supported_languages) + message (SEND_ERROR "check_linker_flag: ${_lang}: unknown language.") + return() + endif() + include (Check${_lang}SourceCompiles) + set(CMAKE_REQUIRED_LINK_OPTIONS "${_flag}") + # Normalize locale during test compilation. + set(_locale_vars LC_ALL LC_MESSAGES LANG) + foreach(v IN LISTS _locale_vars) + set(_locale_vars_saved_${v} "$ENV{${v}}") + set(ENV{${v}} C) + endforeach() + if (_lang MATCHES "^(C|CXX)$") + set (_source "int main() { return 0; }") + elseif (_lang STREQUAL "Fortran") + set (_source " program test\n stop\n end program") + elseif (_lang MATCHES "^(OBJC|OBJCXX)$") + set (_source "#ifndef __OBJC__\n# error \"Not an Objective-C++ compiler\"\n#endif\nint main(void) { return 0; }") + else() + message (SEND_ERROR "check_linker_flag: ${_lang}: unsupported language.") + return() + endif() + set(_common_patterns "") + check_c_source_compiles("${_source}" ${_var} ${_common_patterns}) + foreach(v IN LISTS _locale_vars) + set(ENV{${v}} ${_locale_vars_saved_${v}}) + endforeach() + set(${_var} "${${_var}}" PARENT_SCOPE) +endfunction() + include(CheckIncludeFiles) include(GNUInstallDirs) include("cmake/mimalloc-config-version.cmake") diff --git a/src/alloc.c b/src/alloc.c index 0fed5e75..870f8d10 100644 --- a/src/alloc.c +++ b/src/alloc.c @@ -670,6 +670,24 @@ mi_decl_restrict void* _mi_heap_malloc_guarded(mi_heap_t* heap, size_t size, boo } #endif +bool mi_heap_page_is_underutilized(mi_heap_t* heap, void* p, float ratio) mi_attr_noexcept { + mi_page_t* page = _mi_ptr_page(p); // get the page that this belongs to + + mi_heap_t* page_heap = (mi_heap_t*)(mi_atomic_load_acquire(&(page)->xheap)); + + // the heap id matches and it is not a full page + if (mi_likely(page_heap == heap && page->flags.x.in_full == 0)) { + // first in the list, meaning it's the head of page queue, thus being used for malloc + if (page->prev == NULL) + return false; + + // this page belong to this heap and is not first in the page queue. Lets check its + // utilization. + return page->used <= (unsigned)(page->capacity * ratio); + } + return false; +} + // ------------------------------------------------------ // ensure explicit external inline definitions are emitted! // ------------------------------------------------------ ================================================ FILE: patches/mimalloc-v2.2.4/1_add_stat_type.patch ================================================ diff --git a/include/mimalloc/types.h b/include/mimalloc/types.h index a15d9cba..ee822ca9 100644 --- a/include/mimalloc/types.h +++ b/include/mimalloc/types.h @@ -682,4 +682,23 @@ void _mi_stat_counter_increase(mi_stat_counter_t* stat, size_t amount); #define mi_heap_stat_decrease(heap,stat,amount) mi_stat_decrease( (heap)->tld->stats.stat, amount) #define mi_heap_stat_adjust_decrease(heap,stat,amount) mi_stat_adjust_decrease( (heap)->tld->stats.stat, amount) +#define MI_DFLY_PAGE_BELOW_THRESHOLD 1 +#define MI_DFLY_PAGE_FULL 2 +#define MI_DFLY_HEAP_MISMATCH 4 +#define MI_DFLY_PAGE_USED_FOR_MALLOC 8 + +typedef struct mi_page_usage_stats_s { + uintptr_t page_address; + size_t block_size; + uint16_t capacity; + uint16_t reserved; + uint16_t used; + // Collects the current state of page as returned by mi_heap_page_is_underutilized + // 0th bit set: page usage is below threshold: MI_DFLY_PAGE_BELOW_THRESHOLD + // 1st bit set: the page is full: MI_DFLY_PAGE_FULL + // 2nd bit set: the page heap did not match the heap requested: MI_DFLY_HEAP_MISMATCH + // 3rd bit set: that the page is currently used for malloc operations: MI_DFLY_PAGE_USED_FOR_MALLOC + uint8_t flags; +} mi_page_usage_stats_t; + #endif ================================================ FILE: patches/mimalloc-v2.2.4/2_return_stat.patch ================================================ diff --git a/src/alloc.c b/src/alloc.c index 893f3094..88318d0e 100644 --- a/src/alloc.c +++ b/src/alloc.c @@ -676,22 +676,45 @@ mi_decl_restrict void* _mi_heap_malloc_guarded(mi_heap_t* heap, size_t size, boo } #endif -bool mi_heap_page_is_underutilized(mi_heap_t* heap, void* p, float ratio) mi_attr_noexcept { - mi_page_t* page = _mi_ptr_page(p); // get the page that this belongs to +mi_page_usage_stats_t mi_heap_page_is_underutilized(mi_heap_t *heap, void *p, float ratio, + bool return_detailed_stats) mi_attr_noexcept { + mi_page_t *page = _mi_ptr_page(p); // get the page that this belongs to + mi_heap_t *page_heap = (mi_heap_t *) (mi_atomic_load_acquire(&(page)->xheap)); + + if (!return_detailed_stats) { + mi_page_usage_stats_t result = {.flags = 0}; + if (mi_likely(page_heap == heap && page->flags.x.in_full == 0)) { + if (page->prev != NULL && page->used <= (unsigned) (page->capacity * ratio)) + result.flags = MI_DFLY_PAGE_BELOW_THRESHOLD; + } + return result; + } + + mi_page_usage_stats_t result = { + .page_address = (uintptr_t) page, + .block_size = page->block_size, + .capacity = page->capacity, + .reserved = page->reserved, + .used = page->used, + .flags = 0, + }; - mi_heap_t* page_heap = (mi_heap_t*)(mi_atomic_load_acquire(&(page)->xheap)); + if (page->flags.x.in_full == 1) { + result.flags |= MI_DFLY_PAGE_FULL; + } + + if (page_heap != heap) { + result.flags |= MI_DFLY_HEAP_MISMATCH; + } - // the heap id matches and it is not a full page - if (mi_likely(page_heap == heap && page->flags.x.in_full == 0)) { - // first in the list, meaning it's the head of page queue, thus being used for malloc - if (page->prev == NULL) - return false; + if (page->prev == NULL) { + result.flags |= MI_DFLY_PAGE_USED_FOR_MALLOC; + } - // this page belong to this heap and is not first in the page queue. Lets check its - // utilization. - return page->used <= (unsigned)(page->capacity * ratio); + if (result.flags == 0 && result.used <= (unsigned) (result.capacity * ratio)) { + result.flags = MI_DFLY_PAGE_BELOW_THRESHOLD; } - return false; + return result; } // ------------------------------------------------------ ================================================ FILE: patches/mimalloc-v2.2.4/3_track_full_size.patch ================================================ commit e0cda4eb4a54cfcd33afcd5fbd7ecd86510ac4f9 Author: Roman Gershman Date: Wed Sep 3 23:30:34 2025 +0300 chore: track comitted size of full pages in a heap Signed-off-by: Roman Gershman diff --git a/include/mimalloc/types.h b/include/mimalloc/types.h index a15d9cba..34d99a94 100644 --- a/include/mimalloc/types.h +++ b/include/mimalloc/types.h @@ -559,9 +559,10 @@ struct mi_heap_s { uintptr_t cookie; // random cookie to verify pointers (see `_mi_ptr_cookie`) uintptr_t keys[2]; // two random keys used to encode the `thread_delayed_free` list mi_random_ctx_t random; // random number context used for secure allocation - size_t page_count; // total number of pages in the `pages` queues. - size_t page_retired_min; // smallest retired index (retired pages are fully free, but still in the page queues) - size_t page_retired_max; // largest retired index into the `pages` array. + uint32_t page_count; // total number of pages in the `pages` queues. + uint16_t page_retired_min; // smallest retired index (retired pages are fully free, but still in the page queues) + uint16_t page_retired_max; // largest retired index into the `pages` array. + size_t full_page_size; // total size of pages residing in MI_BIN_FULL bin. long generic_count; // how often is `_mi_malloc_generic` called? long generic_collect_count; // how often is `_mi_malloc_generic` called without collecting? mi_heap_t* next; // list of heaps per thread diff --git a/src/init.c b/src/init.c index 3fc8b033..61ee4c76 100644 --- a/src/init.c +++ b/src/init.c @@ -118,6 +118,7 @@ mi_decl_cache_align const mi_heap_t _mi_heap_empty = { { {0}, {0}, 0, true }, // random 0, // page count MI_BIN_FULL, 0, // page retired min/max + 0, // full page size 0, 0, // generic count NULL, // next false, // can reclaim @@ -167,6 +168,7 @@ mi_decl_cache_align mi_heap_t _mi_heap_main = { { {0x846ca68b}, {0}, 0, true }, // random 0, // page count MI_BIN_FULL, 0, // page retired min/max + 0, // full page size 0, 0, // generic count NULL, // next heap false, // can reclaim diff --git a/src/page-queue.c b/src/page-queue.c index c719b626..524b09d8 100644 --- a/src/page-queue.c +++ b/src/page-queue.c @@ -232,6 +232,10 @@ static void mi_page_queue_remove(mi_page_queue_t* queue, mi_page_t* page) { page->next = NULL; page->prev = NULL; // mi_atomic_store_ptr_release(mi_atomic_cast(void*, &page->heap), NULL); + if (mi_page_queue_is_full(queue)) { + mi_assert_internal(heap->full_page_size >= mi_page_block_size(page) * page->capacity); + heap->full_page_size -= mi_page_block_size(page) * page->capacity; + } mi_page_set_in_full(page,false); } @@ -246,6 +250,9 @@ static void mi_page_queue_push(mi_heap_t* heap, mi_page_queue_t* queue, mi_page_ (mi_page_is_large_or_huge(page) && mi_page_queue_is_huge(queue)) || (mi_page_is_in_full(page) && mi_page_queue_is_full(queue))); + if (mi_page_queue_is_full(queue)) { + heap->full_page_size += mi_page_block_size(page) * page->capacity; + } mi_page_set_in_full(page, mi_page_queue_is_full(queue)); // mi_atomic_store_ptr_release(mi_atomic_cast(void*, &page->heap), heap); page->next = queue->first; @@ -339,6 +346,12 @@ static void mi_page_queue_enqueue_from_ex(mi_page_queue_t* to, mi_page_queue_t* } } + if (mi_page_queue_is_full(to)) { + heap->full_page_size += mi_page_block_size(page) * page->capacity; + } else if (mi_page_queue_is_full(from)) { + mi_assert_internal(heap->full_page_size >= mi_page_block_size(page) * page->capacity); + heap->full_page_size -= mi_page_block_size(page) * page->capacity; + } mi_page_set_in_full(page, mi_page_queue_is_full(to)); } ================================================ FILE: patches/mimalloc-v2.2.4/4_fix_heap_collect.patch ================================================ diff --git a/src/heap.c b/src/heap.c index f96e60d0..5cb7c1ff 100644 --- a/src/heap.c +++ b/src/heap.c @@ -24,7 +24,7 @@ terms of the MIT license. A copy of the license can be found in the file typedef bool (heap_page_visitor_fun)(mi_heap_t* heap, mi_page_queue_t* pq, mi_page_t* page, void* arg1, void* arg2); // Visit all pages in a heap; returns `false` if break was called. -static bool mi_heap_visit_pages(mi_heap_t* heap, heap_page_visitor_fun* fn, void* arg1, void* arg2) +static bool mi_heap_visit_pages(mi_heap_t* heap, size_t max_q_id, heap_page_visitor_fun* fn, void* arg1, void* arg2) { if (heap==NULL || heap->page_count==0) return 0; @@ -34,7 +34,7 @@ static bool mi_heap_visit_pages(mi_heap_t* heap, heap_page_visitor_fun* fn, void size_t count = 0; #endif - for (size_t i = 0; i <= MI_BIN_FULL; i++) { + for (size_t i = 0; i <= max_q_id; i++) { mi_page_queue_t* pq = &heap->pages[i]; mi_page_t* page = pq->first; while(page != NULL) { @@ -47,7 +47,6 @@ static bool mi_heap_visit_pages(mi_heap_t* heap, heap_page_visitor_fun* fn, void page = next; // and continue } } - mi_assert_internal(count == total); return true; } @@ -67,7 +66,7 @@ static bool mi_heap_page_is_valid(mi_heap_t* heap, mi_page_queue_t* pq, mi_page_ #if MI_DEBUG>=3 static bool mi_heap_is_valid(mi_heap_t* heap) { mi_assert_internal(heap!=NULL); - mi_heap_visit_pages(heap, &mi_heap_page_is_valid, NULL, NULL); + mi_heap_visit_pages(heap, MI_BIN_FULL, &mi_heap_page_is_valid, NULL, NULL); return true; } #endif @@ -149,7 +148,7 @@ static void mi_heap_collect_ex(mi_heap_t* heap, mi_collect_t collect) // if abandoning, mark all pages to no longer add to delayed_free if (collect == MI_ABANDON) { - mi_heap_visit_pages(heap, &mi_heap_page_never_delayed_free, NULL, NULL); + mi_heap_visit_pages(heap, MI_BIN_FULL, &mi_heap_page_never_delayed_free, NULL, NULL); } // free all current thread delayed blocks. @@ -160,7 +159,7 @@ static void mi_heap_collect_ex(mi_heap_t* heap, mi_collect_t collect) _mi_heap_collect_retired(heap, force); // collect all pages owned by this thread - mi_heap_visit_pages(heap, &mi_heap_page_collect, &collect, NULL); + mi_heap_visit_pages(heap, collect == MI_NORMAL ? MI_BIN_HUGE : MI_BIN_FULL, &mi_heap_page_collect, &collect, NULL); mi_assert_internal( collect != MI_ABANDON || mi_atomic_load_ptr_acquire(mi_block_t,&heap->thread_delayed_free) == NULL ); // collect abandoned segments (in particular, purge expired parts of segments in the abandoned segment list) @@ -368,7 +367,7 @@ static bool _mi_heap_page_destroy(mi_heap_t* heap, mi_page_queue_t* pq, mi_page_ } void _mi_heap_destroy_pages(mi_heap_t* heap) { - mi_heap_visit_pages(heap, &_mi_heap_page_destroy, NULL, NULL); + mi_heap_visit_pages(heap, MI_BIN_FULL, &_mi_heap_page_destroy, NULL, NULL); mi_heap_reset_pages(heap); } @@ -539,7 +538,7 @@ bool mi_heap_check_owned(mi_heap_t* heap, const void* p) { if (heap==NULL || !mi_heap_is_initialized(heap)) return false; if (((uintptr_t)p & (MI_INTPTR_SIZE - 1)) != 0) return false; // only aligned pointers bool found = false; - mi_heap_visit_pages(heap, &mi_heap_page_check_owned, (void*)p, &found); + mi_heap_visit_pages(heap, MI_BIN_FULL, &mi_heap_page_check_owned, (void*)p, &found); return found; } @@ -705,7 +704,7 @@ static bool mi_heap_visit_areas_page(mi_heap_t* heap, mi_page_queue_t* pq, mi_pa // Visit all heap pages as areas static bool mi_heap_visit_areas(const mi_heap_t* heap, mi_heap_area_visit_fun* visitor, void* arg) { if (visitor == NULL) return false; - return mi_heap_visit_pages((mi_heap_t*)heap, &mi_heap_visit_areas_page, (void*)(visitor), arg); // note: function pointer to void* :-{ + return mi_heap_visit_pages((mi_heap_t*)heap, MI_BIN_FULL, &mi_heap_visit_areas_page, (void*)(visitor), arg); // note: function pointer to void* :-{ } // Just to pass arguments ================================================ FILE: pyproject.toml ================================================ [tool.black] line-length = 100 include = '\.py$' extend-exclude = ''' /( | .git | .__pycache__ | build-dbg | build-opt | helio )/ ''' ================================================ FILE: src/.gitignore ================================================ server/version.cc ================================================ FILE: src/CMakeLists.txt ================================================ option(ENABLE_GIT_VERSION "Build with Git metadata" OFF) option(WITH_SIMSIMD "Enable SimSIMD vector optimizations" OFF) option(SIMSIMD_NATIVE_F16 "Enable native float16 support in SimSIMD" OFF) option(WITH_SEARCH "Enable compilation of search module" ON) if ("${CMAKE_SYSTEM_NAME}" STREQUAL "FreeBSD") set(DFLY_TOOLS_MAKE "gmake") else() set(DFLY_TOOLS_MAKE "make") endif() function(cur_gen_dir out_dir) file(RELATIVE_PATH _rel_folder "${PROJECT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") set(_tmp_dir ${ROOT_GEN_DIR}/${_rel_folder}) set(${out_dir} ${_tmp_dir} PARENT_SCOPE) file(MAKE_DIRECTORY ${_tmp_dir}) endfunction() set(ROOT_GEN_DIR ${CMAKE_SOURCE_DIR}/genfiles) file(MAKE_DIRECTORY ${ROOT_GEN_DIR}) include_directories(${ROOT_GEN_DIR}/src) function(gen_bison name) GET_FILENAME_COMPONENT(_in ${name}.y ABSOLUTE) cur_gen_dir(gen_dir) # add_library(${lib_name} ${gen_dir}/${name}.cc) set(full_path_cc ${gen_dir}/${name}.cc ${gen_dir}/${name}.hh) ADD_CUSTOM_COMMAND( OUTPUT ${full_path_cc} COMMAND mkdir -p ${gen_dir} COMMAND bison --language=c++ -o ${gen_dir}/${name}.cc ${name}.y -Wconflicts-sr DEPENDS ${_in} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMENT "Generating parser from ${name}.y" VERBATIM) set_source_files_properties(${name}.cc ${name}_base.h PROPERTIES GENERATED TRUE) endfunction() Message(STATUS "THIRD_PARTY_LIB_DIR ${THIRD_PARTY_LIB_DIR}") include(external_libs.cmake) if(ENABLE_GIT_VERSION) include(GetGitRevisionDescription.cmake) get_git_head_revision(GIT_REFSPEC GIT_SHA1) git_local_changes(GIT_CLEAN_DIRTY) if("${GIT_CLEAN_DIRTY}" STREQUAL "DIRTY") set(GIT_CLEAN_DIRTY "-dirty") else() set(GIT_CLEAN_DIRTY "") endif() Message(STATUS "GIT_SHA1 ${GIT_SHA1}") git_describe(GIT_VER --always) Message(STATUS "GIT_VER ${GIT_VER}") string(TIMESTAMP PRJ_BUILD_TIME "%Y-%m-%d %H:%M:%S" UTC) else(ENABLE_GIT_VERSION) set(GIT_VER "dev") set(GIT_SHA1 "0000000") set(GIT_CLEAN_DIRTY "-dev") set(PRJ_BUILD_TIME "bigbang") endif(ENABLE_GIT_VERSION) function(gen_flex name) GET_FILENAME_COMPONENT(_in ${name}.lex ABSOLUTE) cur_gen_dir(gen_dir) ADD_CUSTOM_COMMAND( OUTPUT ${gen_dir}/${name}.cc ${gen_dir}/${name}.h COMMAND mkdir -p ${gen_dir} COMMAND ${REFLEX} -o ${gen_dir}/${name}.cc --unicode --header-file=${gen_dir}/${name}.h --bison-complete --bison-locations ${_in} DEPENDS ${_in} reflex_project WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMENT "Generating lexer from ${name}.lex" VERBATIM) set_source_files_properties(${gen_dir}/${name}.h ${gen_dir}/${name}.cc PROPERTIES GENERATED TRUE) endfunction() # the output file resides in the build directory. configure_file(server/version.cc.in "${CMAKE_CURRENT_SOURCE_DIR}/server/version.cc" @ONLY) add_subdirectory(redis) add_subdirectory(core) add_subdirectory(facade) add_subdirectory(server) ================================================ FILE: src/GetGitRevisionDescription.cmake ================================================ # - Returns a version string from Git # # These functions force a re-configure on each git commit so that you can # trust the values of the variables in your build system. # # get_git_head_revision( [ALLOW_LOOKING_ABOVE_CMAKE_SOURCE_DIR]) # # Returns the refspec and sha hash of the current head revision # # git_describe( [ ...]) # # Returns the results of git describe on the source tree, and adjusting # the output so that it tests false if an error occurs. # # git_describe_working_tree( [ ...]) # # Returns the results of git describe on the working tree (--dirty option), # and adjusting the output so that it tests false if an error occurs. # # git_get_exact_tag( [ ...]) # # Returns the results of git describe --exact-match on the source tree, # and adjusting the output so that it tests false if there was no exact # matching tag. # # git_local_changes() # # Returns either "CLEAN" or "DIRTY" with respect to uncommitted changes. # Uses the return code of "git diff-index --quiet HEAD --". # Does not regard untracked files. # # Requires CMake 2.6 or newer (uses the 'function' command) # # Original Author: # 2009-2020 Ryan Pavlik # http://academic.cleardefinition.com # # Copyright 2009-2013, Iowa State University. # Copyright 2013-2020, Ryan Pavlik # Copyright 2013-2020, Contributors # SPDX-License-Identifier: BSL-1.0 # Distributed under the Boost Software License, Version 1.0. # (See accompanying file LICENSE_1_0.txt or copy at # http://www.boost.org/LICENSE_1_0.txt) if(__get_git_revision_description) return() endif() set(__get_git_revision_description YES) # We must run the following at "include" time, not at function call time, # to find the path to this module rather than the path to a calling list file get_filename_component(_gitdescmoddir ${CMAKE_CURRENT_LIST_FILE} PATH) # Function _git_find_closest_git_dir finds the next closest .git directory # that is part of any directory in the path defined by _start_dir. # The result is returned in the parent scope variable whose name is passed # as variable _git_dir_var. If no .git directory can be found, the # function returns an empty string via _git_dir_var. # # Example: Given a path C:/bla/foo/bar and assuming C:/bla/.git exists and # neither foo nor bar contain a file/directory .git. This wil return # C:/bla/.git # function(_git_find_closest_git_dir _start_dir _git_dir_var) set(cur_dir "${_start_dir}") set(git_dir "${_start_dir}/.git") while(NOT EXISTS "${git_dir}") # .git dir not found, search parent directories set(git_previous_parent "${cur_dir}") get_filename_component(cur_dir "${cur_dir}" DIRECTORY) if(cur_dir STREQUAL git_previous_parent) # We have reached the root directory, we are not in git set(${_git_dir_var} "" PARENT_SCOPE) return() endif() set(git_dir "${cur_dir}/.git") endwhile() set(${_git_dir_var} "${git_dir}" PARENT_SCOPE) endfunction() function(get_git_head_revision _refspecvar _hashvar) _git_find_closest_git_dir("${CMAKE_CURRENT_SOURCE_DIR}" GIT_DIR) if("${ARGN}" STREQUAL "ALLOW_LOOKING_ABOVE_CMAKE_SOURCE_DIR") set(ALLOW_LOOKING_ABOVE_CMAKE_SOURCE_DIR TRUE) else() set(ALLOW_LOOKING_ABOVE_CMAKE_SOURCE_DIR FALSE) endif() if(NOT "${GIT_DIR}" STREQUAL "") file(RELATIVE_PATH _relative_to_source_dir "${CMAKE_SOURCE_DIR}" "${GIT_DIR}") if("${_relative_to_source_dir}" MATCHES "[.][.]" AND NOT ALLOW_LOOKING_ABOVE_CMAKE_SOURCE_DIR) # We've gone above the CMake root dir. set(GIT_DIR "") endif() endif() if("${GIT_DIR}" STREQUAL "") set(${_refspecvar} "GITDIR-NOTFOUND" PARENT_SCOPE) set(${_hashvar} "GITDIR-NOTFOUND" PARENT_SCOPE) return() endif() # Check if the current source dir is a git submodule or a worktree. # In both cases .git is a file instead of a directory. # if(NOT IS_DIRECTORY ${GIT_DIR}) # The following git command will return a non empty string that # points to the super project working tree if the current # source dir is inside a git submodule. # Otherwise the command will return an empty string. # execute_process( COMMAND "${GIT_EXECUTABLE}" rev-parse --show-superproject-working-tree WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" OUTPUT_VARIABLE out ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) if(NOT "${out}" STREQUAL "") # If out is empty, GIT_DIR/CMAKE_CURRENT_SOURCE_DIR is in a submodule file(READ ${GIT_DIR} submodule) string(REGEX REPLACE "gitdir: (.*)$" "\\1" GIT_DIR_RELATIVE ${submodule}) string(STRIP ${GIT_DIR_RELATIVE} GIT_DIR_RELATIVE) get_filename_component(SUBMODULE_DIR ${GIT_DIR} PATH) get_filename_component(GIT_DIR ${SUBMODULE_DIR}/${GIT_DIR_RELATIVE} ABSOLUTE) set(HEAD_SOURCE_FILE "${GIT_DIR}/HEAD") else() # GIT_DIR/CMAKE_CURRENT_SOURCE_DIR is in a worktree file(READ ${GIT_DIR} worktree_ref) # The .git directory contains a path to the worktree information directory # inside the parent git repo of the worktree. # string(REGEX REPLACE "gitdir: (.*)$" "\\1" git_worktree_dir ${worktree_ref}) string(STRIP ${git_worktree_dir} git_worktree_dir) _git_find_closest_git_dir("${git_worktree_dir}" GIT_DIR) set(HEAD_SOURCE_FILE "${git_worktree_dir}/HEAD") endif() else() set(HEAD_SOURCE_FILE "${GIT_DIR}/HEAD") endif() set(GIT_DATA "${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/git-data") if(NOT EXISTS "${GIT_DATA}") file(MAKE_DIRECTORY "${GIT_DATA}") endif() if(NOT EXISTS "${HEAD_SOURCE_FILE}") return() endif() set(HEAD_FILE "${GIT_DATA}/HEAD") configure_file("${HEAD_SOURCE_FILE}" "${HEAD_FILE}" COPYONLY) configure_file("${_gitdescmoddir}/GetGitRevisionDescription.cmake.in" "${GIT_DATA}/grabRef.cmake" @ONLY) include("${GIT_DATA}/grabRef.cmake") set(${_refspecvar} "${HEAD_REF}" PARENT_SCOPE) set(${_hashvar} "${HEAD_HASH}" PARENT_SCOPE) endfunction() function(git_describe _var) if(NOT GIT_FOUND) find_package(Git QUIET) endif() get_git_head_revision(refspec hash) if(NOT GIT_FOUND) set(${_var} "GIT-NOTFOUND" PARENT_SCOPE) return() endif() if(NOT hash) set(${_var} "HEAD-HASH-NOTFOUND" PARENT_SCOPE) return() endif() # TODO sanitize #if((${ARGN}" MATCHES "&&") OR # (ARGN MATCHES "||") OR # (ARGN MATCHES "\\;")) # message("Please report the following error to the project!") # message(FATAL_ERROR "Looks like someone's doing something nefarious with git_describe! Passed arguments ${ARGN}") #endif() #message(STATUS "Arguments to execute_process: ${ARGN}") execute_process( COMMAND "${GIT_EXECUTABLE}" describe --tags --always ${hash} ${ARGN} WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" RESULT_VARIABLE res OUTPUT_VARIABLE out ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) if(NOT res EQUAL 0) set(out "${out}-${res}-NOTFOUND") endif() set(${_var} "${out}" PARENT_SCOPE) endfunction() function(git_describe_working_tree _var) if(NOT GIT_FOUND) find_package(Git QUIET) endif() if(NOT GIT_FOUND) set(${_var} "GIT-NOTFOUND" PARENT_SCOPE) return() endif() execute_process( COMMAND "${GIT_EXECUTABLE}" describe --dirty ${ARGN} WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" RESULT_VARIABLE res OUTPUT_VARIABLE out ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) if(NOT res EQUAL 0) set(out "${out}-${res}-NOTFOUND") endif() set(${_var} "${out}" PARENT_SCOPE) endfunction() function(git_get_exact_tag _var) git_describe(out --exact-match ${ARGN}) set(${_var} "${out}" PARENT_SCOPE) endfunction() function(git_local_changes _var) if(NOT GIT_FOUND) find_package(Git QUIET) endif() get_git_head_revision(refspec hash) if(NOT GIT_FOUND) set(${_var} "GIT-NOTFOUND" PARENT_SCOPE) return() endif() if(NOT hash) set(${_var} "HEAD-HASH-NOTFOUND" PARENT_SCOPE) return() endif() execute_process( COMMAND "${GIT_EXECUTABLE}" diff-index --quiet HEAD -- WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" RESULT_VARIABLE res OUTPUT_VARIABLE out ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) if(res EQUAL 0) set(${_var} "CLEAN" PARENT_SCOPE) else() set(${_var} "DIRTY" PARENT_SCOPE) endif() endfunction() ================================================ FILE: src/GetGitRevisionDescription.cmake.in ================================================ # # Internal file for GetGitRevisionDescription.cmake # # Requires CMake 2.6 or newer (uses the 'function' command) # # Original Author: # 2009-2010 Ryan Pavlik # http://academic.cleardefinition.com # Iowa State University HCI Graduate Program/VRAC # # Copyright 2009-2012, Iowa State University # Copyright 2011-2015, Contributors # Distributed under the Boost Software License, Version 1.0. # (See accompanying file LICENSE_1_0.txt or copy at # http://www.boost.org/LICENSE_1_0.txt) # SPDX-License-Identifier: BSL-1.0 set(HEAD_HASH) file(READ "@HEAD_FILE@" HEAD_CONTENTS LIMIT 1024) string(STRIP "${HEAD_CONTENTS}" HEAD_CONTENTS) if(HEAD_CONTENTS MATCHES "ref") # named branch string(REPLACE "ref: " "" HEAD_REF "${HEAD_CONTENTS}") if(EXISTS "@GIT_DIR@/${HEAD_REF}") configure_file("@GIT_DIR@/${HEAD_REF}" "@GIT_DATA@/head-ref" COPYONLY) else() configure_file("@GIT_DIR@/packed-refs" "@GIT_DATA@/packed-refs" COPYONLY) file(READ "@GIT_DATA@/packed-refs" PACKED_REFS) if(${PACKED_REFS} MATCHES "([0-9a-z]*) ${HEAD_REF}") set(HEAD_HASH "${CMAKE_MATCH_1}") endif() endif() else() # detached HEAD configure_file("@GIT_DIR@/HEAD" "@GIT_DATA@/head-ref" COPYONLY) endif() if(NOT HEAD_HASH) file(READ "@GIT_DATA@/head-ref" HEAD_HASH LIMIT 1024) string(STRIP "${HEAD_HASH}" HEAD_HASH) endif() ================================================ FILE: src/common/arg_range.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "base/iterator.h" namespace cmn { using ArgSlice = absl::Span; using OwnedArgSlice = absl::Span; inline std::string_view ToSV(std::string_view slice) { return slice; } inline std::string_view ToSV(const std::string& slice) { return slice; } inline std::string_view ToSV(std::string&& slice) = delete; constexpr auto kToSV = [](auto&& v) { return ToSV(std::forward(v)); }; struct ArgRange { ArgRange(ArgRange&&) = default; ArgRange(const ArgRange&) = default; ArgRange(ArgRange& range) : ArgRange((const ArgRange&)range) { } template , bool> = true> ArgRange(T&& span) : span(std::forward(span)) { // NOLINT google-explicit-constructor) } size_t Size() const { return std::visit([](const auto& span) { return span.size(); }, span); } auto Range() const { return base::it::Wrap(kToSV, span); } auto begin() const { return Range().first; } auto end() const { return Range().second; } std::string_view operator[](size_t idx) const { return std::visit([idx](const auto& span) -> std::string_view { return span[idx]; }, span); } std::variant span; }; } // namespace cmn ================================================ FILE: src/common/backed_args.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace cmn { class BackedArguments { constexpr static size_t kLenCap = 5; constexpr static size_t kStorageCap = 88; public: using value_type = std::string_view; BackedArguments() { } class iterator { public: using iterator_category = std::random_access_iterator_tag; using value_type = std::string_view; using difference_type = std::ptrdiff_t; using pointer = const std::string_view*; using reference = std::string_view; iterator(const BackedArguments* ba, size_t index) : ba_(ba), index_(index) { } iterator& operator++() { ++index_; return *this; } iterator& operator--() { --index_; return *this; } iterator& operator+=(int delta) { index_ += delta; return *this; } iterator operator+(int delta) const { iterator res(*this); res += delta; return res; } ptrdiff_t operator-(iterator other) const { return ptrdiff_t(index_) - ptrdiff_t(other.index_); } bool operator==(const iterator& other) const { return index_ == other.index_ && ba_ == other.ba_; } bool operator!=(const iterator& other) const { return !(*this == other); } std::string_view operator*() const { return ba_->at(index_); } private: const BackedArguments* ba_; size_t index_; }; // Construct the arguments from iterator range. // TODO: In general we could get away without the len argument, // but that would require fixing base::it::CompoundIterator to support subtraction. // Similarly, I wish that CompoundIterator supported the -> operator. template BackedArguments(I begin, I end, size_t len) { Assign(begin, end, len); } template void Assign(I begin, I end, size_t len); void Reserve(size_t arg_cnt, size_t total_size) { offsets_.reserve(arg_cnt); storage_.reserve(total_size); } size_t HeapMemory() const { size_t s1 = offsets_.capacity() <= kLenCap ? 0 : offsets_.capacity() * sizeof(uint32_t); size_t s2 = storage_.capacity() <= kStorageCap ? 0 : storage_.capacity(); return s1 + s2; } void SwapArgs(cmn::BackedArguments& other) { offsets_.swap(other.offsets_); storage_.swap(other.storage_); } // The capacity is chosen so that we allocate a fully utilized (128 bytes) block. using StorageType = absl::InlinedVector; std::string_view Front() const { return std::string_view{storage_.data(), elem_len(0)}; } size_t size() const { return offsets_.size(); } bool empty() const { return offsets_.empty(); } size_t elem_len(size_t i) const { return elem_capacity(i) - 1; } size_t elem_capacity(size_t i) const { uint32_t next_offs = i + 1 >= offsets_.size() ? storage_.size() : offsets_[i + 1]; return next_offs - offsets_[i]; } std::string_view at(uint32_t index) const { uint32_t offset = offsets_[index]; return std::string_view{storage_.data() + offset, elem_len(index)}; } char* data(uint32_t index) { uint32_t offset = offsets_[index]; return storage_.data() + offset; } std::string_view operator[](uint32_t index) const { return at(index); } iterator begin() const { return {this, 0}; } iterator end() const { return {this, offsets_.size()}; } void clear() { // Clear the contents without deallocating memory. clear() deallocates inlined_vector. offsets_.resize(0); storage_.resize(0); } std::string_view back() const { assert(size() > 0); return at(size() - 1); } // Reserves space for additional argument of given length at the end. void PushArg(size_t len) { size_t old_size = storage_.size(); offsets_.push_back(old_size); storage_.resize(old_size + len + 1); } void PushArg(std::string_view arg) { PushArg(arg.size()); char* dest = storage_.data() + offsets_.back(); if (arg.size() > 0) memcpy(dest, arg.data(), arg.size()); dest[arg.size()] = '\0'; } void PopArg() { uint32_t last_offs = offsets_.back(); offsets_.pop_back(); storage_.resize(last_offs); } protected: absl::InlinedVector offsets_; StorageType storage_; }; static_assert(sizeof(BackedArguments) == 128); template void BackedArguments::Assign(I begin, I end, size_t len) { offsets_.resize(len); size_t total_size = 0; unsigned idx = 0; for (auto it = begin; it != end; ++it) { offsets_[idx++] = total_size; total_size += (*it).size() + 1; // +1 for '\0' } storage_.resize(total_size); // Reclaim memory if we have too much allocated. if (storage_.capacity() > kStorageCap && total_size < storage_.capacity() / 2) storage_.shrink_to_fit(); char* next = storage_.data(); for (auto it = begin; it != end; ++it) { size_t sz = (*it).size(); if (sz > 0) { memcpy(next, (*it).data(), sz); } next[sz] = '\0'; next += sz + 1; } } } // namespace cmn ================================================ FILE: src/common/heap_size.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // // This file provides utilities to *estimate* heap memory usage of classes. // The main function exposed here is HeapSize() (with various overloads). // It supports simple structs (returns 0), std::string (returns capacity if it's larger than SSO) // and common containers, such as std::vector, std::deque, absl::flat_hash_map and unique_ptr. // // Example usage: // absl::flat_hash_map>> m; // ... // size_t size = HeapSize(m); #pragma once #include #include #include #include #include #include #include #include #include namespace cmn { namespace heap_size_detail { template struct has_marked_stackonly : std::false_type {}; template struct has_marked_stackonly> : std::true_type {}; template constexpr bool StackOnlyType() { return std::is_trivial_v || std::is_same_v || has_marked_stackonly::value; } template struct has_used_mem : std::false_type {}; template struct has_used_mem> : std::true_type {}; template size_t AccumulateContainer(const Container& c); } // namespace heap_size_detail inline size_t HeapSize(const std::string& s) { constexpr size_t kSmallStringOptSize = 15; return s.capacity() > kSmallStringOptSize ? s.capacity() : 0UL; } template ::value, bool> = true> size_t HeapSize(const T& t) { return t.UsedMemory(); } template (), bool> = true> size_t HeapSize(const T& t) { return 0; } template size_t HeapSize(absl::Span) { return 0; } // Declare first, so that we can use these "recursively" template size_t HeapSize(const std::vector& v); template size_t HeapSize(const std::unique_ptr& t); template size_t HeapSize(const std::deque& d); template size_t HeapSize(const std::pair& p); template size_t HeapSize(const absl::InlinedVector& v); template size_t HeapSize(const absl::flat_hash_map& m); template size_t HeapSize(const absl::flat_hash_set& s); template size_t HeapSize(const std::unique_ptr& t) { if (t == nullptr) { return 0; } else { return sizeof(T) + HeapSize(*t); } } template size_t HeapSize(const std::vector& v) { return (v.capacity() * sizeof(T)) + heap_size_detail::AccumulateContainer(v); } template size_t HeapSize(const std::deque& d) { return (d.size() * sizeof(T)) + heap_size_detail::AccumulateContainer(d); } template size_t HeapSize(const std::pair& p) { return HeapSize(p.first) + HeapSize(p.second); } template size_t HeapSize(const absl::InlinedVector& v) { size_t size = 0; if (v.capacity() > N) { size += v.capacity() * sizeof(T); } size += heap_size_detail::AccumulateContainer(v); return size; } template size_t HeapSize(const absl::flat_hash_map& m) { size_t size = m.capacity() * sizeof(typename absl::flat_hash_map::value_type); if constexpr (!heap_size_detail::StackOnlyType() || !heap_size_detail::StackOnlyType()) { for (const auto& kv : m) { size += HeapSize(kv); } } return size; } template size_t HeapSize(const absl::flat_hash_set& s) { size_t size = s.capacity() * sizeof(typename absl::flat_hash_set::value_type); if constexpr (!heap_size_detail::StackOnlyType()) { for (const auto& k : s) { size += HeapSize(k); } } return size; } namespace heap_size_detail { template size_t AccumulateContainer(const Container& c) { size_t size = 0; if constexpr (!heap_size_detail::StackOnlyType()) { for (const auto& e : c) { size += HeapSize(e); } } return size; } } // namespace heap_size_detail } // namespace cmn ================================================ FILE: src/common/string_or_view.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace cmn { class StringOrView { public: static StringOrView FromString(std::string s) { StringOrView sov; sov.val_ = std::move(s); return sov; } static StringOrView FromView(std::string_view sv) { StringOrView sov; sov.val_ = sv; return sov; } StringOrView() = default; StringOrView(const StringOrView& o) = default; StringOrView(StringOrView&& o) = default; StringOrView& operator=(const StringOrView& o) = default; StringOrView& operator=(StringOrView&& o) = default; bool operator==(const StringOrView& o) const { return *this == o.view(); } bool operator==(std::string_view o) const { return view() == o; } bool operator!=(const StringOrView& o) const { return *this != o.view(); } bool operator!=(std::string_view o) const { return !(*this == o); } std::string_view view() const { return visit([](const auto& s) -> std::string_view { return s; }, val_); } friend std::ostream& operator<<(std::ostream& o, const StringOrView& key) { return o << key.view(); } // Make hashable template friend H AbslHashValue(H h, const StringOrView& c) { return H::combine(std::move(h), c.view()); } // If the key is backed by a string_view, replace it with a string with the same value void MakeOwned() { if (std::holds_alternative(val_)) val_ = std::string{std::get(val_)}; } // Move out of value as string std::string Take() && { MakeOwned(); return std::move(std::get(val_)); } std::string* GetMutable() { MakeOwned(); return &std::get(val_); } bool empty() const { return visit([](const auto& s) { return s.empty(); }, val_); } private: std::variant val_; }; } // namespace cmn ================================================ FILE: src/core/CMakeLists.txt ================================================ find_library(LIB_PCRE2 NAMES pcre2-8) if(LIB_PCRE2) set(PCRE2_LIB ${LIB_PCRE2}) else() message(STATUS "pcre2-8 not found. Building without PCRE2 support.") set(PCRE2_LIB "") endif() find_library(LIB_RE2 NAMES re2) if(LIB_RE2) set(RE2_LIB ${LIB_RE2}) else() message(STATUS "re2 not found. Building without RE2 support.") set(RE2_LIB "") endif() if (WITH_SEARCH) add_subdirectory(search) else() add_library(dfly_search_core INTERFACE) endif() add_subdirectory(json) add_subdirectory(page_usage) add_library(dfly_core allocation_tracker.cc bloom.cc topk.cc compact_object.cc cms.cc dense_set.cc dragonfly_core.cc extent_tree.cc huff_coder.cc interpreter.cc glob_matcher.cc mi_memory_resource.cc qlist.cc dict_builder.cc sds_utils.cc segment_allocator.cc score_map.cc small_string.cc sorted_map.cc task_queue.cc tx_queue.cc string_set.cc string_map.cc tiering_types.cc top_keys.cc detail/bitpacking.cc detail/listpack_wrap.cc detail/listpack.cc oah_entry.cc) cxx_link(dfly_core base dfly_search_core dfly_page_usage fibers2 jsonpath absl::flat_hash_map absl::str_format absl::random_random redis_lib TRDP::lua lua_modules OpenSSL::Crypto TRDP::dconv TRDP::lz4 TRDP::hdr_histogram) add_executable(dash_bench dash_bench.cc) cxx_link(dash_bench dfly_core redis_test_lib) helio_cxx_test(dfly_core_test dfly_core TRDP::fast_float ${PCRE2_LIB} ${RE2_LIB} LABELS DFLY) helio_cxx_test(compact_object_test dfly_core LABELS DFLY) helio_cxx_test(extent_tree_test dfly_core LABELS DFLY) helio_cxx_test(dash_test dfly_core file redis_test_lib DATA testdata/ids.txt.zst LABELS DFLY) helio_cxx_test(interpreter_test dfly_core LABELS DFLY) helio_cxx_test(string_set_test dfly_core LABELS DFLY) helio_cxx_test(string_map_test dfly_core LABELS DFLY) helio_cxx_test(oah_set_test dfly_core LABELS DFLY) helio_cxx_test(sorted_map_test dfly_core redis_test_lib LABELS DFLY) helio_cxx_test(bptree_set_test dfly_core LABELS DFLY) helio_cxx_test(linear_search_map_test dfly_core LABELS DFLY) helio_cxx_test(score_map_test dfly_core LABELS DFLY) helio_cxx_test(flatbuffers_test dfly_core TRDP::flatbuffers LABELS DFLY) helio_cxx_test(bloom_test dfly_core LABELS DFLY) helio_cxx_test(allocation_tracker_test dfly_core absl::random_random LABELS DFLY) helio_cxx_test(qlist_test dfly_core DATA testdata/list.txt.zst LABELS DFLY) helio_cxx_test(listpack_test dfly_core redis_lib LABELS DFLY) helio_cxx_test(zstd_test dfly_core TRDP::zstd LABELS DFLY) helio_cxx_test(dict_builder_test dfly_core LABELS DFLY) helio_cxx_test(top_keys_test dfly_core LABELS DFLY) helio_cxx_test(topk_test dfly_core LABELS DFLY) helio_cxx_test(page_usage_stats_test dfly_core LABELS DFLY) helio_cxx_test(cms_test dfly_core LABELS DFLY) helio_cxx_test(memory_test TRDP::mimalloc2 LABELS DFLY) if(LIB_PCRE2) target_compile_definitions(dfly_core_test PRIVATE USE_PCRE2=1) # target_compile_definitions(dfly_core PUBLIC USE_PCRE2=1) endif() if(LIB_RE2) target_compile_definitions(dfly_core_test PRIVATE USE_RE2) endif() ================================================ FILE: src/core/allocation_tracker.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/allocation_tracker.h" #include "absl/random/random.h" #include "base/logging.h" #include "util/fibers/stacktrace.h" namespace dfly { namespace { thread_local AllocationTracker g_tracker; thread_local absl::InsecureBitGen g_bitgen; bool CanCallVlog(std::string_view trace) { // GLOG fails when logging while flushing the current log under a mutex return trace.find("LogMessage::Flush") == std::string::npos; } } // namespace AllocationTracker& AllocationTracker::Get() { return g_tracker; } bool AllocationTracker::Add(const TrackingInfo& info) { if (tracking_.size() >= tracking_.capacity()) { return false; } tracking_.push_back(info); UpdateAbsSizes(); return true; } bool AllocationTracker::Remove(size_t lower_bound, size_t upper_bound) { size_t before_size = tracking_.size(); tracking_.erase(std::remove_if(tracking_.begin(), tracking_.end(), [&](const TrackingInfo& info) { return info.lower_bound == lower_bound && info.upper_bound == upper_bound; }), tracking_.end()); UpdateAbsSizes(); return before_size != tracking_.size(); } void AllocationTracker::Clear() { tracking_.clear(); } absl::Span AllocationTracker::GetRanges() const { return absl::MakeConstSpan(tracking_); } void AllocationTracker::ProcessNew(void* ptr, size_t size) { if (size < abs_min_size_ || size > abs_max_size_) { return; } if (inside_tracker_) { return; } // Prevent endless recursion, in case logging allocates memory inside_tracker_ = true; for (const auto& band : tracking_) { if (size > band.upper_bound || size < band.lower_bound) { continue; } // Micro optimization: in case sample_odds == 1.0 - do not draw a random number if (band.sample_odds != 1.0 && absl::Uniform(g_bitgen, 0.0, 1.0) >= band.sample_odds) { continue; } size_t usable = mi_usable_size(ptr); std::string trace = util::fb2::GetStacktrace(); if (CanCallVlog(trace)) { DCHECK_GE(usable, size); LOG(INFO) << "Allocating " << usable << " bytes (" << ptr << "). Stack: " << trace; } break; } inside_tracker_ = false; } void AllocationTracker::ProcessDelete(void* ptr) { if (inside_tracker_) { return; } inside_tracker_ = true; // we partially handle deletes, specifically when specifying a single range with // 100% sampling rate. if (tracking_.size() == 1 && tracking_.front().sample_odds == 1) { size_t usable = mi_usable_size(ptr); if (usable <= tracking_.front().upper_bound && usable >= tracking_.front().lower_bound) { std::string trace = util::fb2::GetStacktrace(); LOG_IF(INFO, CanCallVlog(trace)) << "Deallocating " << usable << " bytes (" << ptr << ")\n" << trace; } } inside_tracker_ = false; } void AllocationTracker::UpdateAbsSizes() { abs_min_size_ = 0; abs_max_size_ = 0; for (const auto& tracker : tracking_) { abs_min_size_ = std::min(abs_min_size_, tracker.lower_bound); abs_max_size_ = std::max(abs_max_size_, tracker.upper_bound); } } } // namespace dfly ================================================ FILE: src/core/allocation_tracker.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly { // Allows "tracking" of memory allocations by size bands. Tracking is naive in that it only prints // the stack trace of the memory allocation, if matched by size & sampling criteria. // Supports up to 4 different bands in parallel. // // Thread-local. Must be configured in all relevant threads separately. // // #define INJECT_ALLOCATION_TRACKER before #include exactly once to override new/delete class AllocationTracker { public: struct TrackingInfo { size_t lower_bound = 0; size_t upper_bound = 0; double sample_odds = 0.0; }; // Returns a thread-local reference. static AllocationTracker& Get(); // Will track memory allocations in range [lower, upper]. Sample odds must be between [0, 1], // where 1 means all allocations are tracked and 0 means none. bool Add(const TrackingInfo& info); // Removes all tracking exactly matching lower_bound and upper_bound. // Returns true if the tracking range [lower_bound, upper_bound] was removed // and false, otherwise. bool Remove(size_t lower_bound, size_t upper_bound); // Clears *all* tracking. void Clear(); absl::Span GetRanges() const; void ProcessNew(void* ptr, size_t size); void ProcessDelete(void* ptr); private: void UpdateAbsSizes(); absl::InlinedVector tracking_; bool inside_tracker_ = false; size_t abs_min_size_ = 0; size_t abs_max_size_ = 0; }; } // namespace dfly #ifdef INJECT_ALLOCATION_TRACKER // Code here is copied from mimalloc-new-delete, and modified to add tracking void operator delete(void* p) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free(p); }; void operator delete[](void* p) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free(p); }; void operator delete(void* p, const std::nothrow_t&) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free(p); } void operator delete[](void* p, const std::nothrow_t&) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free(p); } void* operator new(std::size_t n) noexcept(false) { auto v = mi_new(n); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } void* operator new[](std::size_t n) noexcept(false) { auto v = mi_new(n); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } void* operator new(std::size_t n, const std::nothrow_t& tag) noexcept { (void)(tag); auto v = mi_new_nothrow(n); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } void* operator new[](std::size_t n, const std::nothrow_t& tag) noexcept { (void)(tag); auto v = mi_new_nothrow(n); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } #if (__cplusplus >= 201402L || _MSC_VER >= 1916) void operator delete(void* p, std::size_t n) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_size(p, n); }; void operator delete[](void* p, std::size_t n) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_size(p, n); }; #endif #if (__cplusplus > 201402L || defined(__cpp_aligned_new)) void operator delete(void* p, std::align_val_t al) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_aligned(p, static_cast(al)); } void operator delete[](void* p, std::align_val_t al) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_aligned(p, static_cast(al)); } void operator delete(void* p, std::size_t n, std::align_val_t al) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_size_aligned(p, n, static_cast(al)); }; void operator delete[](void* p, std::size_t n, std::align_val_t al) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_size_aligned(p, n, static_cast(al)); }; void operator delete(void* p, std::align_val_t al, const std::nothrow_t&) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_aligned(p, static_cast(al)); } void operator delete[](void* p, std::align_val_t al, const std::nothrow_t&) noexcept { dfly::AllocationTracker::Get().ProcessDelete(p); mi_free_aligned(p, static_cast(al)); } void* operator new(std::size_t n, std::align_val_t al) noexcept(false) { auto v = mi_new_aligned(n, static_cast(al)); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } void* operator new[](std::size_t n, std::align_val_t al) noexcept(false) { auto v = mi_new_aligned(n, static_cast(al)); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } void* operator new(std::size_t n, std::align_val_t al, const std::nothrow_t&) noexcept { auto v = mi_new_aligned_nothrow(n, static_cast(al)); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } void* operator new[](std::size_t n, std::align_val_t al, const std::nothrow_t&) noexcept { auto v = mi_new_aligned_nothrow(n, static_cast(al)); dfly::AllocationTracker::Get().ProcessNew(v, n); return v; } #endif #endif // INJECT_ALLOCATION_TRACKER ================================================ FILE: src/core/allocation_tracker_test.cc ================================================ #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" #define INJECT_ALLOCATION_TRACKER #include "core/allocation_tracker.h" namespace dfly { namespace { using namespace std; using namespace testing; class LogSink : public google::LogSink { public: void send(google::LogSeverity severity, const char* full_filename, const char* base_filename, int line, const struct tm* tm_time, const char* message, size_t message_len) override { logs_.push_back(string(message, message_len)); } const vector& GetLogs() const { return logs_; } void Clear() { logs_.clear(); } private: vector logs_; }; class AllocationTrackerTest : public Test { protected: AllocationTrackerTest() { google::AddLogSink(&log_sink_); } ~AllocationTrackerTest() { google::RemoveLogSink(&log_sink_); AllocationTracker::Get().Clear(); } vector GetLogsDelta() { auto logs = log_sink_.GetLogs(); log_sink_.Clear(); return logs; } void Allocate(size_t s) { CHECK(buffer_.empty()); buffer_.resize(s); // allocate 1mb before setting up tracking } void Deallocate() { buffer_.clear(); // Force deallocation buffer_.shrink_to_fit(); } private: LogSink log_sink_; string buffer_; }; TEST_F(AllocationTrackerTest, UnusedTracker) { Allocate(1'000'000); // allocate 1mb before setting up tracking EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); } TEST_F(AllocationTrackerTest, UsedTracker) { AllocationTracker::Get().Add( {.lower_bound = 1'000'000, .upper_bound = 2'000'000, .sample_odds = 1.0}); Allocate(1'000'000); // allocate 1mb before setting up tracking EXPECT_THAT(GetLogsDelta(), Contains(HasSubstr("Allocating"))); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Deallocating")))); Deallocate(); EXPECT_THAT(GetLogsDelta(), Contains(HasSubstr("Deallocating"))); // Allocate below threshold Allocate(100'000); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Deallocating")))); // Allocate above threshold Allocate(10'000'000); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Deallocating")))); // Remove allocator - stops logging EXPECT_TRUE(AllocationTracker::Get().Remove(1'000'000, 2'000'000)); Allocate(1'000'000); // allocate 1mb before setting up tracking EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Deallocating")))); } TEST_F(AllocationTrackerTest, MultipleRanges) { AllocationTracker::Get().Add( {.lower_bound = 1'000'000, .upper_bound = 2'000'000, .sample_odds = 1.0}); AllocationTracker::Get().Add( {.lower_bound = 100'000'000, .upper_bound = 200'000'000, .sample_odds = 1.0}); // Below all ranges Allocate(100'000); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); // Between ranges Allocate(10'000'000); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); // Above all ranges Allocate(500'000'000); EXPECT_THAT(GetLogsDelta(), Not(Contains(HasSubstr("Allocating")))); Deallocate(); // First range Allocate(1'000'000); EXPECT_THAT(GetLogsDelta(), Contains(HasSubstr("Allocating"))); Deallocate(); // Second range Allocate(100'000'000); EXPECT_THAT(GetLogsDelta(), Contains(HasSubstr("Allocating"))); Deallocate(); } TEST_F(AllocationTrackerTest, Sampling) { // Statistically, 80% of logs should be logged AllocationTracker::Get().Add( {.lower_bound = 1'000'000, .upper_bound = 2'000'000, .sample_odds = 0.8}); const int kIterations = 10'000; for (int i = 0; i < kIterations; ++i) { Allocate(1'000'000); Deallocate(); } int allocations = 0; int deallocations = 0; for (const string& s : GetLogsDelta()) { if (absl::StrContains(s, "Allocating")) { ++allocations; } if (absl::StrContains(s, "Deallocating")) { ++deallocations; } } EXPECT_GE(allocations, kIterations * 0.7); EXPECT_LE(allocations, kIterations * 0.9); EXPECT_EQ(deallocations, 0); // we only track deletions when sample_odds == 1.0 } } // namespace } // namespace dfly ================================================ FILE: src/core/bloom.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/bloom.h" #include #include #include #include #include #include "base/logging.h" namespace dfly { using namespace std; namespace { XXH128_hash_t Hash(string_view str) { return XXH3_128bits_withSeed(str.data(), str.size(), 0xc6a4a7935bd1e995ULL); // murmur2 seed } uint64_t GetMask(unsigned log) { return (1ULL << log) - 1; } uint64_t BitIndex(uint64_t low, uint64_t hi, unsigned i, uint64_t mask) { return (low + hi * i) & mask; } constexpr double kDenom = M_LN2 * M_LN2; constexpr double kSBFErrorFactor = 0.5; double BPE(double fp_prob) { return -log(fp_prob) / kDenom; } } // namespace Bloom::~Bloom() { CHECK(bf_ == nullptr); } Bloom::Bloom(Bloom&& o) noexcept : hash_cnt_(o.hash_cnt_), bit_log_(o.bit_log_), bf_(o.bf_) { o.bf_ = nullptr; } void Bloom::Init(uint64_t entries, double fp_prob, PMR_NS::memory_resource* heap) { CHECK(bf_ == nullptr); CHECK(fp_prob > 0 && fp_prob < 1); if (fp_prob > 0.5) fp_prob = 0.5; double bpe = BPE(fp_prob); hash_cnt_ = ceil(M_LN2 * bpe); uint64_t bits = uint64_t(ceil(entries * bpe)); if (bits < 512) { bits = 512; } bits = absl::bit_ceil(bits); // make it power of 2. uint64_t length = bits / 8; bf_ = (uint8_t*)heap->allocate(length); memset(bf_, 0, length); bit_log_ = absl::countr_zero(bits); } void Bloom::Init(uint8_t* blob, size_t len, unsigned hash_cnt) { DCHECK_EQ(len * 8, absl::bit_ceil(len * 8)); // must be power of two. CHECK(bf_ == nullptr); hash_cnt_ = hash_cnt; bf_ = blob; bit_log_ = absl::countr_zero(len * 8); } void Bloom::Destroy(PMR_NS::memory_resource* resource) { resource->deallocate(CHECK_NOTNULL(bf_), bitlen() / 8); bf_ = nullptr; } bool Bloom::Exists(std::string_view str) const { XXH128_hash_t hash = Hash(str); uint64_t fp[2] = {hash.low64, hash.high64}; return Exists(fp); } bool Bloom::Exists(const uint64_t fp[2]) const { uint64_t mask = GetMask(bit_log_); for (unsigned i = 0; i < hash_cnt_; ++i) { uint64_t index = BitIndex(fp[0], fp[1], i, mask); if (!IsSet(index)) return false; } return true; } bool Bloom::Add(std::string_view str) { XXH128_hash_t hash = Hash(str); uint64_t fp[2] = {hash.low64, hash.high64}; return Add(fp); } bool Bloom::Add(const uint64_t fp[2]) { uint64_t mask = GetMask(bit_log_); unsigned changes = 0; for (uint64_t i = 0; i < hash_cnt_; i++) { uint64_t index = BitIndex(fp[0], fp[1], i, mask); changes += Set(index); } return changes != 0; } size_t Bloom::Capacity(double fp_prob) const { if (fp_prob > 0.5) fp_prob = 0.5; double bpe = BPE(fp_prob); return floor(bitlen() / bpe); } inline bool Bloom::IsSet(size_t bit_idx) const { uint64_t byte_idx = bit_idx / 8; bit_idx %= 8; // index within the byte uint8_t b = bf_[byte_idx]; return (b & (1 << bit_idx)) != 0; } inline bool Bloom::Set(size_t bit_idx) { uint64_t byte_idx = bit_idx / 8; bit_idx %= 8; uint8_t b = bf_[byte_idx]; bf_[byte_idx] |= (1 << bit_idx); return bf_[byte_idx] != b; } /////////////////////////////////////////////////////////////////////////////// // SBF implementation /////////////////////////////////////////////////////////////////////////////// SBF::SBF(uint64_t initial_capacity, double fp_prob, double grow_factor, PMR_NS::memory_resource* mr) : filters_(1, mr), grow_factor_(grow_factor), fp_prob_(fp_prob * kSBFErrorFactor) { filters_.front().Init(initial_capacity, fp_prob_, mr); max_capacity_ = filters_.front().Capacity(fp_prob_); } SBF::SBF(double grow_factor, double fp_prob, size_t max_capacity, size_t prev_size, size_t current_size, PMR_NS::memory_resource* mr) : filters_(mr), grow_factor_(grow_factor), fp_prob_(fp_prob), prev_size_(prev_size), current_size_(current_size), max_capacity_(max_capacity) { } SBF::~SBF() { PMR_NS::memory_resource* mr = filters_.get_allocator().resource(); for (auto& f : filters_) f.Destroy(mr); } SBF& SBF::operator=(SBF&& src) noexcept { filters_.clear(); filters_.swap(src.filters_); grow_factor_ = src.grow_factor_; fp_prob_ = src.fp_prob_; current_size_ = src.current_size_; max_capacity_ = src.max_capacity_; return *this; } void SBF::AddFilter(const std::string& blob, unsigned hash_cnt) { PMR_NS::memory_resource* mr = filters_.get_allocator().resource(); uint8_t* ptr = (uint8_t*)mr->allocate(blob.size(), 1); memcpy(ptr, blob.data(), blob.size()); filters_.emplace_back().Init(ptr, blob.size(), hash_cnt); } bool SBF::Add(std::string_view str) { DCHECK_LT(current_size_, max_capacity_); XXH128_hash_t hash = Hash(str); uint64_t fp[2] = {hash.low64, hash.high64}; auto exists = [fp](const Bloom& b) { return b.Exists(fp); }; // Check for all the previous filters whether the item exists. if (any_of(next(filters_.crbegin()), filters_.crend(), exists)) { return false; } if (!filters_.back().Add(fp)) return false; ++current_size_; // Based on the paper, the optimal fill ratio for SBF is 50%. // Lets add a new slice if we reach it. if (current_size_ >= max_capacity_) { fp_prob_ *= kSBFErrorFactor; filters_.emplace_back().Init(max_capacity_ * grow_factor_, fp_prob_, filters_.get_allocator().resource()); current_size_ = 0; max_capacity_ = filters_.back().Capacity(fp_prob_); } return true; } bool SBF::Exists(std::string_view str) const { XXH128_hash_t hash = Hash(str); uint64_t fp[2] = {hash.low64, hash.high64}; auto exists = [fp](const Bloom& b) { return b.Exists(fp); }; return any_of(filters_.crbegin(), filters_.crend(), exists); } size_t SBF::MallocUsed() const { size_t res = filters_.capacity() * sizeof(Bloom); for (const auto& b : filters_) { res += (b.bitlen() / 8); } res += sizeof(SBF); return res; } } // namespace dfly ================================================ FILE: src/core/bloom.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "base/pmr/memory_resource.h" namespace dfly { /// Bloom filter based on the design of https://github.com/jvirkki/libbloom class Bloom { public: Bloom() = default; Bloom(const Bloom&) = delete; Bloom& operator=(const Bloom&) = delete; // Note, that Destroy() must be called before calling the d'tor ~Bloom(); // Initializes a new Bloom object // entries - entries are silently rounded up to the minimum capacity. // fp_prob - False-positive probability of collision. Must be in (0, 1) range. // heap void Init(uint64_t entries, double fp_prob, PMR_NS::memory_resource* resource); // Direct initializer. len*8 must be power of 2. void Init(uint8_t* blob, size_t len, unsigned hash_cnt); // Destroys the object, must be called before destructing the object. // resource - resource with which the object was initialized. void Destroy(PMR_NS::memory_resource* resource); Bloom(Bloom&& o) noexcept; bool Exists(std::string_view str) const; // Equivalent to the Exist above but accepts two fingerprints of the item. bool Exists(const uint64_t fp[2]) const; // Adds an item to the bloom filter. // Returns true if element was not present and was added, // false - if element (or a collision) had already been added previously. bool Add(std::string_view str); bool Add(const uint64_t fp[2]); size_t bitlen() const { return 1ULL << bit_log_; } // Max element capacity for this bloom filter. // Note that capacity is floor(bit_len / bpe), where bpe (bits per element) is // derived from fp_prob. size_t Capacity(double fp_prob) const; std::string_view data() const { return std::string_view{reinterpret_cast(bf_), bitlen() / 8}; } unsigned hash_cnt() const { return hash_cnt_; } private: bool IsSet(size_t index) const; bool Set(size_t index); // return true if bit was set (i.e was 0 before) uint8_t hash_cnt_ = 0; uint8_t bit_log_ = 0; // log of bit length of the filter. bit length is always power of 2. uint8_t* bf_ = nullptr; // pointer to the blob. }; /** * @brief Scalable bloom filter. * Based on https://gsd.di.uminho.pt/members/cbm/ps/dbloom.pdf * Please note that for SBF, the original paper assumes partitioning of bit space into K * disjoint segments where K is number of hash functions. This is done to reduce index collisions. * We do not do this, because we use power of 2 bit lengths. * TODO: to test the actual rate of this filter. */ class SBF { public: SBF(uint64_t initial_capacity, double fp_prob, double grow_factor, PMR_NS::memory_resource* mr); SBF(const SBF&) = delete; // C'tor used for loading persisted filters into SBF. // Should be followed by AddFilter. SBF(double grow_factor, double fp_prob, size_t max_capacity, size_t prev_size, size_t current_size, PMR_NS::memory_resource* mr); ~SBF(); SBF& operator=(SBF&& src) noexcept; void AddFilter(const std::string& blob, unsigned hash_cnt); bool Add(std::string_view str); bool Exists(std::string_view str) const; size_t current_size() const { return current_size_; } size_t prev_size() const { return prev_size_; } double grow_factor() const { return grow_factor_; } // expected fp probability for the current filter. double fp_probability() const { return fp_prob_; } uint32_t num_filters() const { return filters_.size(); } std::string_view data(size_t idx) const { return filters_[idx].data(); } unsigned hashfunc_cnt(size_t idx) const { return filters_[idx].hash_cnt(); } // max capacity of the current filter. size_t max_capacity() const { return max_capacity_; } size_t MallocUsed() const; private: // multiple filters from the smallest to the largest. std::vector> filters_; double grow_factor_; double fp_prob_; size_t prev_size_ = 0; size_t current_size_ = 0; size_t max_capacity_; }; } // namespace dfly ================================================ FILE: src/core/bloom_test.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/bloom.h" #include #include #include "base/gtest.h" namespace dfly { using namespace std; class BloomTest : public ::testing::Test { protected: BloomTest() { bloom_.Init(1000, 0.001, PMR_NS::get_default_resource()); } ~BloomTest() { bloom_.Destroy(PMR_NS::get_default_resource()); } Bloom bloom_; }; TEST_F(BloomTest, Basic) { EXPECT_FALSE(bloom_.Exists(string_view{})); EXPECT_TRUE(bloom_.Add(string_view{})); EXPECT_TRUE(bloom_.Exists(string_view{})); EXPECT_FALSE(bloom_.Add(string_view{})); vector values; for (unsigned i = 0; i < 100; ++i) { values.push_back(absl::StrCat("val", i)); } for (const auto& val : values) { EXPECT_FALSE(bloom_.Exists(val)); EXPECT_TRUE(bloom_.Add(val)); EXPECT_TRUE(bloom_.Exists(val)); EXPECT_FALSE(bloom_.Add(val)); } } TEST_F(BloomTest, ErrorBound) { size_t max_capacity = bloom_.Capacity(0.001); for (unsigned i = 0; i < max_capacity; ++i) { ASSERT_FALSE(bloom_.Exists(absl::StrCat("item", i))); } unsigned collisions = 0; for (unsigned i = 0; i < max_capacity; ++i) { if (!bloom_.Add(absl::StrCat("item", i))) { ++collisions; } } EXPECT_EQ(collisions, 0) << max_capacity; } TEST_F(BloomTest, Extreme) { Bloom b2; // Init with unreasonable large error probability. b2.Init(10, 0.999, PMR_NS::get_default_resource()); EXPECT_EQ(512, b2.bitlen()); // minimal bit length, even though requested smaller capacity. EXPECT_LT(b2.Capacity(0.999), 512); // make sure our element capacity is smaller. b2.Destroy(PMR_NS::get_default_resource()); } TEST_F(BloomTest, SBF) { SBF sbf(10, 0.001, 2, PMR_NS::get_default_resource()); unsigned collisions = 0; constexpr unsigned kNumElems = 2000000; for (unsigned i = 0; i < kNumElems; ++i) { if (!sbf.Add(absl::StrCat("item", i))) { ++collisions; } } // TODO: to revisit the math for deriving number of hash functions for each filter // according the the SBF paper. EXPECT_LE(collisions, kNumElems * 0.008); } static void BM_BloomExist(benchmark::State& state) { constexpr size_t kCapacity = 1U << 22; Bloom bloom; bloom.Init(kCapacity, 0.001, PMR_NS::get_default_resource()); for (size_t i = 0; i < kCapacity * 0.8; ++i) { bloom.Add(absl::StrCat("val", i)); } unsigned i = 0; char buf[32]; memset(buf, 'x', sizeof(buf)); string_view sv{buf, sizeof(buf)}; while (state.KeepRunning()) { absl::numbers_internal::FastIntToBuffer(i, buf); bloom.Exists(sv); } bloom.Destroy(PMR_NS::get_default_resource()); } BENCHMARK(BM_BloomExist); } // namespace dfly ================================================ FILE: src/core/bptree_set.h ================================================ // Copyright 2023, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "core/detail/bptree_internal.h" #include "core/detail/stateless_allocator.h" namespace dfly { template struct DefaultCompareTo { int operator()(const T& a, const T& b) const { std::less cmp; return cmp(a, b) ? -1 : (cmp(b, a) ? 1 : 0); } }; template struct BPTreePolicy { using KeyT = T; // The three way comparator that should accept a query ( or key) on the left, and the key // on the right. using KeyCompareTo = DefaultCompareTo; }; template > class BPTree { BPTree(const BPTree&) = delete; BPTree& operator=(const BPTree&) = delete; using BPTreeNode = detail::BPTreeNode; using BPTreePath = detail::BPTreePath; public: using KeyT = typename Policy::KeyT; BPTree(PMR_NS::memory_resource* mr = PMR_NS::get_default_resource()) : mr_(mr) { } ~BPTree() { Clear(); } // true if inserted, false if skipped. bool Insert(KeyT item); bool Contains(KeyT item) const; bool Delete(KeyT item); std::optional GetRank(KeyT item, bool reverse = false) const; size_t Height() const { return height_; } size_t Size() const { return count_; // number of items in the tree } bool Empty() const { return count_ == 0; } size_t NodeCount() const { // number of nodes in the tree (usually, order of magnitude smaller than Size()). return num_nodes_; } void Clear(); const BPTreeNode* DEBUG_root() const { return root_; } BPTreePath FromRank(uint32_t rank) const { BPTreePath path; ToRank(rank, &path); return path; } /// @brief Iterates over all items in the range [rank_start, rank_end] by rank. /// @param rank_start /// @param rank_end - inclusive. /// @param cb - callback to be called for each item in the range. /// Should return false to stop iteration. bool Iterate(uint32_t rank_start, uint32_t rank_end, std::function cb) const; /// @brief Iterates over all items in the range [rank_start, rank_end] by rank in reverse order. /// @param rank_start /// @param rank_end /// @param cb - callback to be called for each item in the range. /// Should return false to stop iteration. bool IterateReverse(uint32_t rank_start, uint32_t rank_end, std::function cb) const; /// @brief Returns the path to the first item in the tree for which comp(q, key) >= 0. /// @param item /// @return the path if such item exists, empty path otherwise. template BPTreePath GEQ(Q&& query) const; /// @brief Returns the path to the largest item in the tree such that comp(q, key) <= 0. /// @param key /// @return the path if such item exists, empty path otherwise. template BPTreePath LEQ(Q&& query) const; /// @brief Deletes the element pointed by path. /// @param path void Delete(BPTreePath path); /// @brief Forces an update to the key. Assumes key has the same value. /// Replaces old with new_obj. void ForceUpdate(KeyT old, KeyT new_obj); private: BPTreeNode* CreateNode(bool leaf); void DestroyNode(BPTreeNode* node); void InsertToFullLeaf(KeyT item, const BPTreePath& path); // Returns true if insertion was handled by rebalancing. bool RebalanceLeafAndInsert(const BPTreePath& path, unsigned parent_depth, KeyT item, unsigned insert_pos); void IncreaseSubtreeCounts(const BPTreePath& path, unsigned depth, int32_t delta); // Charts the path towards key. Returns true if key is found. // In that case comp(q, path->Last().first->Key(path->Last().second)) == 0. // Fills the tree path not including the key itself. In case key was not found, // returns the path to the item that is greater than the key. template bool Locate(Q&& q, BPTreePath* path) const; // Sets the tree path to item at specified rank. Rank is 0-based and must be less than Size(). // returns the index of the key in the last node of the path. void ToRank(uint32_t rank, BPTreePath* path) const; BPTreeNode* root_ = nullptr; // root node or NULL if empty tree uint32_t count_ = 0; // number of items in tree uint32_t height_ = 0; // height of tree from root to leaf uint32_t num_nodes_ = 0; // number of nodes in tree PMR_NS::memory_resource* mr_; }; template bool BPTree::Contains(KeyT item) const { BPTreePath path; bool found = Locate(item, &path); return found; } template void BPTree::Clear() { if (!root_) return; BPTreePath path; BPTreeNode* node = root_; auto deep_left = [&](unsigned pos) { do { path.Push(node, pos); node = node->Child(pos); pos = 0; } while (!node->IsLeaf()); }; if (!root_->IsLeaf()) deep_left(0); while (true) { DestroyNode(node); if (path.Depth() == 0) { break; } node = path.Last().first; unsigned pos = path.Last().second; path.Pop(); if (pos < node->NumItems()) { deep_left(pos + 1); } } root_ = nullptr; height_ = count_ = 0; } template bool BPTree::Insert(KeyT item) { if (!root_) { root_ = CreateNode(true); root_->InitSingle(item); count_ = height_ = 1; return true; } BPTreePath path; bool found = Locate(item, &path); if (found) { return false; } assert(path.Depth() > 0u); BPTreeNode* leaf = path.Last().first; assert(leaf->IsLeaf()); if (leaf->NumItems() == detail::BPNodeLayout::kMaxLeafKeys) { InsertToFullLeaf(item, path); } else { unsigned pos = path.Last().second; leaf->LeafInsert(pos, item); if (path.Depth() > 1) IncreaseSubtreeCounts(path, path.Depth() - 2, 1); } count_++; return true; } template bool BPTree::Delete(KeyT item) { if (!root_) return false; BPTreePath path; bool found = Locate(item, &path); if (!found) return false; Delete(path); return true; } template std::optional BPTree::GetRank(KeyT item, bool reverse) const { if (!root_) return std::nullopt; BPTreePath path; bool found = Locate(item, &path); if (!found) return std::nullopt; if (reverse) { return count_ - path.Rank() - 1; } return path.Rank(); } template template bool BPTree::Locate(Q&& q, BPTreePath* path) const { assert(root_); BPTreeNode* node = root_; typename Policy::KeyCompareTo cmp; auto cmp_cb = [&](const KeyT& key) { return cmp(q, key); }; while (true) { typename BPTreeNode::SearchResult res = node->BSearch(cmp_cb); path->Push(node, res.index); if (res.found) { return true; } assert(res.index <= node->NumItems()); if (node->IsLeaf()) { break; } node = node->Child(res.index); } return false; } template void BPTree::InsertToFullLeaf(KeyT item, const BPTreePath& path) { using Layout = detail::BPNodeLayout; using Comp [[maybe_unused]] = typename Policy::KeyCompareTo; assert(path.Depth() > 0u); BPTreeNode* node = path.Last().first; assert(node->IsLeaf() && node->AvailableSlotCount() == 0); unsigned insert_pos = path.Last().second; unsigned level = path.Depth() - 1; if (level > 0 && RebalanceLeafAndInsert(path, level - 1, item, insert_pos)) { // Update the tree count of the ascendants. IncreaseSubtreeCounts(path, level - 1, 1); return; } KeyT median; BPTreeNode* right = CreateNode(true); node->Split(right, &median); assert(node->NumItems() < Layout::kMaxLeafKeys); if (insert_pos <= node->NumItems()) { assert(Comp()(item, median) < 0); node->LeafInsert(insert_pos, item); } else { assert(Comp()(item, median) > 0); right->LeafInsert(insert_pos - node->NumItems() - 1, item); } // we must add the newly created `right` to the parent and update its tree count. while (level > 0) { --level; // level up, now node is parent. node = path.Node(level); unsigned pos = path.Position(level); // position of the child node in parent. assert(!node->IsLeaf() && pos <= node->NumItems()); assert(right); // Terminal case: Node is not full so we can just add `right` to it. if (node->NumItems() < Layout::kMaxInnerKeys) { // We do not update the subtree count of the node here because the surpus of another item // resulted with the additional key in this node. node->InnerInsert(pos, median, right); node->IncreaseTreeCount(1); right = nullptr; break; } // We need to insert right into a node as position pos. Node is full so we must handle it // either via rebalancing "node" or via its splitting. Rebalancing is a better case, we try // it first. if (level > 0) { // see if we can rebalance node (right's parent) via node's parent. BPTreeNode* parent = path.Node(level - 1); unsigned parent_pos = path.Position(level - 1); assert(parent->Child(parent_pos) == node); auto [new_node, inner_pos] = parent->RebalanceChild(parent_pos, pos); if (new_node) { // we rebalanced inner_full so we can insert (median, right) and stop propagating. new_node->InnerInsert(inner_pos, median, right); if (new_node != node) { // Fix subtree counts if right was migrated to the sibling. node->IncreaseTreeCount(-right->TreeCount()); new_node->IncreaseTreeCount(right->TreeCount() + 1); } else { node->IncreaseTreeCount(1); } right = nullptr; break; } } // node is not rebalanced, so we need to split it. BPTreeNode* next_right = CreateNode(false); KeyT next_median; node->Split(next_right, &next_median); assert(node->NumItems() < Layout::kMaxInnerKeys); if (pos <= node->NumItems()) { assert(Comp()(median, next_median) < 0); node->InnerInsert(pos, median, right); node->IncreaseTreeCount(1); } else { assert(Comp()(median, next_median) > 0); next_right->InnerInsert(pos - node->NumItems() - 1, median, right); // Fix tree counts. node->IncreaseTreeCount(-right->TreeCount()); next_right->IncreaseTreeCount(right->TreeCount() + 1); } right = next_right; median = next_median; } if (right) { assert(level == 0); BPTreeNode* new_root = CreateNode(false); new_root->InitSingle(median); new_root->SetChild(0, root_); new_root->SetChild(1, right); new_root->SetTreeCount(root_->TreeCount() + right->TreeCount() + 1); root_ = new_root; height_++; } else { if (level > 0) { IncreaseSubtreeCounts(path, level - 1, 1); } } } template bool BPTree::RebalanceLeafAndInsert(const BPTreePath& path, unsigned parent_depth, KeyT item, unsigned insert_pos) { BPTreeNode* parent = path.Node(parent_depth); unsigned pos = path.Position(parent_depth); std::pair rebalance_res = parent->RebalanceChild(pos, insert_pos); if (rebalance_res.first) { rebalance_res.first->LeafInsert(rebalance_res.second, item); return true; } return false; } template void BPTree::IncreaseSubtreeCounts(const BPTreePath& path, unsigned depth, int32_t delta) { for (int i = depth; i >= 0; --i) { BPTreeNode* node = path.Node(i); node->IncreaseTreeCount(delta); } } template bool BPTree::Iterate(uint32_t rank_start, uint32_t rank_end, std::function cb) const { if (rank_start >= Size()) return true; assert(rank_start <= rank_end); BPTreePath path; ToRank(rank_start, &path); for (uint32_t i = rank_start; i <= rank_end; ++i) { if (!cb(path.Terminal())) return false; if (!path.Next()) return true; } return true; } template bool BPTree::IterateReverse(uint32_t rank_start, uint32_t rank_end, std::function cb) const { assert(rank_start <= rank_end && rank_end < count_); BPTreePath path; ToRank(count_ - 1 - rank_start, &path); for (uint32_t i = rank_start; i <= rank_end; ++i) { if (!cb(path.Terminal())) return false; path.Prev(); } return true; } template void BPTree::ToRank(uint32_t rank, BPTreePath* path) const { assert(root_ && rank < count_); BPTreeNode* node = root_; if (rank + 1 == count_) { // Corner case where we search for the node on the right. while (!node->IsLeaf()) { path->Push(node, node->NumItems()); node = node->Child(node->NumItems()); } path->Push(node, node->NumItems() - 1); return; } while (!node->IsLeaf()) { // handle common corner case of search of left-most node, and avoid counting sub-tree count. if (rank == 0) { path->Push(node, 0); node = node->Child(0); continue; } for (unsigned i = 0; i <= node->NumItems(); ++i) { uint32_t subtree_cnt = node->GetChildTreeCount(i); if (subtree_cnt > rank) { path->Push(node, i); node = node->Child(i); break; } assert(i < node->NumItems()); rank -= subtree_cnt; if (rank == 0) { path->Push(node, i); return; } --rank; } } assert(node->IsLeaf()); assert(rank < node->NumItems()); path->Push(node, rank); } template template auto BPTree::GEQ(Q&& query) const -> BPTreePath { BPTreePath path; bool res = Locate(query, &path); // if we did not find the item and the path does not lead to any key in the node, // adjust the path to point to the next key in the tree. // In case we are past all items in the tree, Next() will collapse to the empty path. if (!res && path.Last().second >= path.Last().first->NumItems()) { path.Next(); } return path; } template template auto BPTree::LEQ(Q&& query) const -> BPTreePath { BPTreePath path; bool res = Locate(query, &path); if (!res) { // fix the result in case the path leads to key greater than item. path.Prev(); } return path; } template detail::BPTreeNode* BPTree::CreateNode(bool leaf) { num_nodes_++; void* ptr = mr_->allocate(detail::kBPNodeSize, 8); BPTreeNode* node = new (ptr) BPTreeNode(leaf); return node; } template void BPTree::Delete(BPTreePath path) { using Comp [[maybe_unused]] = typename Policy::KeyCompareTo; BPTreeNode* node = path.Last().first; unsigned key_pos = path.Last().second; // Remove the key from the node. if (node->IsLeaf()) { node->ShiftLeft(key_pos); // shift left everything after key_pos. } else { // We can not remove the item from the inner node because it also serves as a separator. // Therefore, we swap it the rightmost key in the left subtree and pop from there instead. path.DigRight(); BPTreeNode* leaf = path.Last().first; assert(Comp()(leaf->Key(leaf->NumItems() - 1), node->Key(key_pos)) < 0); // set a new separator. node->SetKey(key_pos, leaf->Key(leaf->NumItems() - 1)); leaf->LeafEraseRight(); // pop the rightmost key from the leaf. node = leaf; } count_--; assert(node->IsLeaf()); // go up the tree and rebalance if number of items in the node is less // than low limit. We either merge or rebalance nodes. while (node->NumItems() < node->MinItems()) { if (node == root_) { if (node->NumItems() == 0) { // terminal case, we reached the root - and it has either a single child (0 delimiters) // or no children at all (leaf). The former is more common case: the tree can only shrink // through the root. if (node->IsLeaf()) { assert(count_ == 0u); root_ = nullptr; } else { root_ = root_->Child(0); } --height_; DestroyNode(node); } return; } // The node has a parent. Pop the node from the path and try rebalance it via its parent. assert(path.Depth() > 0u); path.Pop(); BPTreeNode* parent = path.Last().first; unsigned pos = path.Last().second; assert(parent->Child(pos) == node); node = parent->MergeOrRebalanceChild(pos); parent->IncreaseTreeCount(-1); if (node == nullptr) // succeeded to merge/rebalance without the need to propagate. break; DestroyNode(node); // assert(parent->TreeCount() == parent->DEBUG_TreeCount()); node = parent; } if (path.Depth() >= 2) { IncreaseSubtreeCounts(path, path.Depth() - 2, -1); } } template void BPTree::DestroyNode(BPTreeNode* node) { void* ptr = node; mr_->deallocate(ptr, detail::kBPNodeSize, 8); num_nodes_--; } template void BPTree::ForceUpdate(KeyT old, KeyT new_obj) { BPTreePath path; [[maybe_unused]] bool found = Locate(old, &path); assert(path.Depth() > 0u); assert(found); BPTreeNode* node = path.Last().first; node->SetKey(path.Last().second, new_obj); } } // namespace dfly ================================================ FILE: src/core/bptree_set_test.cc ================================================ // Copyright 2023, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "core/bptree_set.h" #include #include #include #include extern "C" { #include "redis/sds.h" #include "redis/zmalloc.h" } #include "base/gtest.h" #include "base/init.h" #include "base/logging.h" #include "core/mi_memory_resource.h" using namespace std; namespace dfly { namespace { template bool ValidateNode(const Node* node, typename Node::KeyT ubound) { typename Policy::KeyCompareTo cmp; for (unsigned i = 1; i < node->NumItems(); ++i) { if (cmp(node->Key(i - 1), node->Key(i)) > -1) return false; } if (!node->IsLeaf()) { unsigned mask = 0; uint32_t subtree_cnt = node->NumItems(); for (unsigned i = 0; i <= node->NumItems(); ++i) { mask |= (1 << node->Child(i)->IsLeaf()); DCHECK_EQ(node->Child(i)->DEBUG_TreeCount(), node->Child(i)->TreeCount()); subtree_cnt += node->Child(i)->TreeCount(); } if (mask == 3) return false; if (subtree_cnt != node->TreeCount()) { LOG(ERROR) << "Expected " << subtree_cnt << " got " << node->TreeCount(); return false; } } return cmp(node->Key(node->NumItems() - 1), ubound) == -1; } struct ZsetPolicy { struct KeyT { double d; sds s; }; struct KeyCompareTo { int operator()(const KeyT& left, const KeyT& right) { if (left.d < right.d) return -1; if (left.d > right.d) return 1; // Note that sdscmp can return values outside of [-1, 1] range. return sdscmp(left.s, right.s); } }; }; using SDSTree = BPTree; } // namespace class BPTreeSetTest : public ::testing::Test { using Node = detail::BPTreeNode; protected: static constexpr size_t kNumElems = 7000; BPTreeSetTest() : mi_alloc_(mi_heap_get_backing()), bptree_(&mi_alloc_) { } static void SetUpTestSuite() { } void FillTree(unsigned start, unsigned factor) { for (unsigned i = start; i < kNumElems; ++i) { bptree_.Insert(i * factor); } } void FillTree(unsigned factor = 1) { FillTree(0, factor); } bool Validate(); MiMemoryResource mi_alloc_; BPTree bptree_; mt19937 generator_{1}; }; bool BPTreeSetTest::Validate() { auto* root = bptree_.DEBUG_root(); if (!root) return true; // node, upper bound vector> stack; stack.emplace_back(root, UINT64_MAX); while (!stack.empty()) { const Node* node = stack.back().first; uint64_t ubound = stack.back().second; stack.pop_back(); if (!ValidateNode>(node, ubound)) return false; if (!node->IsLeaf()) { for (unsigned i = 0; i < node->NumItems(); ++i) { stack.emplace_back(node->Child(i), node->Key(i)); } stack.emplace_back(node->Child(node->NumItems()), ubound); } } return true; } TEST_F(BPTreeSetTest, BPtreeInsert) { for (unsigned i = 1; i < 7000; ++i) { ASSERT_TRUE(bptree_.Insert(i)); ASSERT_EQ(i, bptree_.Size()); ASSERT_EQ(i - 1, bptree_.GetRank(i)); // ASSERT_TRUE(Validate()) << i; } ASSERT_TRUE(Validate()); ASSERT_GT(mi_alloc_.used(), 56000u); ASSERT_LT(mi_alloc_.used(), 66000u); for (unsigned i = 1; i < 7000; ++i) { ASSERT_TRUE(bptree_.Contains(i)); } bptree_.Clear(); ASSERT_EQ(mi_alloc_.used(), 0u); uniform_int_distribution dist(0, 100000); for (unsigned i = 0; i < 20000; ++i) { bptree_.Insert(dist(generator_)); // ASSERT_TRUE(Validate()) << i; } ASSERT_TRUE(Validate()); ASSERT_GT(mi_alloc_.used(), 10000u); LOG(INFO) << bptree_.Height() << " " << bptree_.Size(); bptree_.Clear(); ASSERT_EQ(mi_alloc_.used(), 0u); for (unsigned i = 20000; i > 1; --i) { bptree_.Insert(i); } ASSERT_TRUE(Validate()); for (unsigned i = 2; i <= 20000; ++i) { ASSERT_EQ(i - 2, bptree_.GetRank(i)); } LOG(INFO) << bptree_.Height() << " " << bptree_.Size(); ASSERT_GT(mi_alloc_.used(), 20000 * 8); ASSERT_LT(mi_alloc_.used(), 20000 * 10); bptree_.Clear(); ASSERT_EQ(mi_alloc_.used(), 0u); } TEST_F(BPTreeSetTest, Delete) { for (unsigned i = 31; i > 10; --i) { bptree_.Insert(i); } for (unsigned i = 1; i < 10; ++i) { ASSERT_FALSE(bptree_.Delete(i)); } for (unsigned i = 11; i < 32; ++i) { ASSERT_TRUE(bptree_.Delete(i)); } ASSERT_EQ(mi_alloc_.used(), 0u); ASSERT_EQ(bptree_.Size(), 0u); FillTree(); ASSERT_GT(bptree_.NodeCount(), 2u); unsigned sz = bptree_.Size(); for (unsigned i = 0; i < kNumElems; ++i) { --sz; ASSERT_EQ(bptree_.GetRank(kNumElems - 1), sz); ASSERT_TRUE(bptree_.Delete(i)); ASSERT_EQ(bptree_.Size(), sz); // ASSERT_TRUE(Validate()) << i; } ASSERT_EQ(mi_alloc_.used(), 0u); ASSERT_EQ(bptree_.Size(), 0u); ASSERT_EQ(bptree_.Height(), 0u); ASSERT_EQ(bptree_.NodeCount(), 0u); FillTree(2); for (unsigned i = 0; i < 20000; ++i) { unsigned val = generator_() % 15000; bool res = bptree_.Delete(val); if (val % 2 == 1) { ASSERT_FALSE(res); } if (res) { ASSERT_TRUE(Validate()); } } } TEST_F(BPTreeSetTest, Iterate) { FillTree(2); unsigned cnt = 0; bool res = bptree_.Iterate(31, 543, [&](uint64_t val) { if ((31 + cnt) * 2 != val) return false; ++cnt; return true; }); ASSERT_EQ(543 - 31 + 1, cnt); ASSERT_TRUE(res); for (unsigned j = 0; j < 10; ++j) { cnt = 0; unsigned from = generator_() % kNumElems; unsigned to = from + generator_() % (kNumElems - from); res = bptree_.Iterate(from, to, [&](uint64_t val) { if ((from + cnt) * 2 != val) return false; ++cnt; return true; }); ASSERT_EQ(to - from + 1, cnt); ASSERT_TRUE(res); } } TEST_F(BPTreeSetTest, Ranges) { FillTree(2); auto path = bptree_.GEQ(31); EXPECT_EQ(32, path.Terminal()); path = bptree_.GEQ(32); EXPECT_EQ(32, path.Terminal()); path = bptree_.GEQ(13998); EXPECT_EQ(13998, path.Terminal()); path = bptree_.LEQ(14000); EXPECT_EQ(13998, path.Terminal()); path = bptree_.GEQ(14000); EXPECT_EQ(0, path.Depth()); ASSERT_TRUE(bptree_.Delete(0)); path = bptree_.GEQ(0); EXPECT_EQ(2, path.Terminal()); path = bptree_.LEQ(1); EXPECT_TRUE(path.Empty()); } TEST_F(BPTreeSetTest, HalfRanges) { FillTree(1, 3); // 3, 6, 9 ... auto path = bptree_.FromRank(bptree_.Size() - 1); uint64_t val = path.Terminal(); for (unsigned i = 0; i <= val; ++i) { path = bptree_.GEQ(i); ASSERT_FALSE(path.Empty()) << i; } path = bptree_.GEQ(val + 1); ASSERT_TRUE(path.Empty()); for (unsigned i = 3; i <= val + 10; ++i) { path = bptree_.LEQ(i); ASSERT_FALSE(path.Empty()) << i; } path = bptree_.LEQ(2); ASSERT_TRUE(path.Empty()); } #if 0 TEST_F(BPTreeSetTest, MemoryUsage) { zskiplist* zsl = zslCreate(); std::vector sds_vec; constexpr size_t kLength = 3000; for (size_t i = 0; i < kLength; ++i) { sds_vec.push_back(sdsnew("f")); } size_t sz_before = zmalloc_used_memory_tl; LOG(INFO) << "zskiplist before: " << sz_before << " bytes"; for (size_t i = 0; i < sds_vec.size(); ++i) { zslInsert(zsl, i, sds_vec[i]); } LOG(INFO) << "zskiplist takes: " << double(zmalloc_used_memory_tl - sz_before) / sds_vec.size() << " bytes per entry"; zslFree(zsl); sds_vec.clear(); for (size_t i = 0; i < kLength; ++i) { sds_vec.push_back(sdsnew("f")); } MiMemoryResource mi_alloc(mi_heap_get_backing()); using AllocType = PMR_NS::polymorphic_allocator>; AllocType alloc(&mi_alloc); absl::btree_set, std::greater>, AllocType> btree(alloc); ASSERT_EQ(0, mi_alloc.used()); for (size_t i = 0; i < sds_vec.size(); ++i) { btree.emplace(i, sds_vec[i]); } ASSERT_GT(mi_alloc.used(), 0u); LOG(INFO) << "abseil btree: " << double(mi_alloc.used()) / sds_vec.size() << " bytes per entry"; btree.clear(); ASSERT_EQ(0, mi_alloc.used()); SDSTree df_tree(&mi_alloc); for (size_t i = 0; i < sds_vec.size(); ++i) { btree.emplace(i, sds_vec[i]); VLOG(1) << "df btree: " << i << " " << double(mi_alloc.used()) / btree.size() << " bytes per entry"; } ASSERT_GT(mi_alloc.used(), 0u); LOG(INFO) << "df btree: " << double(mi_alloc.used()) / sds_vec.size() << " bytes per entry"; } #endif TEST_F(BPTreeSetTest, InsertSDS) { vector vals; for (unsigned i = 0; i < 256; ++i) { sds s = sdsempty(); s = sdscatfmt(s, "a%u", i); vals.emplace_back(ZsetPolicy::KeyT{.d = 1000, .s = s}); } SDSTree tree(&mi_alloc_); for (size_t i = 0; i < vals.size(); ++i) { ASSERT_TRUE(tree.Insert(vals[i])); } for (auto v : vals) { sdsfree(v.s); } } TEST_F(BPTreeSetTest, ReverseIterate) { vector vals; for (int i = -1000; i < 1000; ++i) { sds s = sdsempty(); s = sdscatfmt(s, "a%u", i); vals.emplace_back(ZsetPolicy::KeyT{.d = (double)i, .s = s}); } SDSTree tree(&mi_alloc_); for (auto v : vals) { ASSERT_TRUE(tree.Insert(v)); { double score = 0; tree.IterateReverse(0, 0, [&score](auto i) { score = i.d; return false; }); EXPECT_EQ(score, v.d); } { double score = 0; tree.Iterate(0, 0, [&score](auto i) { score = i.d; return false; }); EXPECT_EQ(score, vals[0].d); } } vector res; tree.IterateReverse(0, 1, [&](auto i) { res.push_back(i.d); return true; }); EXPECT_THAT(res, testing::ElementsAre(999, 998)); for (auto v : vals) { sdsfree(v.s); } } static string RandomString(mt19937& rand, unsigned len) { const string_view alpanum = "1234567890abcdefghijklmnopqrstuvwxyz"; string ret; ret.reserve(len); for (size_t i = 0; i < len; ++i) { ret += alpanum[rand() % alpanum.size()]; } return ret; } std::vector GenerateRandomPairs(unsigned len) { mt19937 dre(10); std::vector vals(len, ZsetPolicy::KeyT{}); for (unsigned i = 0; i < len; ++i) { vals[i].d = dre(); vals[i].s = sdsnew(RandomString(dre, 10).c_str()); } return vals; } static void BM_FindRandomBPTree(benchmark::State& state) { unsigned iters = state.range(0); std::vector vals = GenerateRandomPairs(iters); SDSTree bptree; for (unsigned i = 0; i < iters; ++i) { bptree.Insert(vals[i]); } unsigned i = 0; while (state.KeepRunningBatch(10)) { for (unsigned j = 0; j < 10; ++j) { benchmark::DoNotOptimize(bptree.GEQ(vals[i])); ++i; if (vals.size() == i) i = 0; } } for (const auto v : vals) { sdsfree(v.s); } } BENCHMARK(BM_FindRandomBPTree)->Arg(1024)->Arg(1 << 16)->Arg(1 << 20); #if 0 static void BM_FindRandomZSL(benchmark::State& state) { zskiplist* zsl = zslCreate(); unsigned iters = state.range(0); std::vector vals = GenerateRandomPairs(iters); for (unsigned i = 0; i < iters; ++i) { zslInsert(zsl, vals[i].d, sdsdup(vals[i].s)); } zrangespec spec; spec.maxex = 0; spec.minex = 0; unsigned i = 0; while (state.KeepRunningBatch(10)) { for (unsigned j = 0; j < 10; ++j) { spec.min = vals[i].d; spec.max = spec.min; benchmark::DoNotOptimize(zslFirstInRange(zsl, &spec)); ++i; if (vals.size() == i) i = 0; } } zslFree(zsl); for (const auto v : vals) { sdsfree(v.s); } } BENCHMARK(BM_FindRandomZSL)->Arg(1024)->Arg(1 << 16)->Arg(1 << 20); #endif void RegisterBPTreeBench() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); }; REGISTER_MODULE_INITIALIZER(Bptree, RegisterBPTreeBench()); TEST_F(BPTreeSetTest, ForceUpdate) { struct Policy { // Similar to how it's used in SortedMap just a little simpler. using KeyT = int*; struct KeyCompareTo { int operator()(KeyT a, KeyT b) const { if (*a < *b) return -1; if (*a > *b) return 1; return 0; } }; }; auto gen_vector = []() { std::vector> tmp; for (size_t i = 0; i < 1000; ++i) { tmp.push_back(std::make_unique(i)); } return tmp; }; std::vector> original = gen_vector(); std::vector> modified = gen_vector(); BPTree bptree; for (auto& item : original) { bptree.Insert(item.get()); } for (auto& item : modified) { bptree.ForceUpdate(item.get(), item.get()); } original.clear(); size_t index = 0; bptree.Iterate(0, 1000, [&](int* ptr) { EXPECT_EQ(modified[index].get(), ptr); ++index; return true; }); } } // namespace dfly ================================================ FILE: src/core/cms.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/cms.h" #include #include #include #include #include "base/logging.h" namespace dfly { namespace { uint32_t Offset(uint64_t h1, uint64_t h2, uint32_t row, uint32_t width) { uint32_t idx = static_cast((h1 + (row * h2)) % width); return row * width + idx; } } // namespace CMS::CMS(uint32_t width, uint32_t depth, PMR_NS::memory_resource* mr) : width_(width), depth_(depth), mr_(mr) { size_t len = NumCounters(); counters_ = static_cast(mr_->allocate(len * sizeof(int64_t), alignof(int64_t))); std::fill_n(counters_, len, 0); } CMS::~CMS() { if (counters_) { mr_->deallocate(counters_, NumCounters() * sizeof(int64_t), alignof(int64_t)); } } CMS::CMS(CMS&& other) noexcept : width_(other.width_), depth_(other.depth_), mr_(other.mr_), count_(other.count_), counters_(other.counters_) { other.width_ = 0; other.depth_ = 0; other.count_ = 0; other.counters_ = nullptr; } CMS& CMS::operator=(CMS&& other) noexcept { if (this != &other) { if (counters_) { mr_->deallocate(counters_, NumCounters() * sizeof(int64_t), alignof(int64_t)); } width_ = other.width_; depth_ = other.depth_; mr_ = other.mr_; count_ = other.count_; counters_ = other.counters_; other.width_ = 0; other.depth_ = 0; other.count_ = 0; other.counters_ = nullptr; } return *this; } CMS::CMS(ErrorRateTag /*tag*/, double error, double probability, PMR_NS::memory_resource* mr) : CMS(static_cast(std::ceil(M_E / error)), static_cast(std::ceil(std::log(1.0 / probability))), mr) { } int64_t CMS::IncrBy(std::string_view item, int64_t increment) { count_ += increment; int64_t min_count = std::numeric_limits::max(); XXH128_hash_t hash = XXH3_128bits(item.data(), item.size()); uint64_t h1 = hash.low64; uint64_t h2 = hash.high64; for (uint32_t row = 0; row < depth_; ++row) { uint32_t offset = Offset(h1, h2, row, width_); counters_[offset] += increment; min_count = std::min(min_count, counters_[offset]); } return min_count; } int64_t CMS::Query(std::string_view item) const { XXH128_hash_t hash = XXH3_128bits(item.data(), item.size()); uint64_t h1 = hash.low64; uint64_t h2 = hash.high64; int64_t min_count = std::numeric_limits::max(); for (uint32_t row = 0; row < depth_; ++row) { uint32_t offset = Offset(h1, h2, row, width_); min_count = std::min(min_count, counters_[offset]); } return min_count; } bool CMS::MergeFrom(const CMS& other, int64_t weight) { if (width_ != other.width_ || depth_ != other.depth_) { return false; } for (size_t i = 0; i < NumCounters(); ++i) { counters_[i] += other.counters_[i] * weight; } count_ += other.count_ * weight; return true; } void CMS::Reset() { std::fill_n(counters_, NumCounters(), 0); count_ = 0; } void CMS::Load(int64_t total_incr_count, const int64_t* data) { count_ = total_incr_count; std::copy_n(data, NumCounters(), counters_); } } // namespace dfly ================================================ FILE: src/core/cms.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "base/pmr/memory_resource.h" namespace dfly { /// Count-Min Sketch implementation compatible with Redis CMS commands. class CMS { public: // Create a CMS with given width and depth dimensions. // width: number of counters per row // depth: number of rows (hash functions) CMS(uint32_t width, uint32_t depth, PMR_NS::memory_resource* mr); CMS(const CMS&) = delete; CMS& operator=(const CMS&) = delete; CMS(CMS&& other) noexcept; CMS& operator=(CMS&& other) noexcept; ~CMS(); // Tag type to disambiguate CMS construction by error rate and probability. struct ErrorRateTag {}; // Create a CMS from error rate and probability parameters. // error: relative error (e.g. 0.01 for 1%), must be in (0, 1). // probability: probability of exceeding the error, must be in (0, 1). // width = ceil(e / error), depth = ceil(ln(1 / probability)). CMS(ErrorRateTag, double error, double probability, PMR_NS::memory_resource* mr); // Increment the count for an item by the given value. // Returns the new estimated count for the item. int64_t IncrBy(std::string_view item, int64_t increment); // Query the estimated count for an item. int64_t Query(std::string_view item) const; // Merge another CMS into this one with the given weight. // The other CMS must have the same dimensions. // Returns false if dimensions don't match. bool MergeFrom(const CMS& other, int64_t weight = 1); // Reset all counters and total count to zero. void Reset(); // Load serialized counter state. data must have exactly NumCounters() elements. void Load(int64_t total_incr_count, const int64_t* data); // Accessors for CMS properties uint32_t width() const { return width_; } uint32_t depth() const { return depth_; } // Total count of all IncrBy operations (used by CMS.INFO). int64_t total_count() const { return count_; } // Memory usage in bytes size_t MallocUsed() const { return NumCounters() * sizeof(int64_t); } size_t NumCounters() const { return static_cast(width_) * depth_; } const int64_t* Data() const { return counters_; } private: uint32_t width_; uint32_t depth_; PMR_NS::memory_resource* mr_ = nullptr; int64_t count_ = 0; // Total count of all IncrBy operations int64_t* counters_ = nullptr; }; } // namespace dfly ================================================ FILE: src/core/cms_test.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/cms.h" #include #include #include "base/gtest.h" namespace dfly { using namespace std; class CMSTest : public ::testing::Test { protected: CMSTest() : cms_(CMS(1000, 5, PMR_NS::get_default_resource())) { } CMS cms_; }; // A freshly created CMS must return 0 for any item. TEST_F(CMSTest, InitialCountIsZero) { EXPECT_EQ(cms_.Query("nonexistent"), 0); EXPECT_EQ(cms_.Query(""), 0); EXPECT_EQ(cms_.Query("anything"), 0); } // Use width=1 so every item maps to column 0, exercising all counters. // This catches initialization bugs (e.g. counters not zeroed). TEST(CMSBasic, InitialCountIsZeroSmall) { CMS cms(1, 1, PMR_NS::get_default_resource()); EXPECT_EQ(cms.Query("x"), 0); EXPECT_EQ(cms.Query("y"), 0); } TEST(CMSBasic, IncrBySmall) { CMS cms(1, 1, PMR_NS::get_default_resource()); EXPECT_EQ(cms.IncrBy("a", 3), 3); // width=1 means all items collide; "b" should also return 3. EXPECT_EQ(cms.Query("b"), 3); } // Inspired by fakeredis test_cms_create: initbyprob computes correct dimensions. TEST(CMSBasic, InitByProb) { CMS cms(CMS::ErrorRateTag{}, 0.01, 0.01, PMR_NS::get_default_resource()); // width = ceil(e / 0.01) = ceil(271.8..) = 272 EXPECT_EQ(cms.width(), static_cast(std::ceil(M_E / 0.01))); // depth = ceil(ln(1/0.01)) = ceil(4.605..) = 5 EXPECT_EQ(cms.depth(), static_cast(std::ceil(std::log(100.0)))); EXPECT_EQ(cms.Query("anything"), 0); } // Inspired by fakeredis test_cms_incrby: multiple items, incremental updates. TEST_F(CMSTest, IncrByMultipleItems) { EXPECT_EQ(cms_.IncrBy("foo", 3), 3); cms_.IncrBy("foo", 4); cms_.IncrBy("bar", 1); EXPECT_GE(cms_.Query("foo"), 7); EXPECT_GE(cms_.Query("bar"), 1); EXPECT_EQ(cms_.Query("noexist"), 0); } TEST_F(CMSTest, BasicIncrBy) { int64_t count = cms_.IncrBy("foo", 5); EXPECT_EQ(count, 5); count = cms_.IncrBy("foo", 3); EXPECT_EQ(count, 8); EXPECT_EQ(cms_.Query("foo"), 8); } TEST_F(CMSTest, QueryReturnsMinimum) { cms_.IncrBy("a", 10); cms_.IncrBy("b", 20); // CMS can overestimate, but never underestimate. EXPECT_GE(cms_.Query("a"), 10); EXPECT_GE(cms_.Query("b"), 20); } TEST_F(CMSTest, NeverUnderestimates) { for (int i = 0; i < 500; ++i) { string key = absl::StrCat("item", i); cms_.IncrBy(key, i + 1); } for (int i = 0; i < 500; ++i) { string key = absl::StrCat("item", i); EXPECT_GE(cms_.Query(key), i + 1) << "Underestimate for " << key; } } TEST_F(CMSTest, UnseenItemIsZero) { cms_.IncrBy("known", 100); // With width=1000 and depth=5 and only one item inserted, collisions are unlikely. EXPECT_LE(cms_.Query("unknown"), 5); } TEST_F(CMSTest, Dimensions) { EXPECT_EQ(cms_.width(), 1000u); EXPECT_EQ(cms_.depth(), 5u); } TEST_F(CMSTest, MallocUsed) { EXPECT_EQ(cms_.MallocUsed(), 1000u * 5 * sizeof(int64_t)); } // Inspired by fakeredis test_cms_merge: basic merge of two sketches. TEST_F(CMSTest, MergeFrom) { CMS other(1000, 5, PMR_NS::get_default_resource()); cms_.IncrBy("foo", 3); other.IncrBy("foo", 4); other.IncrBy("bar", 1); EXPECT_TRUE(cms_.MergeFrom(other)); EXPECT_GE(cms_.Query("foo"), 7); EXPECT_GE(cms_.Query("bar"), 1); } TEST_F(CMSTest, MergeFromWithWeight) { CMS other(1000, 5, PMR_NS::get_default_resource()); other.IncrBy("x", 5); cms_.IncrBy("x", 10); EXPECT_TRUE(cms_.MergeFrom(other, 3)); // 10 + 5*3 = 25 EXPECT_GE(cms_.Query("x"), 25); } TEST_F(CMSTest, MergeDimensionMismatch) { CMS other(500, 5, PMR_NS::get_default_resource()); EXPECT_FALSE(cms_.MergeFrom(other)); CMS other2(1000, 3, PMR_NS::get_default_resource()); EXPECT_FALSE(cms_.MergeFrom(other2)); } // Inspired by fakeredis test_cms_info: merge multiple sources with weights, verify counts. // Mirrors the exact sequence: C=A+B, C+=A*1+B*2, C+=A*2+B*3, then check info.count. TEST(CMSBasic, MergeMultipleWithWeights) { auto* mr = PMR_NS::get_default_resource(); CMS a(1000, 5, mr); CMS b(1000, 5, mr); CMS c(1000, 5, mr); a.IncrBy("foo", 5); a.IncrBy("bar", 3); a.IncrBy("baz", 9); b.IncrBy("foo", 2); b.IncrBy("bar", 3); b.IncrBy("baz", 1); EXPECT_EQ(a.Query("foo"), 5); EXPECT_EQ(a.Query("bar"), 3); EXPECT_EQ(a.Query("baz"), 9); EXPECT_EQ(b.Query("foo"), 2); EXPECT_EQ(b.Query("bar"), 3); EXPECT_EQ(b.Query("baz"), 1); // C = A*1 + B*1 EXPECT_TRUE(c.MergeFrom(a)); EXPECT_TRUE(c.MergeFrom(b)); EXPECT_EQ(c.Query("foo"), 7); EXPECT_EQ(c.Query("bar"), 6); EXPECT_EQ(c.Query("baz"), 10); // C += A*1 + B*2 EXPECT_TRUE(c.MergeFrom(a, 1)); EXPECT_TRUE(c.MergeFrom(b, 2)); EXPECT_EQ(c.Query("foo"), 16); EXPECT_EQ(c.Query("bar"), 15); EXPECT_EQ(c.Query("baz"), 21); // C += A*2 + B*3 EXPECT_TRUE(c.MergeFrom(a, 2)); EXPECT_TRUE(c.MergeFrom(b, 3)); EXPECT_EQ(c.Query("foo"), 32); EXPECT_EQ(c.Query("bar"), 30); EXPECT_EQ(c.Query("baz"), 42); } // Inspired by fakeredis test_cms_info: verify count tracks total of all IncrBy operations. TEST(CMSBasic, CountTracking) { auto* mr = PMR_NS::get_default_resource(); CMS a(1000, 5, mr); EXPECT_EQ(a.total_count(), 0); a.IncrBy("foo", 5); a.IncrBy("bar", 3); a.IncrBy("baz", 9); // total_count = 5 + 3 + 9 = 17 (matches fakeredis test_cms_info assertion) EXPECT_EQ(a.total_count(), 17); } // Inspired by fakeredis test_cms_info: count is updated by MergeFrom. TEST(CMSBasic, CountAfterMerge) { auto* mr = PMR_NS::get_default_resource(); CMS a(1000, 5, mr); CMS b(1000, 5, mr); CMS c(1000, 5, mr); a.IncrBy("foo", 5); a.IncrBy("bar", 3); a.IncrBy("baz", 9); EXPECT_EQ(a.total_count(), 17); b.IncrBy("foo", 2); b.IncrBy("bar", 3); b.IncrBy("baz", 1); EXPECT_EQ(b.total_count(), 6); // C = A + B -> total_count = 17 + 6 = 23 c.MergeFrom(a); c.MergeFrom(b); EXPECT_EQ(c.total_count(), 23); // C += A*1 + B*2 -> total_count = 23 + 17*1 + 6*2 = 52 // (matches fakeredis test_cms_merge_fail assertion: count == 52) c.MergeFrom(a, 1); c.MergeFrom(b, 2); EXPECT_EQ(c.total_count(), 52); } TEST_F(CMSTest, MoveConstruct) { cms_.IncrBy("foo", 42); CMS moved(std::move(cms_)); EXPECT_EQ(moved.Query("foo"), 42); EXPECT_EQ(moved.width(), 1000u); EXPECT_EQ(moved.depth(), 5u); } TEST_F(CMSTest, MoveAssign) { cms_.IncrBy("foo", 42); CMS other(500, 3, PMR_NS::get_default_resource()); other = std::move(cms_); EXPECT_EQ(other.Query("foo"), 42); EXPECT_EQ(other.width(), 1000u); EXPECT_EQ(other.depth(), 5u); } } // namespace dfly ================================================ FILE: src/core/collection_entry.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include namespace dfly { // Stores either: // - A single long long value (longval) when value = nullptr // - A single char* (value) when value != nullptr struct CollectionEntry { CollectionEntry(const char* value, size_t length) : value_{value}, length_{length} { } explicit CollectionEntry(long long longval) : value_{nullptr}, longval_{longval} { } CollectionEntry(const CollectionEntry&) = default; CollectionEntry& operator=(const CollectionEntry&) = default; std::string ToString() const { if (value_) return {value_, length_}; else return absl::StrCat(longval_); } bool IsString() const { return value_ != nullptr; } bool is_int() const { return value_ == nullptr; } const char* data() const { return value_; } size_t size() const { return length_; } long long as_long() const { return longval_; } // Assumes value is not null. std::string_view view() const { return {value_, length_}; } // compatibility method std::string to_string() const { return ToString(); } // compatibility method long long ival() const { return longval_; } bool operator==(std::string_view sv) const; friend bool operator==(std::string_view sv, const CollectionEntry& entry) { return entry == sv; } private: const char* value_; union { size_t length_; long long longval_; }; }; inline bool CollectionEntry::operator==(std::string_view sv) const { if (value_ == nullptr) { char buf[absl::numbers_internal::kFastToBufferSize]; char* end = absl::numbers_internal::FastIntToBuffer(longval_, buf); return sv == std::string_view(buf, end - buf); } return view() == sv; } } // namespace dfly ================================================ FILE: src/core/compact_object.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/compact_object.h" // #define XXH_INLINE_ALL #include #include extern "C" { #include "redis/intset.h" #include "redis/listpack.h" #include "redis/redis_aux.h" #include "redis/sds.h" #include "redis/stream.h" #include "redis/util.h" #include "redis/zmalloc.h" // for non-string objects. } #include #include #include "base/flags.h" #include "base/logging.h" #include "base/pod_array.h" #include "core/bloom.h" #include "core/cms.h" #include "core/detail/bitpacking.h" #include "core/huff_coder.h" #include "core/page_usage/page_usage_stats.h" #include "core/qlist.h" #include "core/sorted_map.h" #include "core/string_map.h" #include "core/string_set.h" #include "core/tiering_types.h" #include "core/topk.h" ABSL_FLAG(bool, experimental_flat_json, false, "If true uses flat json implementation."); ABSL_FLAG(bool, disable_json_defragmentation, false, "If true disable json object defragmentation"); namespace dfly { using namespace std; using detail::ascii_len; using detail::binpacked_len; using MemoryResource = detail::RobjWrapper::MemoryResource; namespace { constexpr XXH64_hash_t kHashSeed = 24061983; constexpr size_t kAlignSize = 8u; size_t UpdateSize(size_t size, int64_t update) { int64_t result = static_cast(size) + update; if (result < 0) { DCHECK(false) << "Can't decrease " << size << " from " << -update; LOG_EVERY_T(ERROR, 30) << "Can't decrease " << size << " from " << -update; } return result; } inline void FreeObjSet(unsigned encoding, void* ptr, MemoryResource* mr) { switch (encoding) { case kEncodingStrMap2: { CompactObj::DeleteMR(ptr); break; } case kEncodingIntSet: zfree((void*)ptr); break; default: LOG(FATAL) << "Unknown set encoding type"; } } void FreeList(unsigned encoding, void* ptr, MemoryResource* mr) { if (encoding == kEncodingListPack) { lpFree((uint8_t*)ptr); return; } CHECK_EQ(encoding, kEncodingQL2); CompactObj::DeleteMR(ptr); } size_t MallocUsedSet(unsigned encoding, void* ptr) { switch (encoding) { case kEncodingStrMap2: { StringSet* ss = (StringSet*)ptr; return ss->ObjMallocUsed() + ss->SetMallocUsed() + zmalloc_usable_size(ptr); } case kEncodingIntSet: return intsetBlobLen((intset*)ptr); } LOG(DFATAL) << "Unknown set encoding type " << encoding; return 0; } size_t MallocUsedHSet(unsigned encoding, void* ptr) { switch (encoding) { case kEncodingListPack: return zmalloc_usable_size(reinterpret_cast(ptr)); case kEncodingStrMap2: { StringMap* sm = (StringMap*)ptr; return sm->ObjMallocUsed() + sm->SetMallocUsed() + zmalloc_usable_size(ptr); } } LOG(DFATAL) << "Unknown set encoding type " << encoding; return 0; } size_t MallocUsedZSet(unsigned encoding, void* ptr) { switch (encoding) { case OBJ_ENCODING_LISTPACK: return zmalloc_usable_size(reinterpret_cast(ptr)); case OBJ_ENCODING_SKIPLIST: { detail::SortedMap* ss = (detail::SortedMap*)ptr; return ss->MallocSize() + zmalloc_usable_size(ptr); // DictMallocSize(zs->dict); } } LOG(DFATAL) << "Unknown set encoding type " << encoding; return 0; } /* This is a helper function with the goal of estimating the memory * size of a radix tree that is used to store Stream IDs. * * Note: to guess the size of the radix tree is not trivial, so we * approximate it considering 16 bytes of data overhead for each * key (the ID), and then adding the number of bare nodes, plus some * overhead due by the data and child pointers. This secret recipe * was obtained by checking the average radix tree created by real * workloads, and then adjusting the constants to get numbers that * more or less match the real memory usage. * * Actually the number of nodes and keys may be different depending * on the insertion speed and thus the ability of the radix tree * to compress prefixes. */ size_t streamRadixTreeMemoryUsage(rax* rax) { size_t size = sizeof(*rax); size = rax->numele * sizeof(streamID); size += rax->numnodes * sizeof(raxNode); /* Add a fixed overhead due to the aux data pointer, children, ... */ size += rax->numnodes * sizeof(long) * 30; return size; } size_t MallocUsedStream(stream* s) { size_t asize = sizeof(*s); asize += streamRadixTreeMemoryUsage(s->rax); /* Now we have to add the listpacks. The last listpack is often non * complete, so we estimate the size of the first N listpacks, and * use the average to compute the size of the first N-1 listpacks, and * finally add the real size of the last node. */ raxIterator ri; raxStart(&ri, s->rax); raxSeek(&ri, "^", NULL, 0); size_t lpsize = 0, samples = 0; while (raxNext(&ri)) { uint8_t* lp = (uint8_t*)ri.data; /* Use the allocated size, since we overprovision the node initially. */ lpsize += zmalloc_size(lp); samples++; } if (s->rax->numele <= samples) { asize += lpsize; } else { if (samples) lpsize /= samples; /* Compute the average. */ asize += lpsize * (s->rax->numele - 1); /* No need to check if seek succeeded, we enter this branch only * if there are a few elements in the radix tree. */ raxSeek(&ri, "$", NULL, 0); raxNext(&ri); /* Use the allocated size, since we overprovision the node initially. */ asize += zmalloc_size(ri.data); } raxStop(&ri); /* Consumer groups also have a non trivial memory overhead if there * are many consumers and many groups, let's count at least the * overhead of the pending entries in the groups and consumers * PELs. */ if (s->cgroups) { raxStart(&ri, s->cgroups); raxSeek(&ri, "^", NULL, 0); while (raxNext(&ri)) { streamCG* cg = (streamCG*)ri.data; asize += sizeof(*cg); asize += streamRadixTreeMemoryUsage(cg->pel); asize += sizeof(streamNACK) * raxSize(cg->pel); /* For each consumer we also need to add the basic data * structures and the PEL memory usage. */ raxIterator cri; raxStart(&cri, cg->consumers); raxSeek(&cri, "^", NULL, 0); while (raxNext(&cri)) { const streamConsumer* consumer = (const streamConsumer*)cri.data; asize += sizeof(*consumer); asize += sdslen(consumer->name); asize += streamRadixTreeMemoryUsage(consumer->pel); /* Don't count NACKs again, they are shared with the * consumer group PEL. */ } raxStop(&cri); } raxStop(&ri); } return asize; } inline void FreeObjHash(unsigned encoding, void* ptr) { switch (encoding) { case kEncodingStrMap2: CompactObj::DeleteMR(ptr); break; case kEncodingListPack: lpFree((uint8_t*)ptr); break; default: LOG(FATAL) << "Unknown hset encoding type " << encoding; } } inline void FreeObjZset(unsigned encoding, void* ptr) { switch (encoding) { case OBJ_ENCODING_SKIPLIST: CompactObj::DeleteMR(ptr); break; case OBJ_ENCODING_LISTPACK: zfree(ptr); break; default: LOG(FATAL) << "Unknown sorted set encoding" << encoding; } } pair DefragStrMap2(StringMap* sm, PageUsage* page_usage) { bool realloced = false; for (auto it = sm->begin(); it != sm->end(); ++it) realloced |= it.ReallocIfNeeded(page_usage); return {sm, realloced}; } pair DefragListPack(uint8_t* lp, PageUsage* page_usage) { if (!page_usage->IsPageForObjectUnderUtilized(lp)) return {lp, false}; size_t lp_bytes = lpBytes(lp); uint8_t* replacement = lpNew(lpBytes(lp)); memcpy(replacement, lp, lp_bytes); lpFree(lp); return {replacement, true}; } pair DefragIntSet(intset* is, PageUsage* page_usage) { if (!page_usage->IsPageForObjectUnderUtilized(is)) return {is, false}; const size_t blob_len = intsetBlobLen(is); intset* replacement = (intset*)zmalloc(blob_len); memcpy(replacement, is, blob_len); zfree(is); return {replacement, true}; } pair DefragSortedMap(detail::SortedMap* sm, PageUsage* page_usage) { const bool reallocated = sm->DefragIfNeeded(page_usage); return {sm, reallocated}; } pair DefragStrSet(StringSet* ss, PageUsage* page_usage) { bool realloced = false; for (auto it = ss->begin(); it != ss->end(); ++it) realloced |= it.ReallocIfNeeded(page_usage); return {ss, realloced}; } // Iterates over allocations of internal hash data structures and re-allocates // them if their pages are underutilized. // Returns pointer to new object ptr and whether any re-allocations happened. pair DefragHash(unsigned encoding, void* ptr, PageUsage* page_usage) { switch (encoding) { // Listpack is stored as a single contiguous array case kEncodingListPack: { return DefragListPack((uint8_t*)ptr, page_usage); } // StringMap supports re-allocation of it's internal nodes case kEncodingStrMap2: { return DefragStrMap2((StringMap*)ptr, page_usage); } default: ABSL_UNREACHABLE(); } } pair DefragSet(unsigned encoding, void* ptr, PageUsage* page_usage) { switch (encoding) { // Int sets have flat storage case kEncodingIntSet: { return DefragIntSet((intset*)ptr, page_usage); } case kEncodingStrMap2: { return DefragStrSet((StringSet*)ptr, page_usage); } default: ABSL_UNREACHABLE(); } } pair DefragZSet(unsigned encoding, void* ptr, PageUsage* page_usage) { switch (encoding) { // Listpack is stored as a single contiguous array case OBJ_ENCODING_LISTPACK: { return DefragListPack((uint8_t*)ptr, page_usage); } // SKIPLIST really means ScoreMap case OBJ_ENCODING_SKIPLIST: { return DefragSortedMap((detail::SortedMap*)ptr, page_usage); } default: ABSL_UNREACHABLE(); } } pair DefragList(unsigned encoding, void* ptr, PageUsage* page_usage) { if (encoding == kEncodingListPack) { return DefragListPack((uint8_t*)ptr, page_usage); } auto* qlist_ptr = static_cast(ptr); bool reallocated = qlist_ptr->DefragIfNeeded(page_usage); return {ptr, reallocated}; } inline void FreeObjStream(void* ptr) { freeStream((stream*)ptr); } inline const uint8_t* to_byte(const void* s) { return reinterpret_cast(s); } static_assert(binpacked_len(7) == 7); static_assert(binpacked_len(8) == 7); static_assert(binpacked_len(15) == 14); static_assert(binpacked_len(16) == 14); static_assert(binpacked_len(17) == 15); static_assert(binpacked_len(18) == 16); static_assert(binpacked_len(19) == 17); static_assert(binpacked_len(20) == 18); static_assert(ascii_len(14) == 16); static_assert(ascii_len(15) == 17); static_assert(ascii_len(16) == 18); static_assert(ascii_len(17) == 19); struct Huffman { HuffmanEncoder encoder; HuffmanDecoder decoder; }; struct TL { MemoryResource* local_mr = PMR_NS::get_default_resource(); base::PODArray tmp_buf; string tmp_str; size_t small_str_bytes; Huffman huff_keys, huff_string_values; uint64_t huff_encode_total = 0, huff_encode_success = 0; // success/total metrics. const HuffmanDecoder& GetHuffmanDecoder(uint8_t huffman_domain) const { return huffman_domain == CompactObj::HUFF_KEYS ? huff_keys.decoder : huff_string_values.decoder; } }; thread_local TL tl; constexpr bool kUseAsciiEncoding = true; } // namespace static_assert(sizeof(CompactObj) == 18); namespace detail { size_t RobjWrapper::MallocUsed(bool slow) const { if (!inner_obj_) return 0; switch (type_) { case OBJ_STRING: CHECK_EQ(OBJ_ENCODING_RAW, encoding_); return InnerObjMallocUsed(); case OBJ_LIST: if (encoding_ == kEncodingListPack) { return zmalloc_usable_size(inner_obj_); } return ((QList*)inner_obj_)->MallocUsed(slow); case OBJ_SET: return MallocUsedSet(encoding_, inner_obj_); case OBJ_HASH: return MallocUsedHSet(encoding_, inner_obj_); case OBJ_ZSET: return MallocUsedZSet(encoding_, inner_obj_); case OBJ_STREAM: return slow ? MallocUsedStream((stream*)inner_obj_) : sz_; default: LOG(FATAL) << "Not supported " << type_; } return 0; } size_t RobjWrapper::Size() const { switch (type_) { case OBJ_STRING: DCHECK_EQ(OBJ_ENCODING_RAW, encoding_); return sz_; case OBJ_LIST: if (encoding_ == kEncodingListPack) { return lpLength((uint8_t*)inner_obj_); } return ((QList*)inner_obj_)->Size(); case OBJ_ZSET: { switch (encoding_) { case OBJ_ENCODING_SKIPLIST: { SortedMap* ss = (SortedMap*)inner_obj_; return ss->Size(); } case OBJ_ENCODING_LISTPACK: return lpLength((uint8_t*)inner_obj_) / 2; default: LOG(FATAL) << "Unknown sorted set encoding" << encoding_; } } case OBJ_SET: switch (encoding_) { case kEncodingIntSet: { intset* is = (intset*)inner_obj_; return intsetLen(is); } case kEncodingStrMap2: { StringSet* ss = (StringSet*)inner_obj_; return ss->UpperBoundSize(); } default: LOG(FATAL) << "Unexpected encoding " << encoding_; }; case OBJ_HASH: switch (encoding_) { case kEncodingListPack: { uint8_t* lp = (uint8_t*)inner_obj_; return lpLength(lp) / 2; } break; case kEncodingStrMap2: { StringMap* sm = (StringMap*)inner_obj_; return sm->UpperBoundSize(); } default: LOG(FATAL) << "Unexpected encoding " << encoding_; } case OBJ_STREAM: // Size mean malloc bytes for streams return sz_; default:; } return 0; } void RobjWrapper::Free(MemoryResource* mr) { if (!inner_obj_) return; DVLOG(1) << "RobjWrapper::Free " << inner_obj_; switch (type_) { case OBJ_STRING: DVLOG(2) << "Freeing string object"; DCHECK_EQ(OBJ_ENCODING_RAW, encoding_); mr->deallocate(inner_obj_, 0, 8); // we do not keep the allocated size. break; case OBJ_LIST: FreeList(encoding_, inner_obj_, mr); break; case OBJ_SET: FreeObjSet(encoding_, inner_obj_, mr); break; case OBJ_ZSET: FreeObjZset(encoding_, inner_obj_); break; case OBJ_HASH: FreeObjHash(encoding_, inner_obj_); break; case OBJ_MODULE: LOG(FATAL) << "Unsupported OBJ_MODULE type"; break; case OBJ_STREAM: FreeObjStream(inner_obj_); break; default: LOG(FATAL) << "Unknown object type"; break; } Set(nullptr, 0); } uint64_t RobjWrapper::HashCode() const { switch (type_) { case OBJ_STRING: DCHECK_EQ(OBJ_ENCODING_RAW, encoding()); { auto str = AsView(); return XXH3_64bits_withSeed(str.data(), str.size(), kHashSeed); } break; default: LOG(FATAL) << "Unsupported type for hashcode " << type_; } return 0; } bool RobjWrapper::Equal(const RobjWrapper& ow) const { if (ow.type_ != type_ || ow.encoding_ != encoding_) return false; if (type_ == OBJ_STRING) { DCHECK_EQ(OBJ_ENCODING_RAW, encoding()); return AsView() == ow.AsView(); } LOG(FATAL) << "Unsupported type " << type_; return false; } bool RobjWrapper::Equal(string_view sv) const { if (type() != OBJ_STRING) return false; DCHECK_EQ(OBJ_ENCODING_RAW, encoding()); return AsView() == sv; } void RobjWrapper::SetString(string_view s, MemoryResource* mr) { type_ = OBJ_STRING; encoding_ = OBJ_ENCODING_RAW; if (s.size() > sz_) { size_t cur_cap = InnerObjMallocUsed(); if (s.size() > cur_cap) { MakeInnerRoom(cur_cap, s.size(), mr); } memcpy(inner_obj_, s.data(), s.size()); sz_ = s.size(); } } void RobjWrapper::ReserveString(size_t size, MemoryResource* mr) { CHECK_EQ(inner_obj_, nullptr); type_ = OBJ_STRING; encoding_ = OBJ_ENCODING_RAW; MakeInnerRoom(0, size, mr); } void RobjWrapper::AppendString(string_view s, MemoryResource* mr) { size_t cur_cap = InnerObjMallocUsed(); CHECK(cur_cap >= sz_ + s.size()) << cur_cap << " " << sz_ << " " << s.size(); memcpy(reinterpret_cast(inner_obj_) + sz_, s.data(), s.size()); sz_ += s.size(); } void RobjWrapper::SetSize(uint64_t size) { sz_ = size; } bool RobjWrapper::DefragIfNeeded(PageUsage* page_usage) { auto do_defrag = [this, &page_usage](auto defrag_fun) mutable { auto [new_ptr, realloced] = defrag_fun(encoding_, inner_obj_, page_usage); inner_obj_ = new_ptr; return realloced; }; if (type() == OBJ_STRING) { if (page_usage->IsPageForObjectUnderUtilized(inner_obj())) { ReallocateString(tl.local_mr); return true; } } else if (type() == OBJ_HASH) { return do_defrag(DefragHash); } else if (type() == OBJ_SET) { return do_defrag(DefragSet); } else if (type() == OBJ_ZSET) { return do_defrag(DefragZSet); } else if (type() == OBJ_LIST) { return do_defrag(DefragList); } page_usage->RecordNotSupported(); return false; } void RobjWrapper::ReallocateString(MemoryResource* mr) { DCHECK_EQ(type(), OBJ_STRING); void* old_ptr = inner_obj_; inner_obj_ = mr->allocate(sz_, kAlignSize); memcpy(inner_obj_, old_ptr, sz_); mr->deallocate(old_ptr, 0, kAlignSize); } void RobjWrapper::Init(unsigned type, unsigned encoding, void* inner) { type_ = type; encoding_ = encoding; Set(inner, 0); } inline size_t RobjWrapper::InnerObjMallocUsed() const { return zmalloc_size(inner_obj_); } void RobjWrapper::MakeInnerRoom(size_t current_cap, size_t desired, MemoryResource* mr) { if (current_cap * 2 > desired) { if (desired < SDS_MAX_PREALLOC) desired *= 2; else desired += SDS_MAX_PREALLOC; } void* newp = mr->allocate(desired, kAlignSize); if (sz_) { memcpy(newp, inner_obj_, sz_); } if (current_cap) { mr->deallocate(inner_obj_, current_cap, kAlignSize); } inner_obj_ = newp; } } // namespace detail uint32_t JsonEnconding() { thread_local uint32_t json_enc = absl::GetFlag(FLAGS_experimental_flat_json) ? kEncodingJsonFlat : kEncodingJsonCons; return json_enc; } using namespace std; auto CompactObj::GetStatsThreadLocal() -> Stats { Stats res; res.small_string_bytes = tl.small_str_bytes; res.huff_encode_total = tl.huff_encode_total; res.huff_encode_success = tl.huff_encode_success; return res; } void CompactObj::InitThreadLocal(MemoryResource* mr) { tl.local_mr = mr; tl.tmp_buf = base::PODArray{mr}; } bool CompactObj::InitHuffmanThreadLocal(HuffmanDomain domain, std::string_view hufftable) { string err_msg; Huffman* huffman = nullptr; switch (domain) { case HUFF_KEYS: huffman = &tl.huff_keys; break; case HUFF_STRING_VALUES: huffman = &tl.huff_string_values; break; } // We do not allow overriding the existing huffman table once it is set. if (huffman->encoder.valid()) { return false; } if (!huffman->encoder.Load(hufftable, &err_msg)) { LOG(DFATAL) << "Failed to load huffman table: " << err_msg; return false; } if (!huffman->decoder.Load(hufftable, &err_msg)) { LOG(DFATAL) << "Failed to load huffman table: " << err_msg; return false; } return true; } CompactObj::~CompactObj() { if (HasAllocated()) { Free(); } } CompactObj& CompactObj::operator=(CompactObj&& o) noexcept { DCHECK(&o != this); DCHECK_EQ(is_key_, o.is_key_); SetMeta(o.taglen_, o.mask_); // frees own previous resources encoding_ = o.encoding_; memcpy(&u_, &o.u_, sizeof(u_)); o.taglen_ = 0; // forget all data o.encoding_ = 0; o.mask_ = 0; return *this; } size_t CompactObj::Size() const { auto decoded_str_size = [this](size_t raw_size, uint8_t first_byte) { DCHECK_EQ(ObjType(), OBJ_STRING); return GetStrEncoding().DecodedSize(raw_size, first_byte); }; if (IsInline()) return decoded_str_size(taglen_, u_.inline_str[0]); switch (taglen_) { case SMALL_TAG: return decoded_str_size(u_.small_str.size(), u_.small_str.first_byte()); case EXTERNAL_TAG: if (ObjType() == OBJ_STRING) return decoded_str_size(u_.ext_ptr.serialized_size, GetFirstByte()); else return u_.ext_ptr.serialized_size; case ROBJ_TAG: if (size_t size = u_.r_obj.Size(); u_.r_obj.type() != OBJ_STRING) return size; else return decoded_str_size(size, *(uint8_t*)u_.r_obj.inner_obj()); case INT_TAG: return absl::AlphaNum(u_.ival).size(); case SDS_TTL_TAG: return decoded_str_size(sdslen(u_.sds_ttl.sds_ptr), u_.sds_ttl.sds_ptr[0]); case JSON_TAG: if (JsonEnconding() == kEncodingJsonFlat) return u_.json_obj.flat.json_len; else return u_.json_obj.cons.json_ptr->size(); case SBF_TAG: return u_.sbf->current_size(); case CMS_TAG: return 0; case TOPK_TAG: return u_.topk->Size(); default: LOG(DFATAL) << "Should not reach " << int(taglen_); return 0; } } uint64_t CompactObj::HashCode() const { DCHECK(taglen_ != JSON_TAG) << "JSON type cannot be used for keys!"; if (encoding_ == NONE_ENC) { if (IsInline()) { return XXH3_64bits_withSeed(u_.inline_str, taglen_, kHashSeed); } switch (taglen_) { case SMALL_TAG: return u_.small_str.HashCode(); case ROBJ_TAG: return u_.r_obj.HashCode(); case INT_TAG: { absl::AlphaNum an(u_.ival); return XXH3_64bits_withSeed(an.data(), an.size(), kHashSeed); } case SDS_TTL_TAG: return XXH3_64bits_withSeed(u_.sds_ttl.sds_ptr, sdslen(u_.sds_ttl.sds_ptr), kHashSeed); } } DCHECK(encoding_); if (IsInline()) { // Buffer must accommodate maximum decompressed size from inline storage // Highly compressible data can achieve ~8x compression (e.g., repeated character) // kInlineLen (16 bytes) compressed -> up to 128 bytes decompressed char buf[kInlineLen * 8]; size_t decoded_len = GetStrEncoding().Decode(string_view{u_.inline_str, taglen_}, buf); return XXH3_64bits_withSeed(buf, decoded_len, kHashSeed); } string_view sv = GetSlice(&tl.tmp_str); return XXH3_64bits_withSeed(sv.data(), sv.size(), kHashSeed); } uint64_t CompactObj::HashCode(string_view str) { return XXH3_64bits_withSeed(str.data(), str.size(), kHashSeed); } CompactObjType CompactObj::ObjType() const { if (IsInline() || taglen_ == INT_TAG || taglen_ == SMALL_TAG || taglen_ == SDS_TTL_TAG) return OBJ_STRING; if (taglen_ == EXTERNAL_TAG) { switch (static_cast(u_.ext_ptr.representation)) { case ExternalRep::STRING: return OBJ_STRING; case ExternalRep::SERIALIZED_MAP: return OBJ_HASH; }; } if (taglen_ == ROBJ_TAG) return u_.r_obj.type(); if (taglen_ == JSON_TAG) { return OBJ_JSON; } if (taglen_ == SBF_TAG) { return OBJ_SBF; } if (taglen_ == CMS_TAG) { return OBJ_CMS; } if (taglen_ == TOPK_TAG) { return OBJ_TOPK; } LOG(FATAL) << "TBD " << int(taglen_); return kInvalidCompactObjType; } unsigned CompactObj::Encoding() const { switch (taglen_) { case ROBJ_TAG: return u_.r_obj.encoding(); case INT_TAG: return OBJ_ENCODING_INT; default: return OBJ_ENCODING_RAW; } } void CompactObj::InitRobj(CompactObjType type, unsigned encoding, void* obj) { DCHECK_NE(type, OBJ_STRING); SetMeta(ROBJ_TAG, mask_); u_.r_obj.Init(type, encoding, obj); } void CompactObj::SetInt(int64_t val) { DCHECK(!IsExternal()); if (INT_TAG != taglen_) { SetMeta(INT_TAG, mask_); encoding_ = NONE_ENC; } u_.ival = val; } std::optional CompactObj::TryGetInt() const { if (taglen_ != INT_TAG) return std::nullopt; int64_t val = u_.ival; return val; } auto CompactObj::GetJson() const -> JsonType* { if (ObjType() == OBJ_JSON) { DCHECK_EQ(JsonEnconding(), kEncodingJsonCons); return u_.json_obj.cons.json_ptr; } return nullptr; } void CompactObj::SetJson(JsonType&& j) { if (taglen_ == JSON_TAG && JsonEnconding() == kEncodingJsonCons) { DCHECK(u_.json_obj.cons.json_ptr != nullptr); // must be allocated u_.json_obj.cons.json_ptr->swap(j); DCHECK(jsoncons::is_trivial_storage(u_.json_obj.cons.json_ptr->storage_kind()) || u_.json_obj.cons.json_ptr->get_allocator().resource() == tl.local_mr); // We do not set bytes_used as this is needed. Consider the two following cases: // 1. old json contains 50 bytes. The delta for new one is 50, so the total bytes // the new json occupies is 100. // 2. old json contains 100 bytes. The delta for new one is -50, so the total bytes // the new json occupies is 50. // Both of the cases are covered in SetJsonSize and JsonMemTracker. See below. return; } SetMeta(JSON_TAG); u_.json_obj.cons.json_ptr = AllocateMR(std::move(j)); // With trivial storage json_ptr->get_allocator() throws an exception. DCHECK(jsoncons::is_trivial_storage(u_.json_obj.cons.json_ptr->storage_kind()) || u_.json_obj.cons.json_ptr->get_allocator().resource() == tl.local_mr); u_.json_obj.cons.bytes_used = 0; } void CompactObj::SetJsonSize(int64_t size) { if (taglen_ == JSON_TAG && JsonEnconding() == kEncodingJsonCons) { // JSON.SET or if mem hasn't changed from a JSON op then we just update. int64_t result = static_cast(u_.json_obj.cons.bytes_used) + size; if (result < 1) { LOG_EVERY_T(ERROR, 20) << "JSON size underflow: " << u_.json_obj.cons.bytes_used << " + " << size << " = " << result; u_.json_obj.cons.bytes_used = 1; } else { u_.json_obj.cons.bytes_used = static_cast(result); } } } void CompactObj::AddStreamSize(int64_t size) { if (size < 0) { // We might have a negative size. For example, if we remove a consumer, // the tracker will report a negative net (since we deallocated), // so the object now consumes less memory than it did before. This DCHECK // is for fanity and to catch any potential issues with our tracking approach. DCHECK(static_cast(u_.r_obj.Size()) >= size); } u_.r_obj.SetSize((u_.r_obj.Size() + size)); } void CompactObj::SetJson(const uint8_t* buf, size_t len) { SetMeta(JSON_TAG); u_.json_obj.flat.flat_ptr = (uint8_t*)tl.local_mr->allocate(len, kAlignSize); memcpy(u_.json_obj.flat.flat_ptr, buf, len); u_.json_obj.flat.json_len = len; } void CompactObj::SetSBF(uint64_t initial_capacity, double fp_prob, double grow_factor) { if (taglen_ == SBF_TAG) { // already json *u_.sbf = SBF(initial_capacity, fp_prob, grow_factor, tl.local_mr); } else { SetMeta(SBF_TAG); u_.sbf = AllocateMR(initial_capacity, fp_prob, grow_factor, tl.local_mr); } } SBF* CompactObj::GetSBF() const { DCHECK_EQ(SBF_TAG, taglen_); return u_.sbf; } void CompactObj::SetCMS(uint32_t width, uint32_t depth) { if (taglen_ == CMS_TAG) { *u_.cms = CMS(width, depth, tl.local_mr); } else { SetMeta(CMS_TAG); u_.cms = AllocateMR(width, depth, tl.local_mr); } } CMS* CompactObj::GetCMS() const { DCHECK_EQ(CMS_TAG, taglen_); return u_.cms; } void CompactObj::SetTOPK(uint32_t k, uint32_t width, uint32_t depth, double decay) { if (taglen_ == TOPK_TAG) { *u_.topk = TOPK(memory_resource(), k, width, depth, decay); } else { SetMeta(TOPK_TAG); u_.topk = AllocateMR(memory_resource(), k, width, depth, decay); } } TOPK* CompactObj::GetTOPK() const { DCHECK_EQ(TOPK_TAG, taglen_); return u_.topk; } void CompactObj::SetString(std::string_view str) { CHECK(!IsExternal()); encoding_ = NONE_ENC; // Trying auto-detection heuristics first. if (str.size() <= 20) { long long ival; static_assert(sizeof(long long) == 8); // We use redis string2ll to be compatible with Redis. if (string2ll(str.data(), str.size(), &ival)) { SetMeta(INT_TAG, mask_); u_.ival = ival; return; } if (str.size() <= kInlineLen) { SetMeta(str.size(), mask_); if (!str.empty()) memcpy(u_.inline_str, str.data(), str.size()); return; } } EncodeString(str); } void CompactObj::ReserveString(size_t size) { encoding_ = NONE_ENC; SetMeta(ROBJ_TAG, mask_); u_.r_obj.ReserveString(size, tl.local_mr); } void CompactObj::AppendString(std::string_view str) { u_.r_obj.AppendString(str, tl.local_mr); } string_view CompactObj::GetSlice(string* scratch) const { CHECK(!IsExternal()); if (encoding_) { GetString(scratch); return *scratch; } if (IsInline()) { return string_view{u_.inline_str, taglen_}; } if (taglen_ == INT_TAG) { absl::AlphaNum an(u_.ival); scratch->assign(an.Piece()); return *scratch; } // no encoding. if (taglen_ == ROBJ_TAG) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); return u_.r_obj.AsView(); } if (taglen_ == SMALL_TAG) { u_.small_str.Get(scratch); return *scratch; } if (taglen_ == SDS_TTL_TAG) { return u_.sds_ttl.view(); } LOG(FATAL) << "Bad tag " << int(taglen_); return string_view{}; } bool CompactObj::DefragIfNeeded(PageUsage* page_usage) { static const bool disable_json_defragmentation = absl::GetFlag(FLAGS_disable_json_defragmentation); if (OmitDefrag()) { page_usage->RecordNotRequired(); return false; } switch (taglen_) { case ROBJ_TAG: // currently only these object types are supported for this operation if (u_.r_obj.inner_obj() != nullptr) { return u_.r_obj.DefragIfNeeded(page_usage); } return false; case SMALL_TAG: return u_.small_str.DefragIfNeeded(page_usage); case JSON_TAG: if (disable_json_defragmentation) { return false; } return u_.json_obj.DefragIfNeeded(page_usage); case SDS_TTL_TAG: if (page_usage->IsPageForObjectUnderUtilized(u_.sds_ttl.sds_ptr)) { size_t len = sdslen(u_.sds_ttl.sds_ptr); char* new_sds = sdsnewlen(u_.sds_ttl.sds_ptr, len); sdsfree(u_.sds_ttl.sds_ptr); u_.sds_ttl.sds_ptr = new_sds; return true; } return false; case INT_TAG: page_usage->RecordNotRequired(); // this is not relevant in this case return false; case EXTERNAL_TAG: page_usage->RecordNotRequired(); return false; default: page_usage->RecordNotRequired(); // This is the case when the object is at inline_str return false; } } bool CompactObj::HasAllocated() const { if (IsRef() || taglen_ == INT_TAG || IsInline() || taglen_ == EXTERNAL_TAG || (taglen_ == ROBJ_TAG && u_.r_obj.inner_obj() == nullptr)) return false; DCHECK(taglen_ == ROBJ_TAG || taglen_ == SMALL_TAG || taglen_ == JSON_TAG || taglen_ == SBF_TAG || taglen_ == CMS_TAG || taglen_ == SDS_TTL_TAG || taglen_ == TOPK_TAG); return true; } bool CompactObj::TagAllowsEmptyValue() const { const auto type = ObjType(); return type == OBJ_JSON || type == OBJ_STREAM || type == OBJ_STRING || type == OBJ_SBF || type == OBJ_CMS || type == OBJ_TOPK || type == OBJ_SET; } void __attribute__((noinline)) CompactObj::GetString(string* res) const { res->resize(Size()); GetString(res->data()); } void CompactObj::GetString(char* dest) const { CHECK(!IsExternal()); if (IsInline()) { GetStrEncoding().Decode({u_.inline_str, taglen_}, dest); return; } if (taglen_ == INT_TAG) { absl::AlphaNum an(u_.ival); memcpy(dest, an.data(), an.size()); return; } if (encoding_) { StrEncoding str_encoding = GetStrEncoding(); string_view decode_blob = GetEncodedBlob(str_encoding, dest); str_encoding.Decode(decode_blob, dest); return; } // no encoding. if (taglen_ == ROBJ_TAG) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); memcpy(dest, u_.r_obj.inner_obj(), u_.r_obj.Size()); return; } if (taglen_ == SDS_TTL_TAG) { memcpy(dest, u_.sds_ttl.sds_ptr, sdslen(u_.sds_ttl.sds_ptr)); return; } if (taglen_ == SMALL_TAG) return u_.small_str.Get(dest); LOG(FATAL) << "Bad tag " << int(taglen_); } void CompactObj::SetExternal(size_t offset, uint32_t sz, ExternalRep rep) { uint8_t first_byte = 0; if (encoding_ == HUFFMAN_ENC) { CHECK(rep == ExternalRep::STRING); first_byte = GetFirstByte(); } SetMeta(EXTERNAL_TAG, mask_); u_.ext_ptr.is_cool = 0; u_.ext_ptr.representation = static_cast(rep); u_.ext_ptr.first_byte = first_byte; u_.ext_ptr.page_offset = offset % 4096; u_.ext_ptr.serialized_size = sz; u_.ext_ptr.offload.page_index = offset / 4096; } CompactObj::ExternalRep CompactObj::GetExternalRep() const { DCHECK(IsExternal()); return static_cast(u_.ext_ptr.representation); } void CompactObj::SetCool(size_t offset, uint32_t sz, ExternalRep rep, tiering::TieredCoolRecord* record) { encoding_ = record->value.encoding_; SetMeta(EXTERNAL_TAG, record->value.mask_); u_.ext_ptr.is_cool = 1; u_.ext_ptr.representation = static_cast(rep); u_.ext_ptr.page_offset = offset % 4096; u_.ext_ptr.serialized_size = sz; u_.ext_ptr.cool_record = record; } auto CompactObj::GetCool() const -> CoolItem { DCHECK(IsExternal() && u_.ext_ptr.is_cool); CoolItem res; res.page_offset = u_.ext_ptr.page_offset; res.serialized_size = u_.ext_ptr.serialized_size; res.record = u_.ext_ptr.cool_record; return res; } void CompactObj::Freeze(size_t offset, size_t sz) { SetExternal(offset, sz, GetExternalRep()); } std::pair CompactObj::GetExternalSlice() const { DCHECK_EQ(EXTERNAL_TAG, taglen_); auto& ext = u_.ext_ptr; size_t offset = ext.page_offset; offset += size_t(ext.is_cool ? ext.cool_record->page_index : ext.offload.page_index) * 4096; return {offset, size_t(u_.ext_ptr.serialized_size)}; } string_view CompactObj::GetEncodedBlob(StrEncoding str_encoding, char* opt_dest) const { if (taglen_ == ROBJ_TAG) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); return u_.r_obj.AsView(); } else if (IsInline()) { return {u_.inline_str, taglen_}; } else if (taglen_ == SDS_TTL_TAG) { return u_.sds_ttl.view(); } CHECK_EQ(taglen_, SMALL_TAG); auto& ss = u_.small_str; char* copy_dest = nullptr; if (opt_dest && str_encoding.enc_ != HUFFMAN_ENC) { // Write to rightmost location of dest buffer to leave some bytes for inline unpacking size_t decoded_len = str_encoding.DecodedSize(ss.size(), ss.first_byte()); copy_dest = opt_dest + (decoded_len - ss.size()); } else { tl.tmp_buf.resize(ss.size()); copy_dest = reinterpret_cast(tl.tmp_buf.data()); } ss.Get(copy_dest); return {copy_dest, ss.size()}; } void CompactObj::Materialize(std::string_view blob, bool is_raw) { CHECK(IsExternal()) << int(taglen_); DCHECK_EQ(u_.ext_ptr.representation, static_cast(ExternalRep::STRING)); DCHECK_GT(blob.size(), kInlineLen); // There are no mutable commands that shrink strings if (is_raw) { if (SmallString::CanAllocate(blob.size())) { SetMeta(SMALL_TAG, mask_); tl.small_str_bytes += u_.small_str.Assign(blob); } else { SetMeta(ROBJ_TAG, mask_); u_.r_obj.SetString(blob, tl.local_mr); } } else { encoding_ = NONE_ENC; // reset encoding EncodeString(blob); } } void CompactObj::Reset() { if (HasAllocated()) { Free(); } taglen_ = 0; encoding_ = 0; mask_ = 0; } uint8_t CompactObj::GetFirstByte() const { DCHECK_EQ(ObjType(), OBJ_STRING); if (IsInline()) { return u_.inline_str[0]; } if (taglen_ == ROBJ_TAG) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); return *(uint8_t*)u_.r_obj.inner_obj(); } if (taglen_ == SMALL_TAG) { return u_.small_str.first_byte(); } if (taglen_ == SDS_TTL_TAG) { return u_.sds_ttl.sds_ptr[0]; } if (taglen_ == EXTERNAL_TAG) { if (u_.ext_ptr.is_cool) { const CompactObj& cooled_obj = u_.ext_ptr.cool_record->value; return cooled_obj.GetFirstByte(); } return u_.ext_ptr.first_byte; } LOG(DFATAL) << "Bad tag " << int(taglen_); return 0; } bool CompactObj::GetByteAtIndex(size_t idx, uint8_t* res) const { CHECK(!IsExternal()); DCHECK_EQ(ObjType(), OBJ_STRING); if (encoding_) { StrEncoding str_encoding = GetStrEncoding(); string_view decode_blob = GetEncodedBlob(str_encoding, nullptr); if (!str_encoding.DecodeByte(decode_blob, idx, res)) { VLOG(1) << "Offset out of bounds for encoded string: " << idx << " >= " << str_encoding.DecodedSize(decode_blob.size(), decode_blob[0]); *res = 0; return false; } return true; } // No encoding, we can directly access the byte at index. string_view sv = GetSlice(&tl.tmp_str); if (idx >= sv.size()) { VLOG(1) << "Offset out of bounds: " << idx << " >= " << sv.size(); *res = 0; return false; } *res = sv[idx]; return true; } std::pair CompactObj::SetByteAtIndex(size_t idx, uint8_t val) { CHECK(!IsExternal()); DCHECK_EQ(ObjType(), OBJ_STRING); // Inline string without encoding: modify directly. if (IsInline() && !encoding_) { if (idx >= taglen_) { VLOG(1) << "Offset out of bounds for inline string: " << idx << " >= " << int(taglen_); return {false, false}; } u_.inline_str[idx] = val; return {true, true}; } // SDS_TTL_TAG raw string without encoding: modify directly. if (taglen_ == SDS_TTL_TAG && !encoding_) { size_t len = sdslen(u_.sds_ttl.sds_ptr); if (idx >= len) { return {false, false}; } u_.sds_ttl.sds_ptr[idx] = val; return {true, true}; } // ROBJ_TAG raw string without encoding: modify the underlying buffer directly. if (taglen_ == ROBJ_TAG && !encoding_) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); if (idx >= u_.r_obj.Size()) { VLOG(1) << "Offset out of bounds for raw string: " << idx << " >= " << u_.r_obj.Size(); return {false, false}; } reinterpret_cast(u_.r_obj.inner_obj())[idx] = val; return {true, true}; } // For ASCII encoded ROBJ strings we can modify the underlying buffer directly. if (encoding_ && (encoding_ == ASCII1_ENC || encoding_ == ASCII2_ENC) && taglen_ == ROBJ_TAG && absl::ascii_isascii(val)) { DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); auto* buf = reinterpret_cast(u_.r_obj.inner_obj()); size_t decoded_len = GetStrEncoding().DecodedSize(u_.r_obj.Size(), buf[0]); if (idx >= decoded_len) { VLOG(1) << "Offset out of bounds for ASCII encoded string: " << idx << " >= " << decoded_len; return {false, false}; } detail::ascii_pack_byte(buf, decoded_len, idx, val); return {true, true}; } // For other encoded strings, INT_TAG, SMALL_TAG we need to decode, modify, and re-encode. string str; GetString(&str); if (idx >= str.size()) { VLOG(1) << "Offset out of bounds: " << idx << " >= " << str.size(); return {false, false}; } str[idx] = val; SetString(str); return {true, false}; } // Frees all resources if owns. void CompactObj::Free() { DCHECK(HasAllocated()); if (taglen_ == ROBJ_TAG) { u_.r_obj.Free(tl.local_mr); } else if (taglen_ == SMALL_TAG) { tl.small_str_bytes -= u_.small_str.MallocUsed(); u_.small_str.Free(); } else if (taglen_ == JSON_TAG) { DVLOG(1) << "Freeing JSON object"; if (JsonEnconding() == kEncodingJsonCons) { DeleteMR(u_.json_obj.cons.json_ptr); } else { tl.local_mr->deallocate(u_.json_obj.flat.flat_ptr, u_.json_obj.flat.json_len, kAlignSize); } } else if (taglen_ == SBF_TAG) { DeleteMR(u_.sbf); } else if (taglen_ == TOPK_TAG) { DeleteMR(u_.topk); } else if (taglen_ == CMS_TAG) { DeleteMR(u_.cms); } else if (taglen_ == SDS_TTL_TAG) { sdsfree(u_.sds_ttl.sds_ptr); } else { LOG(FATAL) << "Unsupported tag " << int(taglen_); } memset(u_.inline_str, 0, kInlineLen); } size_t CompactObj::MallocUsed(bool slow) const { if (!HasAllocated()) return 0; if (taglen_ == ROBJ_TAG) { return u_.r_obj.MallocUsed(slow); } if (taglen_ == JSON_TAG) { // TODO fix this once we fully support flat json // This is here because accessing a union field that is not active // is UB. if (JsonEnconding() == kEncodingJsonFlat) { return 0; } return u_.json_obj.cons.bytes_used; } if (taglen_ == SMALL_TAG) { return u_.small_str.MallocUsed(); } if (taglen_ == SBF_TAG) { return u_.sbf->MallocUsed(); } if (taglen_ == CMS_TAG) { return u_.cms->MallocUsed(); } if (taglen_ == SDS_TTL_TAG) { return sdsAllocSize(u_.sds_ttl.sds_ptr); } if (taglen_ == TOPK_TAG) { return u_.topk->MallocUsed(); } LOG(DFATAL) << "should not reach"; return 0; } // TODO: we need this operator ONLY because we search in prime-table based on the ExpireKey // which is a reference to the CompactKey. Therefore operator== currently works // specifically for this particular use-case. // So once we remove the expire table, we can remove this operator too. // In addition - we MUST remove AsRef/IsRef api as well as it will break // once we start using SetExpireTime/ClearExpireTime methods. // All in all, we will free up two additional bits. bool CompactKey::operator==(const CompactKey& o) const { DCHECK(taglen_ != JSON_TAG && o.taglen_ != JSON_TAG) << "cannot use JSON type to check equal"; // Cross-tag/encoding comparison: fall back to decoded string comparison for OBJ_STRING. // This handles e.g. SDS_TTL_TAG vs ROBJ_TAG/inline/INT_TAG with same logical content. if (taglen_ != o.taglen_ || encoding_ != o.encoding_) { if (ObjType() == OBJ_STRING && o.ObjType() == OBJ_STRING) { std::string tmp; return *this == o.GetSlice(&tmp); } return false; } if (taglen_ == ROBJ_TAG) return u_.r_obj.Equal(o.u_.r_obj); if (taglen_ == INT_TAG) return u_.ival == o.u_.ival; if (taglen_ == SMALL_TAG) return u_.small_str.Equal(o.u_.small_str); if (taglen_ == SDS_TTL_TAG) return u_.sds_ttl.view() == o.u_.sds_ttl.view(); DCHECK(IsInline() && o.IsInline()); return memcmp(u_.inline_str, o.u_.inline_str, taglen_) == 0; } bool CompactObj::CmpNonInline(std::string_view sv) const { DCHECK_GT(taglen_, kInlineLen); switch (taglen_) { case INT_TAG: return absl::AlphaNum(u_.ival).Piece() == sv; case ROBJ_TAG: return u_.r_obj.Equal(sv); case SMALL_TAG: return u_.small_str.Equal(sv); case SDS_TTL_TAG: return u_.sds_ttl.view() == sv; default: break; } return false; } bool CompactObj::CmpEncoded(string_view sv) const { DCHECK(encoding_); if (encoding_ == HUFFMAN_ENC) { size_t sz = Size(); if (sv.size() != sz) return false; if (IsInline()) { // Buffer must accommodate maximum decompressed size from inline storage (~8x compression) constexpr size_t kMaxHuffLen = kInlineLen * 8; if (sz <= kMaxHuffLen) { char buf[kMaxHuffLen]; auto domain = is_key_ ? HUFF_KEYS : HUFF_STRING_VALUES; const auto& decoder = tl.GetHuffmanDecoder(domain); CHECK(decoder.Decode({u_.inline_str + 1, size_t(taglen_ - 1)}, sz, buf)); return sv == string_view(buf, sz); } } tl.tmp_str.resize(sz); GetString(tl.tmp_str.data()); return sv == tl.tmp_str; } size_t encode_len = binpacked_len(sv.size()); if (IsInline()) { if (encode_len != taglen_) return false; char buf[kInlineLen * 2]; detail::ascii_unpack(to_byte(u_.inline_str), sv.size(), buf); return sv == string_view(buf, sv.size()); } if (taglen_ == ROBJ_TAG) { if (u_.r_obj.type() != OBJ_STRING) return false; if (u_.r_obj.Size() != encode_len) return false; if (!detail::validate_ascii_fast(sv.data(), sv.size())) return false; return detail::compare_packed(to_byte(u_.r_obj.inner_obj()), sv.data(), sv.size()); } if (taglen_ == SDS_TTL_TAG) { size_t sds_len = sdslen(u_.sds_ttl.sds_ptr); if (sds_len != encode_len) return false; if (!detail::validate_ascii_fast(sv.data(), sv.size())) return false; return detail::compare_packed(to_byte(u_.sds_ttl.sds_ptr), sv.data(), sv.size()); } if (taglen_ == JSON_TAG) { return false; // cannot compare json with string } if (taglen_ == SMALL_TAG) { if (u_.small_str.size() != encode_len) return false; if (!detail::validate_ascii_fast(sv.data(), sv.size())) return false; // We need to compare an unpacked sv with 2 packed parts. // To compare easily ascii with binary we would need to split ascii at 8 bytes boundaries // so that we could pack it into complete binary bytes (8 ascii chars produce 7 bytes). // I choose a minimal 16 byte prefix: // 1. sv must be longer than 16 if we are here (at least 18 actually). // 2. 16 chars produce 14 byte blob that should cover the first slice (10 bytes) and 4 bytes // of the second slice. // 3. I assume that the first slice is less than 14 bytes which is correct since small string // has only 9-10 bytes in its inline prefix storage. DCHECK_GT(sv.size(), 16u); // we would not be in SMALL_TAG, otherwise. auto slice = u_.small_str.Get(); DCHECK_LT(slice[0].size(), 14u); uint8_t tmpbuf[14]; detail::ascii_pack(sv.data(), 16, tmpbuf); // Compare the first slice. if (memcmp(slice[0].data(), tmpbuf, slice[0].size()) != 0) return false; // Compare the prefix of the second slice. size_t pref_len = 14 - slice[0].size(); if (memcmp(slice[1].data(), tmpbuf + slice[0].size(), pref_len) != 0) return false; // We verified that the first 16 chars (or 14 bytes) are equal. // Lets verify the rest - suffix of the second slice and the suffix of sv. return detail::compare_packed(to_byte(slice[1].data() + pref_len), sv.data() + 16, sv.size() - 16); } LOG(FATAL) << "Unsupported tag " << int(taglen_); return false; } void CompactObj::EncodeString(string_view str) { DCHECK_GT(str.size(), kInlineLen); DCHECK_EQ(NONE_ENC, encoding_); string_view encoded = str; bool huff_encoded = false; // We chose such length that we can store the decoded length delta into 1 byte. // The maximum huffman compression is 1/8, so 288 / 8 = 36. // 288 - 36 = 252, which is smaller than 256. // TODO: introduce variable length huffman length. constexpr unsigned kMaxHuffLen = 288; // For sizes 17, 18 we would like to test ascii encoding first as it's more efficient. // And if it succeeds we can squash into the inline buffer. bool is_ascii = kUseAsciiEncoding && str.size() < 19 && detail::validate_ascii_fast(str.data(), str.size()); // if !is_ascii, we try huffman encoding next. if (!is_ascii && str.size() <= kMaxHuffLen) { auto& huffman = is_key_ ? tl.huff_keys : tl.huff_string_values; if (huffman.encoder.valid()) { unsigned dest_len = huffman.encoder.CompressedBound(str.size()); // 1 byte for storing the size delta. tl.tmp_buf.resize(1 + dest_len); string err_msg; ++tl.huff_encode_total; bool res = huffman.encoder.Encode(str, tl.tmp_buf.data() + 1, &dest_len, &err_msg); if (res) { // we accept huffman encoding only if it is: // 1. smaller than the original string by 20% // 2. allows us to store the encoded string in the inline buffer if (dest_len && (dest_len < kInlineLen || (dest_len + dest_len / 5) < str.size())) { huff_encoded = true; tl.huff_encode_success++; encoded = string_view{reinterpret_cast(tl.tmp_buf.data()), dest_len + 1}; unsigned delta = str.size() - dest_len; DCHECK_LT(delta, 256u); tl.tmp_buf[0] = static_cast(delta); encoding_ = HUFFMAN_ENC; if (encoded.size() <= kInlineLen) { SetMeta(encoded.size(), mask_); memcpy(u_.inline_str, tl.tmp_buf.data(), encoded.size()); return; } } } else { // Should not happen, means we have an internal buf. LOG(DFATAL) << "Failed to encode string with huffman: " << err_msg; } } } // Finally we try ascii encoding for longer strings if we have not encoded them with huffman. if (kUseAsciiEncoding && !is_ascii && str.size() >= 19 && !huff_encoded) { is_ascii = detail::validate_ascii_fast(str.data(), str.size()); } if (is_ascii) { size_t encode_len = binpacked_len(str.size()); size_t rev_len = ascii_len(encode_len); if (rev_len == str.size()) { encoding_ = ASCII2_ENC; // str hits its highest bound. } else { CHECK_EQ(str.size(), rev_len - 1) << "Bad ascii encoding for len " << str.size(); encoding_ = ASCII1_ENC; // str is shorter than its highest bound. } tl.tmp_buf.resize(encode_len); detail::ascii_pack_simd2(str.data(), str.size(), tl.tmp_buf.data()); encoded = string_view{reinterpret_cast(tl.tmp_buf.data()), encode_len}; if (encoded.size() <= kInlineLen) { SetMeta(encoded.size(), mask_); detail::ascii_pack(str.data(), str.size(), reinterpret_cast(u_.inline_str)); return; } } DCHECK_GT(encoded.size(), kInlineLen); if (SmallString::CanAllocate(encoded.size())) { if (taglen_ == SMALL_TAG) tl.small_str_bytes -= u_.small_str.MallocUsed(); else SetMeta(SMALL_TAG, mask_); tl.small_str_bytes += u_.small_str.Assign(encoded); return; } SetMeta(ROBJ_TAG, mask_); u_.r_obj.SetString(encoded, tl.local_mr); } std::array CompactObj::GetRawString() const { DCHECK(!IsExternal()); if (taglen_ == ROBJ_TAG) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); DCHECK_EQ(OBJ_ENCODING_RAW, u_.r_obj.encoding()); return {u_.r_obj.AsView(), {}}; } if (taglen_ == SMALL_TAG) { return u_.small_str.Get(); } if (taglen_ == SDS_TTL_TAG) { return {u_.sds_ttl.view(), {}}; } LOG(FATAL) << "Unsupported tag for GetRawString(): " << int(taglen_); return {}; } MemoryResource* CompactObj::memory_resource() { return tl.local_mr; } string_view CompactObj::SdsTtlString::view() const { return string_view{sds_ptr, sdslen(sds_ptr)}; } bool CompactObj::JsonConsT::DefragIfNeeded(PageUsage* page_usage) { const MiMemoryResource* mr = static_cast(memory_resource()); const int64_t before = static_cast(mr->used()); DCHECK_GE(before, 0) << "Memory usage is more than int64_t max value"; bool did_defragment = Defragment(*json_ptr, page_usage); const int64_t after = static_cast(mr->used()); DCHECK_GE(after, 0) << "Memory usage is more than int64_t max value"; if (const int64_t delta = after - before; delta != 0) { bytes_used = UpdateSize(bytes_used, delta); } return did_defragment; } bool CompactObj::FlatJsonT::DefragIfNeeded(PageUsage* page_usage) { if (uint8_t* old = flat_ptr; page_usage->IsPageForObjectUnderUtilized(old)) { const uint32_t size = json_len; flat_ptr = static_cast(tl.local_mr->allocate(size, kAlignSize)); memcpy(flat_ptr, old, size); tl.local_mr->deallocate(old, size, kAlignSize); return true; } return false; } bool CompactObj::JsonWrapper::DefragIfNeeded(PageUsage* page_usage) { if (JsonEnconding() == kEncodingJsonCons) { return cons.DefragIfNeeded(page_usage); } return flat.DefragIfNeeded(page_usage); } constexpr std::pair kObjTypeToString[] = { {OBJ_STRING, "string"sv}, {OBJ_LIST, "list"sv}, {OBJ_SET, "set"sv}, {OBJ_ZSET, "zset"sv}, {OBJ_HASH, "hash"sv}, {OBJ_STREAM, "stream"sv}, {OBJ_KEY, "key"sv}, // pseudo-type used for memory tracking {OBJ_JSON, "ReJSON-RL"sv}, {OBJ_SBF, "MBbloom--"sv}, {OBJ_CMS, "CMSk-TYPE"sv}, {OBJ_TOPK, "TopK-TYPE"sv}}; std::string_view ObjTypeToString(CompactObjType type) { for (auto& p : kObjTypeToString) { if (type == p.first) { return p.second; } } LOG(DFATAL) << "Unsupported type " << type; return "Invalid type"sv; } CompactObjType ObjTypeFromString(std::string_view sv) { for (auto& p : kObjTypeToString) { if (absl::EqualsIgnoreCase(sv, p.second)) { return p.first; } } return kInvalidCompactObjType; } void CompactKey::SetExpireTime(uint64_t abs_ms) { DCHECK(!IsRef() && !IsExternal()); // Already SDS_TTL_TAG — update TTL in place. if (taglen_ == SDS_TTL_TAG) { u_.sds_ttl.exp_ms = abs_ms; return; } char* new_sds = nullptr; if (IsInline()) { new_sds = sdsnewlen(u_.inline_str, taglen_); // encoding_ preserved as-is. } else if (taglen_ == INT_TAG) { absl::AlphaNum an(u_.ival); new_sds = sdsnewlen(an.data(), an.size()); encoding_ = NONE_ENC; } else if (taglen_ == SMALL_TAG) { size_t total = u_.small_str.size(); new_sds = sdsnewlen(nullptr, total); u_.small_str.Get(new_sds); tl.small_str_bytes -= u_.small_str.MallocUsed(); u_.small_str.Free(); } else if (taglen_ == ROBJ_TAG) { CHECK_EQ(OBJ_STRING, u_.r_obj.type()); auto view = u_.r_obj.AsView(); new_sds = sdsnewlen(view.data(), view.size()); u_.r_obj.Free(tl.local_mr); } else { LOG(FATAL) << "Unexpected tag for SetExpireTime: " << int(taglen_); } u_.sds_ttl.sds_ptr = new_sds; u_.sds_ttl.exp_ms = abs_ms; taglen_ = SDS_TTL_TAG; mask_bits_.expire = 1; } bool CompactKey::ClearExpireTime() { if (taglen_ != SDS_TTL_TAG) return false; DCHECK(!IsRef() && !IsExternal()); string decoded; GetString(&decoded); SetMeta(0, mask_); encoding_ = NONE_ENC; mask_bits_.expire = 0; SetString(decoded); return true; } uint64_t CompactKey::GetExpireTime() const { if (taglen_ != SDS_TTL_TAG) return 0; DCHECK(!IsRef() && !IsExternal()); return u_.sds_ttl.exp_ms; } size_t CompactObj::StrEncoding::DecodedSize(string_view blob) const { return DecodedSize(blob.size(), blob[0]); } size_t CompactObj::StrEncoding::DecodedSize(size_t blob_size, uint8_t first_byte) const { switch (enc_) { case NONE_ENC: return blob_size; case ASCII1_ENC: case ASCII2_ENC: return ascii_len(blob_size) - (enc_ == ASCII1_ENC); case HUFFMAN_ENC: return blob_size + int(first_byte) - 1; }; return 0; } size_t CompactObj::StrEncoding::Decode(std::string_view blob, char* dest) const { if (blob.empty()) return 0; size_t decoded_len = DecodedSize(blob); switch (enc_) { case NONE_ENC: memcpy(dest, blob.data(), blob.size()); break; case ASCII1_ENC: case ASCII2_ENC: detail::ascii_unpack(reinterpret_cast(blob.data()), decoded_len, dest); break; case HUFFMAN_ENC: { auto domain = is_key_ ? HUFF_KEYS : HUFF_STRING_VALUES; const auto& decoder = tl.GetHuffmanDecoder(domain); decoder.Decode(blob.substr(1), decoded_len, dest); break; } }; return decoded_len; } bool CompactObj::StrEncoding::DecodeByte(std::string_view blob, size_t idx, uint8_t* dest) const { if (blob.empty()) { return false; } size_t decoded_len = DecodedSize(blob); if (idx >= decoded_len) { return false; } switch (enc_) { case NONE_ENC: *dest = blob[idx]; break; case ASCII1_ENC: case ASCII2_ENC: *dest = detail::ascii_unpack_byte(reinterpret_cast(blob.data()), decoded_len, idx); break; case HUFFMAN_ENC: { std::string decoded_huff_string(decoded_len, 0); auto domain = is_key_ ? HUFF_KEYS : HUFF_STRING_VALUES; const auto& decoder = tl.GetHuffmanDecoder(domain); decoder.Decode(blob.substr(1), decoded_len, decoded_huff_string.data()); *dest = decoded_huff_string[idx]; break; } }; return true; } StringOrView CompactObj::StrEncoding::Decode(std::string_view blob) const { switch (enc_) { case NONE_ENC: return StringOrView::FromView(blob); default: { string out; out.resize(DecodedSize(blob)); Decode(blob, out.data()); return StringOrView::FromString(std::move(out)); } } return {}; } /* Create a new stream data structure. */ stream* streamNew() { stream* s = (stream*)zmalloc(sizeof(stream)); s->rax = raxNew(); s->length = 0; s->first_id.ms = 0; s->first_id.seq = 0; s->last_id.ms = 0; s->last_id.seq = 0; s->max_deleted_entry_id.seq = 0; s->max_deleted_entry_id.ms = 0; s->entries_added = 0; s->cgroups = NULL; /* Created on demand to save memory when not used. */ return s; } /* Free a consumer and associated data structures. Note that this function * will not reassign the pending messages associated with this consumer * nor will delete them from the stream, so when this function is called * to delete a consumer, and not when the whole stream is destroyed, the caller * should do some work before. */ static void streamFreeConsumer(streamConsumer* sc) { raxFree(sc->pel); /* No value free callback: the PEL entries are shared between the consumer and the main stream PEL. */ sdsfree(sc->name); zfree(sc); } /* Used for generic free functions. */ static void streamFreeConsumerVoid(void* sc) { streamFreeConsumer((streamConsumer*)sc); } /* Used for generic free functions. */ static void streamFreeCGVoid(void* cg_) { streamCG* cg = (streamCG*)cg_; raxFreeWithCallback(cg->pel, zfree); raxFreeWithCallback(cg->consumers, streamFreeConsumerVoid); zfree(cg); } static void lpFreeVoid(void* lp) { lpFree((uint8_t*)lp); } /* Free a stream, including the listpacks stored inside the radix tree. */ void freeStream(stream* s) { raxFreeWithCallback(s->rax, lpFreeVoid); if (s->cgroups) raxFreeWithCallback(s->cgroups, streamFreeCGVoid); zfree(s); } } // namespace dfly ================================================ FILE: src/core/compact_object.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "base/pmr/memory_resource.h" #include "common/string_or_view.h" #include "core/json/json_object.h" #include "core/mi_memory_resource.h" #include "core/small_string.h" typedef struct stream stream; namespace dfly { namespace tiering { struct TieredCoolRecord; } constexpr unsigned kEncodingIntSet = 0; constexpr unsigned kEncodingStrMap2 = 2; // for set/map encodings of strings using DenseSet constexpr unsigned kEncodingQL2 = 1; constexpr unsigned kEncodingListPack = 3; constexpr unsigned kEncodingJsonCons = 0; constexpr unsigned kEncodingJsonFlat = 1; class SBF; class TOPK; class CMS; class PageUsage; using cmn::StringOrView; namespace detail { // redis objects or blobs of upto 4GB size. class RobjWrapper { public: using MemoryResource = PMR_NS::memory_resource; RobjWrapper() : sz_(0), type_(0), encoding_(0) { } size_t MallocUsed(bool slow) const; uint64_t HashCode() const; bool Equal(const RobjWrapper& ow) const; bool Equal(std::string_view sv) const; size_t Size() const; void Free(MemoryResource* mr); void SetString(std::string_view s, MemoryResource* mr); void ReserveString(size_t size, MemoryResource* mr); void AppendString(std::string_view s, MemoryResource* mr); // Used when sz_ is used to denote memory usage void SetSize(uint64_t size); void Init(unsigned type, unsigned encoding, void* inner); unsigned type() const { return type_; } unsigned encoding() const { return encoding_; } void* inner_obj() const { return inner_obj_; } void set_inner_obj(void* ptr) { inner_obj_ = ptr; } std::string_view AsView() const { return std::string_view{reinterpret_cast(inner_obj_), sz_}; } // Try reducing memory fragmentation by re-allocating values from underutilized pages. // Returns true if re-allocated. bool DefragIfNeeded(PageUsage* page_usage); private: void ReallocateString(MemoryResource* mr); size_t InnerObjMallocUsed() const; void MakeInnerRoom(size_t current_cap, size_t desired, MemoryResource* mr); void Set(void* p, size_t s) { inner_obj_ = p; sz_ = s; } void* inner_obj_ = nullptr; // semantics depend on the type. For OBJ_STRING it's string length. uint64_t sz_ : 56; uint64_t type_ : 4; uint64_t encoding_ : 4; } __attribute__((packed)); static_assert(sizeof(RobjWrapper) == 16); } // namespace detail using CompactObjType = unsigned; constexpr CompactObjType kInvalidCompactObjType = std::numeric_limits::max(); uint32_t JsonEnconding(); class CompactObj { static constexpr unsigned kInlineLen = 16; void operator=(const CompactObj&) = delete; CompactObj(const CompactObj&) = delete; protected: // 0-16 is reserved for inline lengths of string type. enum TagEnum : uint8_t { INT_TAG = 17, SMALL_TAG = 18, ROBJ_TAG = 19, EXTERNAL_TAG = 20, JSON_TAG = 21, SBF_TAG = 22, CMS_TAG = 23, SDS_TTL_TAG = 24, TOPK_TAG = 25, }; // String encoding types. // With ascii compression it compresses 8 bytes to 7 but also 7 to 7. // Therefore, in order to know the original length we introduce 2 states that // correct the length upon decoding. ASCII1_ENC rounds down the decoded length, // while ASCII2_ENC rounds it up. See DecodedLen implementation for more info. enum EncodingEnum : uint8_t { NONE_ENC = 0, ASCII1_ENC = 1, ASCII2_ENC = 2, HUFFMAN_ENC = 3, }; public: // Utility class for working with different string encodings (ascii, huffman, etc) struct StrEncoding { size_t DecodedSize(std::string_view blob) const; // Size of decoded blob size_t Decode(std::string_view blob, char* dest) const; // Decode into dest, return size StringOrView Decode(std::string_view blob) const; // Decode a byte at offset into dest. Return true if decoded successfully, // false if idx is out of bounds. bool DecodeByte(std::string_view blob, size_t idx, uint8_t* dest) const; private: friend class CompactObj; explicit StrEncoding(uint8_t enc, bool is_key) : enc_(static_cast(enc)), is_key_(is_key) { } size_t DecodedSize(size_t compr_size, uint8_t first_byte) const; EncodingEnum enc_; bool is_key_; }; using MemoryResource = detail::RobjWrapper::MemoryResource; // Different representations of external values enum class ExternalRep : uint8_t { STRING, // OBJ_STRING, Basic representation with various string encodings SERIALIZED_MAP // OBJ_HASH, Serialized map }; explicit CompactObj(bool is_key) : is_key_{is_key}, taglen_{0}, encoding_{0} { // default - empty string } CompactObj(std::string_view str, bool is_key) : CompactObj(is_key) { SetString(str); } CompactObj(CompactObj&& cs) noexcept : CompactObj(cs.is_key_) { operator=(std::move(cs)); }; ~CompactObj(); CompactObj& operator=(CompactObj&& o) noexcept; // Returns object size depending on the semantics. // For strings - returns the length of the string. // For containers - returns number of elements in the container. size_t Size() const; bool IsRef() const { return mask_bits_.ref; } std::string_view GetSlice(std::string* scratch) const; std::string ToString() const { std::string res; GetString(&res); return res; } uint64_t HashCode() const; static uint64_t HashCode(std::string_view str); bool HasFlag() const { return mask_bits_.mc_flag; } void SetFlag(bool e) { mask_bits_.mc_flag = e; } bool WasTouched() const { return mask_bits_.touched; } void SetTouched(bool e) { mask_bits_.touched = e; } bool DefragIfNeeded(PageUsage* page_usage); void SetOmitDefrag(bool v) { mask_bits_.omit_defrag = v; } bool OmitDefrag() const { return mask_bits_.omit_defrag; } bool HasStashPending() const { return mask_bits_.io_pending; } void SetStashPending(bool b) { mask_bits_.io_pending = b; } bool IsSticky() const { return mask_bits_.sticky; } void SetSticky(bool e) { mask_bits_.sticky = e; } unsigned Encoding() const; CompactObjType ObjType() const; void* RObjPtr() const { return u_.r_obj.inner_obj(); } void SetRObjPtr(void* ptr) { u_.r_obj.Init(u_.r_obj.type(), u_.r_obj.encoding(), ptr); } // takes ownership over obj_inner. // type should not be OBJ_STRING. void InitRobj(CompactObjType type, unsigned encoding, void* obj_inner); // For STR object. void SetInt(int64_t val); std::optional TryGetInt() const; void GetString(std::string* res) const; void SetString(std::string_view str); void ReserveString(size_t size); void AppendString(std::string_view str); // Will set this to hold OBJ_JSON, after that it is safe to call GetJson // NOTE: in order to avid copy which can be expensive in this case, // you need to move an object that created with the function JsonFromString // into here, no copying is allowed! void SetJson(JsonType&& j); void SetJson(const uint8_t* buf, size_t len); // Adjusts the size used by json void SetJsonSize(int64_t size); // Adjusts the size used by a stream void AddStreamSize(int64_t size); // pre condition - the type here is OBJ_JSON and was set with SetJson JsonType* GetJson() const; void SetSBF(SBF* sbf) { SetMeta(SBF_TAG); u_.sbf = sbf; } void SetSBF(uint64_t initial_capacity, double fp_prob, double grow_factor); SBF* GetSBF() const; void SetTOPK(TOPK* topk) { SetMeta(TOPK_TAG); u_.topk = topk; } void SetTOPK(uint32_t k, uint32_t width, uint32_t depth, double decay); TOPK* GetTOPK() const; void SetCMS(CMS* cms) { SetMeta(CMS_TAG); u_.cms = cms; } void SetCMS(uint32_t width, uint32_t depth); CMS* GetCMS() const; // dest must have at least Size() bytes available void GetString(char* dest) const; bool IsExternal() const { return taglen_ == EXTERNAL_TAG; } // returns true if the value is stored in the cooling storage. Cooling storage has an item both // on disk and in memory. bool IsCool() const { assert(IsExternal()); return u_.ext_ptr.is_cool; } void SetExternal(size_t offset, uint32_t sz, ExternalRep rep); ExternalRep GetExternalRep() const; // Switches to empty, non-external string. // Preserves all the attributes. void RemoveExternal() { encoding_ = NONE_ENC; SetMeta(0, mask_); } // Assigns a cooling record to the object together with its external slice. void SetCool(size_t offset, uint32_t serialized_size, ExternalRep rep, tiering::TieredCoolRecord* record); struct CoolItem { uint16_t page_offset; size_t serialized_size; tiering::TieredCoolRecord* record; }; // Prerequisite: IsCool() is true. // Returns the external data of the object incuding its ColdRecord. CoolItem GetCool() const; // Prequisite: IsCool() is true. // Keeps cool record only as external value and discard in-memory part. void Freeze(size_t offset, size_t sz); std::pair GetExternalSlice() const; // Injects either the the raw string (extracted with GetRawString()) or the usual string // back to the compact object. In the latter case, encoding is performed. // Precondition: The object must be in the EXTERNAL state. // Postcondition: The object is an in-memory string. void Materialize(std::string_view str, bool is_raw); // Returns the approximation of memory used by the object. // If slow is true, may use more expensive methods to calculate the precise size. size_t MallocUsed(bool slow = false) const; // Resets the object to empty state (string). void Reset(); bool IsInline() const { return taglen_ <= kInlineLen; } uint8_t GetFirstByte() const; // Returns true if the byte was decoded successfully, false if idx is out of bounds. bool GetByteAtIndex(size_t idx, uint8_t* res) const; // Returns a pair of booleans: {success, in_place}. success is false if offset is out of bounds // in_place is true if the byte was set without needing to rewrite the string. std::pair SetByteAtIndex(size_t idx, uint8_t val); struct Stats { size_t small_string_bytes = 0; uint64_t huff_encode_total = 0, huff_encode_success = 0; }; static Stats GetStatsThreadLocal(); static void InitThreadLocal(MemoryResource* mr); enum HuffmanDomain : uint8_t { HUFF_KEYS = 0, HUFF_STRING_VALUES = 1, // TODO: add more domains. }; static bool InitHuffmanThreadLocal(HuffmanDomain domain, std::string_view hufftable); static MemoryResource* memory_resource(); // thread-local. template static T* AllocateMR(Args&&... args) { void* ptr = memory_resource()->allocate(sizeof(T), alignof(T)); if constexpr (std::is_constructible_v && sizeof...(args) == 0) return new (ptr) T{memory_resource()}; else return new (ptr) T{std::forward(args)...}; } template static void DeleteMR(void* ptr) { T* t = (T*)ptr; t->~T(); memory_resource()->deallocate(ptr, sizeof(T), alignof(T)); } // Return raw (non-decoded) string as two views. First is guaranteed to be non-empty. // Precondition: the object is a non-inline string. std::array GetRawString() const; StrEncoding GetStrEncoding() const { return StrEncoding{encoding_, is_key_}; } bool HasAllocated() const; bool TagAllowsEmptyValue() const; uint8_t Tag() const { return taglen_; } private: // Returns a string_view corresponding to the serialized encoded blob. // If opt_dest is provided, it may be used to decode directly into the destination buffer. std::string_view GetEncodedBlob(StrEncoding str_encoding, char* opt_dest) const; protected: void EncodeString(std::string_view str); // Requires: HasAllocated() - true. void Free(); bool CmpEncoded(std::string_view sv) const; bool CmpNonInline(std::string_view sv) const; void SetMeta(uint8_t taglen, uint8_t mask = 0) { if (HasAllocated()) { Free(); } else { memset(u_.inline_str, 0, kInlineLen); } taglen_ = taglen; mask_ = mask; } struct ExternalPtr { uint32_t serialized_size; uint16_t page_offset; // 0 for multi-page blobs. != 0 for small blobs. uint8_t is_cool : 1; uint8_t representation : 2; // See ExternalRep uint8_t is_reserved : 5; uint8_t first_byte; // We do not have enough space in the common area to store page_index together with // cool_record pointer. Therefore, we moved this field into TieredCoolRecord itself. struct Offload { uint32_t page_index; uint32_t reserved; }; union { Offload offload; tiering::TieredCoolRecord* cool_record; }; } __attribute__((packed)); static_assert(sizeof(ExternalPtr) == 16); struct SdsTtlString { char* sds_ptr; // SDS string (length via sdslen) uint64_t exp_ms; // absolute expiry time in ms std::string_view view() const; } __attribute__((packed)); struct JsonConsT { JsonType* json_ptr; size_t bytes_used; bool DefragIfNeeded(PageUsage* page_usage); }; struct FlatJsonT { uint32_t json_len; uint8_t* flat_ptr; bool DefragIfNeeded(PageUsage* page_usage); }; struct JsonWrapper { union { JsonConsT cons; FlatJsonT flat; }; bool DefragIfNeeded(PageUsage* page_usage); }; // Union of different representations union U { char inline_str[kInlineLen]; SmallString small_str; detail::RobjWrapper r_obj; // using 'packed' to reduce alignment of U to 1. JsonWrapper json_obj __attribute__((packed)); SBF* sbf __attribute__((packed)); TOPK* topk __attribute__((packed)); CMS* cms __attribute__((packed)); int64_t ival __attribute__((packed)); ExternalPtr ext_ptr; SdsTtlString sds_ttl; U() : r_obj() { } } u_; static_assert(sizeof(u_) == 16); union { uint8_t mask_ = 0; struct { uint8_t ref : 1; // Mark objects that don't own their allocation. uint8_t expire : 1; // Mark objects that have expiry timestamp assigned. uint8_t mc_flag : 1; // Marks keys that have memcache flags assigned. // IO_PENDING is set when the tiered storage has issued an i/o request to save the value. // It is cleared when the io request finishes or is cancelled. uint8_t io_pending : 1; uint8_t sticky : 1; // TOUCHED used to determin which items are hot/cold. // by checking if the item was touched from the last time we // reached this item while travering the database to set items as cold. // https://junchengyang.com/publication/nsdi24-SIEVE.pdf uint8_t touched : 1; // used to mark keys that were accessed. uint8_t omit_defrag : 1; // mark object to skip defragmentation. } mask_bits_; }; // TODO: use c++20 bitfield initializers const bool is_key_ : 1; uint8_t taglen_ : 5; // Either length of inline string or tag of type uint8_t encoding_ : 2; // Encoding of string values }; struct CompactKey : public CompactObj { CompactKey() : CompactObj(true) { } explicit CompactKey(std::string_view str) : CompactObj{str, true} { } CompactKey AsRef() const { CompactKey res; memcpy(&res.u_, &u_, sizeof(u_)); res.encoding_ = encoding_; res.taglen_ = taglen_; res.mask_ = mask_; res.mask_bits_.ref = 1; return res; } bool HasExpire() const { return mask_bits_.expire; } void SetExpire(bool e) { mask_bits_.expire = e; } // Embed expire time directly in the key by converting to SDS_TTL_TAG. void SetExpireTime(uint64_t abs_ms); // Remove embedded expire time and convert back to optimal string form. bool ClearExpireTime(); // Read the embedded expire time. // Returns 0 if there is no embedded expire time, otherwise // returns the absolute expire time in ms. uint64_t GetExpireTime() const; CompactKey& operator=(std::string_view sv) noexcept { SetString(sv); return *this; } bool operator==(const CompactKey& o) const; bool operator==(std::string_view sl) const; bool operator!=(std::string_view sl) const { return !(*this == sl); } friend bool operator!=(const CompactKey& lhs, const CompactKey& rhs) { return !(lhs == rhs); } friend bool operator==(std::string_view sl, const CompactKey& o) { return o.operator==(sl); } }; inline bool CompactKey::operator==(std::string_view sv) const { if (encoding_) return CmpEncoded(sv); if (IsInline()) { return std::string_view{u_.inline_str, taglen_} == sv; } return CmpNonInline(sv); } struct CompactValue : public CompactObj { CompactValue() : CompactObj(false) { } explicit CompactValue(std::string_view str) : CompactObj{str, false} { } }; std::string_view ObjTypeToString(CompactObjType type); // Returns kInvalidCompactObjType if sv is not a valid type. CompactObjType ObjTypeFromString(std::string_view sv); stream* streamNew(); void freeStream(stream* s); } // namespace dfly ================================================ FILE: src/core/compact_object_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/compact_object.h" #include #include #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/detail/bitpacking.h" #include "core/huff_coder.h" #include "core/mi_memory_resource.h" #include "core/page_usage/page_usage_stats.h" #include "core/string_map.h" #include "core/string_set.h" extern "C" { #include "redis/intset.h" #include "redis/redis_aux.h" #include "redis/stream.h" #include "redis/zmalloc.h" } namespace dfly { XXH64_hash_t kSeed = 24061983; constexpr size_t kRandomStartIndex = 24; constexpr size_t kRandomStep = 26; constexpr float kUnderUtilizedRatio = 1.0f; // ensure that we would detect using namespace std; using namespace jsoncons; using namespace jsoncons::jsonpath; void PrintTo(const CompactObj& cobj, std::ostream* os) { if (cobj.ObjType() == OBJ_STRING) { *os << "'" << cobj.ToString() << "' "; return; } *os << "cobj: [" << cobj.ObjType() << "]"; } // This is for the mimalloc test - being able to find an address in memory // where we have memory underutilzation // see issue number 448 (https://github.com/dragonflydb/dragonfly/issues/448) std::vector AllocateForTest(int size, std::size_t allocate_size, int factor1 = 1, int factor2 = 1) { const int kAllocRandomChangeSize = 13; // just some random value std::vector ptrs; for (int index = 0; index < size; index++) { auto alloc_size = index % kAllocRandomChangeSize == 0 ? allocate_size * factor1 : allocate_size * factor2; auto heap_alloc = mi_heap_get_backing(); void* ptr = mi_heap_malloc(heap_alloc, alloc_size); ptrs.push_back(ptr); } return ptrs; } bool HasUnderutilizedMemory(const std::vector& ptrs, float ratio) { PageUsage page_usage{CollectPageStats::NO, ratio}; auto it = std::find_if(ptrs.begin(), ptrs.end(), [&](auto p) { int r = p && page_usage.IsPageForObjectUnderUtilized(p); return r > 0; }); return it != ptrs.end(); } // Go over ptrs vector and free memory at locations every "steps". // This is so that we will trigger the under utilization - some // pages will have "holes" in them and we are expecting to find these pages. void DeallocateAtRandom(size_t steps, std::vector* ptrs) { for (size_t i = kRandomStartIndex; i < ptrs->size(); i += steps) { mi_free(ptrs->at(i)); ptrs->at(i) = nullptr; } } static void InitThreadStructs() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); SmallString::InitThreadLocal(tlh); thread_local MiMemoryResource mi_resource(tlh); CompactObj::InitThreadLocal(&mi_resource); InitTLStatelessAllocMR(&mi_resource); }; static void CheckEverythingDeallocated() { mi_heap_collect(mi_heap_get_backing(), true); auto cb_visit = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { LOG(ERROR) << "Unfreed allocations: block_size " << block_size << ", allocated: " << area->used * block_size; return true; }; mi_heap_visit_blocks(mi_heap_get_backing(), false /* do not visit all blocks*/, cb_visit, nullptr); } class CompactObjectTest : public ::testing::Test { protected: static void SetUpTestSuite() { InitRedisTables(); // to initialize server struct. InitThreadStructs(); } static void TearDownTestSuite() { CheckEverythingDeallocated(); CleanupStatelessAllocMR(); } CompactValue cobj_; CompactKey ckey_; string tmp_; }; TEST_F(CompactObjectTest, WastedMemoryDetection) { size_t allocated = 0, commited = 0, wasted = 0; // By setting the threshold to high value we are expecting // To find locations where we have wasted memory float ratio = 0.8; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, 0); EXPECT_EQ(commited, 0); EXPECT_EQ(wasted, (commited - allocated)); std::size_t allocated_mem = 64; auto* myheap = mi_heap_get_backing(); void* p1 = mi_heap_malloc(myheap, 64); void* ptrs_end[50]; for (size_t i = 0; i < 50; ++i) { ptrs_end[i] = mi_heap_malloc(myheap, 128); allocated_mem += 128; } allocated = commited = wasted = 0; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, allocated_mem); EXPECT_GT(commited, allocated_mem); EXPECT_EQ(wasted, (commited - allocated)); void* ptr[50]; // allocate 50 for (size_t i = 0; i < 50; ++i) { ptr[i] = mi_heap_malloc(myheap, 256); allocated_mem += 256; } // At this point all the blocks has committed > 0 and used > 0 // and since we expecting to find these locations, the size of // wasted == commited memory - allocated memory. allocated = commited = wasted = 0; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, allocated_mem); EXPECT_GT(commited, allocated_mem); EXPECT_EQ(wasted, (commited - allocated)); // free 50/50 - for (size_t i = 0; i < 50; ++i) { mi_free(ptr[i]); allocated_mem -= 256; } // After all the memory at block size 256 is free, we would have commited there // but the used is expected to be 0, so the number now is different from the // case above allocated = commited = wasted = 0; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, allocated_mem); EXPECT_GT(commited, allocated_mem); // since we release all 256 memory block, it should not be counted EXPECT_EQ(wasted, (commited - allocated)); for (size_t i = 0; i < 50; ++i) { mi_free(ptrs_end[i]); } mi_free(p1); // Now that its all freed, we are not expecting to have any wasted memory any more allocated = commited = wasted = 0; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, 0); EXPECT_GT(commited, allocated); EXPECT_EQ(wasted, (commited - allocated)); mi_collect(false); } TEST_F(CompactObjectTest, WastedMemoryDontCount) { // The commited memory per blocks are: // 64bit => 4K // 128bit => 8k // 256 => 16k // and so on, which mean every n * sizeof(ptr) ^ 2 == 2^11*2*(n-1) (where n starts with 1) constexpr std::size_t kExpectedFor256MemWasted = 0x4000; // memory block 256 auto* myheap = mi_heap_get_backing(); size_t allocated = 0, commited = 0, wasted = 0; // By setting the threshold to a very low number // we don't expect to find and locations where memory is wasted float ratio = 0.01; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, 0); EXPECT_EQ(commited, 0); EXPECT_EQ(wasted, (commited - allocated)); std::size_t allocated_mem = 64; void* p1 = mi_heap_malloc(myheap, 64); void* ptrs_end[50]; for (size_t i = 0; i < 50; ++i) { ptrs_end[i] = mi_heap_malloc(myheap, 128); (void)p1; allocated_mem += 128; } void* ptr[50]; // allocate 50 for (size_t i = 0; i < 50; ++i) { ptr[i] = mi_heap_malloc(myheap, 256); allocated_mem += 256; } allocated = commited = wasted = 0; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); // Threshold is low so we are not expecting any wasted memory to be found. EXPECT_EQ(allocated, allocated_mem); EXPECT_GT(commited, allocated_mem); EXPECT_EQ(wasted, 0); // free 50/50 - for (size_t i = 0; i < 50; ++i) { mi_free(ptr[i]); allocated_mem -= 256; } allocated = commited = wasted = 0; zmalloc_get_allocator_wasted_blocks(ratio, &allocated, &commited, &wasted); EXPECT_EQ(allocated, allocated_mem); EXPECT_GT(commited, allocated_mem); // We will detect only wasted memory for block size of // 256 - and all of it is wasted. EXPECT_EQ(wasted, kExpectedFor256MemWasted); // Threshold is low so we are not expecting any wasted memory to be found. for (size_t i = 0; i < 50; ++i) { mi_free(ptrs_end[i]); } mi_free(p1); mi_collect(false); } TEST_F(CompactObjectTest, NonInline) { string s(22, 'a'); CompactKey obj{s}; uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), kSeed); EXPECT_EQ(18261733907982517826UL, expected_val); EXPECT_EQ(expected_val, obj.HashCode()); EXPECT_EQ(s, obj); s.assign(25, 'b'); obj.SetString(s); EXPECT_EQ(s, obj); EXPECT_EQ(s.size(), obj.Size()); } TEST_F(CompactObjectTest, InlineAsciiEncoded) { string s = "key:0000000000000"; uint64_t expected_val = XXH3_64bits_withSeed(s.data(), s.size(), kSeed); CompactValue obj{s}; EXPECT_EQ(expected_val, obj.HashCode()); EXPECT_EQ(s.size(), obj.Size()); } TEST_F(CompactObjectTest, Int) { ckey_.SetString("0"); EXPECT_EQ(0, ckey_.TryGetInt()); EXPECT_EQ(1, ckey_.Size()); EXPECT_EQ(ckey_, "0"); EXPECT_EQ("0", ckey_.GetSlice(&tmp_)); EXPECT_EQ(OBJ_STRING, ckey_.ObjType()); } TEST_F(CompactObjectTest, Expire) { CompactKey key; key.SetExpire(true); key.SetString("42"); EXPECT_EQ(8181779779123079347, key.HashCode()); EXPECT_EQ(OBJ_ENCODING_INT, key.Encoding()); EXPECT_EQ(2, key.Size()); EXPECT_TRUE(key.HasExpire()); } TEST_F(CompactObjectTest, SdsTtlTag) { // 1. Inline key + SetTtl { CompactKey key("hello"); ASSERT_TRUE(key.IsInline()); uint64_t hash_before = key.HashCode(); key.SetExpireTime(1000); EXPECT_TRUE(key.HasExpire()); EXPECT_EQ(1000, key.GetExpireTime()); EXPECT_EQ(hash_before, key.HashCode()); EXPECT_TRUE(key == string_view("hello")); EXPECT_EQ(5, key.Size()); EXPECT_EQ(OBJ_STRING, key.ObjType()); string slice; EXPECT_EQ("hello", key.GetSlice(&slice)); EXPECT_GT(key.MallocUsed(), 0u); } // 2. INT_TAG key + SetTtl { CompactKey key("42"); ASSERT_TRUE(key.TryGetInt().has_value()); uint64_t hash_before = key.HashCode(); key.SetExpireTime(2000); EXPECT_TRUE(key.HasExpire()); EXPECT_EQ(2000, key.GetExpireTime()); EXPECT_TRUE(key == string_view("42")); EXPECT_EQ(hash_before, key.HashCode()); // No longer INT_TAG — TryGetInt should return nullopt. EXPECT_FALSE(key.TryGetInt().has_value()); } // 3. SMALL_TAG key + SetTtl { string s(64, 'x'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'a' + (i % 26); CompactKey key(s); uint64_t hash_before = key.HashCode(); key.SetExpireTime(3000); EXPECT_TRUE(key.HasExpire()); EXPECT_EQ(3000, key.GetExpireTime()); EXPECT_TRUE(key == string_view(s)); EXPECT_EQ(hash_before, key.HashCode()); EXPECT_EQ(s.size(), key.Size()); } // 4. ROBJ_TAG key + SetExpireTime { string s(512, 'z'); for (size_t i = 0; i < s.size(); ++i) s[i] = static_cast(128 + (i % 128)); CompactKey key(s); uint64_t hash_before = key.HashCode(); key.SetExpireTime(4000); EXPECT_TRUE(key.HasExpire()); EXPECT_EQ(4000, key.GetExpireTime()); EXPECT_TRUE(key == string_view(s)); EXPECT_EQ(hash_before, key.HashCode()); EXPECT_EQ(s.size(), key.Size()); } // 5. ExpireTime update in-place { CompactKey key("hello"); key.SetExpireTime(1000); EXPECT_EQ(1000, key.GetExpireTime()); key.SetExpireTime(2000); EXPECT_EQ(2000, key.GetExpireTime()); EXPECT_TRUE(key == string_view("hello")); } // 6. ClearTtl (inline recovery) { CompactKey key("hello"); key.SetExpireTime(1000); EXPECT_TRUE(key.ClearExpireTime()); EXPECT_FALSE(key.HasExpire()); EXPECT_TRUE(key.IsInline()); EXPECT_TRUE(key == string_view("hello")); } // 7. ClearTtl (INT recovery) { CompactKey key("42"); key.SetExpireTime(1000); EXPECT_TRUE(key.ClearExpireTime()); EXPECT_FALSE(key.HasExpire()); EXPECT_TRUE(key.TryGetInt().has_value()); EXPECT_EQ(42, key.TryGetInt().value()); } // 8. ClearTtl (SMALL recovery) { string s(64, 'x'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'a' + (i % 26); CompactKey key(s); key.SetExpireTime(1000); EXPECT_TRUE(key.ClearExpireTime()); EXPECT_FALSE(key.HasExpire()); EXPECT_TRUE(key == string_view(s)); } // 9. Move semantics { CompactKey a("test"); a.SetExpireTime(100); CompactKey b(std::move(a)); EXPECT_TRUE(b.HasExpire()); EXPECT_EQ(100, b.GetExpireTime()); EXPECT_TRUE(b == string_view("test")); } // 10. Free/destructor — just verify no leaks (TearDown catches them). { CompactKey key("hello"); key.SetExpireTime(5000); } // 11. Cross-tag operator== (SDS_TTL_TAG vs inline/INT_TAG). { CompactKey a("hello"); CompactKey b("hello"); b.SetExpireTime(999); // b is SDS_TTL_TAG, a is inline — must compare equal as OBJ_STRING. EXPECT_TRUE(a == b); EXPECT_TRUE(b == a); CompactKey c("42"); CompactKey d("42"); d.SetExpireTime(1); EXPECT_TRUE(c == d); EXPECT_TRUE(d == c); // Different content must not compare equal. CompactKey e("world"); e.SetExpireTime(1); EXPECT_FALSE(a == e); } } TEST_F(CompactObjectTest, MediumString) { string tmp(511, 'b'); cobj_.SetString(tmp); EXPECT_EQ(tmp.size(), cobj_.Size()); cobj_.SetString(tmp); EXPECT_EQ(tmp.size(), cobj_.Size()); cobj_.Reset(); tmp.assign(27463, 'c'); cobj_.SetString(tmp); EXPECT_EQ(27463, cobj_.Size()); } TEST_F(CompactObjectTest, AsciiUtil) { std::string_view data{"aaaaaabb"}; uint8_t buf[32]; char outbuf[32] = "xxxxxxxxxxxxxx"; detail::ascii_pack_simd(data.data(), 7, buf); detail::ascii_unpack_simd(buf, 7, outbuf); ASSERT_EQ('x', outbuf[7]) << outbuf; std::string_view actual{outbuf, 7}; ASSERT_EQ(data.substr(0, 7), actual); string data3; for (unsigned i = 0; i < 13; ++i) { data3.append("12345678910"); } string act_str(data3.size(), 'y'); std::vector binvec(detail::binpacked_len(data3.size())); detail::ascii_pack_simd2(data3.data(), data3.size(), binvec.data()); detail::ascii_unpack_simd(binvec.data(), data3.size(), act_str.data()); ASSERT_EQ(data3, act_str); } TEST_F(CompactObjectTest, AsciiPackByte) { // Test ascii_pack_byte and ascii_unpack_byte for correctness. for (size_t len : {8, 16, 24, 31, 32, 33, 64, 100}) { string original(len, 'a'); for (size_t i = 0; i < len; ++i) original[i] = 'A' + (i % 26); size_t packed_len = detail::binpacked_len(len); vector packed(packed_len); detail::ascii_pack(original.data(), len, packed.data()); // Verify initial pack/unpack round-trip at byte level. for (size_t i = 0; i < len; ++i) { uint8_t got = detail::ascii_unpack_byte(packed.data(), len, i); ASSERT_EQ(static_cast(original[i]), got) << "len=" << len << " offset=" << i; } // Now set each byte to a different value via ascii_pack_byte, verify round-trip. for (size_t i = 0; i < len; ++i) { uint8_t new_val = 'a' + ((i + 3) % 26); // Pack the full string, then modify one byte. vector modified(packed); detail::ascii_pack_byte(modified.data(), len, i, new_val); // The modified byte should read back correctly. uint8_t got = detail::ascii_unpack_byte(modified.data(), len, i); EXPECT_EQ(new_val, got) << "len=" << len << " set offset=" << i; // All other bytes should be unchanged. for (size_t j = 0; j < len; ++j) { if (j == i) continue; uint8_t other = detail::ascii_unpack_byte(modified.data(), len, j); EXPECT_EQ(static_cast(original[j]), other) << "len=" << len << " set offset=" << i << " check offset=" << j; } } // Test setting all bytes to zero (edge case: clearing bits). { vector zeroed(packed); string expected = original; for (size_t i = 0; i < len; ++i) { detail::ascii_pack_byte(zeroed.data(), len, i, 0); expected[i] = '\0'; } for (size_t i = 0; i < len; ++i) { uint8_t got = detail::ascii_unpack_byte(zeroed.data(), len, i); EXPECT_EQ(0, got) << "len=" << len << " zero check offset=" << i; } } // Test setting all bytes to 0x7F (all bits set in 7-bit ASCII). { vector maxed(packed); for (size_t i = 0; i < len; ++i) { detail::ascii_pack_byte(maxed.data(), len, i, 0x7F); } for (size_t i = 0; i < len; ++i) { uint8_t got = detail::ascii_unpack_byte(maxed.data(), len, i); EXPECT_EQ(0x7F, got) << "len=" << len << " max check offset=" << i; } } } } TEST_F(CompactObjectTest, IntSet) { intset* is = intsetNew(); cobj_.InitRobj(OBJ_SET, kEncodingIntSet, is); EXPECT_EQ(0, cobj_.Size()); is = (intset*)cobj_.RObjPtr(); uint8_t success = 0; is = intsetAdd(is, 10, &success); EXPECT_EQ(1, success); is = intsetAdd(is, 10, &success); EXPECT_EQ(0, success); cobj_.SetRObjPtr(is); EXPECT_GT(cobj_.MallocUsed(), 0); } TEST_F(CompactObjectTest, ZSet) { // unrelated, checking that sds static encoding works. // it is used in zset special strings. char kMinStrData[] = "\110" "minstring"; EXPECT_EQ(9, sdslen(kMinStrData + 1)); cobj_.InitRobj(OBJ_ZSET, OBJ_ENCODING_LISTPACK, lpNew(0)); EXPECT_EQ(OBJ_ZSET, cobj_.ObjType()); EXPECT_EQ(OBJ_ENCODING_LISTPACK, cobj_.Encoding()); } TEST_F(CompactObjectTest, Hash) { uint8_t* lp = lpNew(0); lp = lpAppend(lp, reinterpret_cast("foo"), 3); lp = lpAppend(lp, reinterpret_cast("barrr"), 5); cobj_.InitRobj(OBJ_HASH, kEncodingListPack, lp); EXPECT_EQ(OBJ_HASH, cobj_.ObjType()); EXPECT_EQ(1, cobj_.Size()); } TEST_F(CompactObjectTest, SBF) { cobj_.SetSBF(1000, 0.001, 2); EXPECT_EQ(cobj_.ObjType(), OBJ_SBF); EXPECT_GT(cobj_.MallocUsed(), 0); } TEST_F(CompactObjectTest, MimallocUnderutilzation) { // We are testing with the same object size allocation here // This test is for https://github.com/dragonflydb/dragonfly/issues/448 size_t allocation_size = 94; int count = 2000; std::vector ptrs = AllocateForTest(count, allocation_size); bool found = HasUnderutilizedMemory(ptrs, kUnderUtilizedRatio); ASSERT_FALSE(found); DeallocateAtRandom(kRandomStep, &ptrs); found = HasUnderutilizedMemory(ptrs, kUnderUtilizedRatio); ASSERT_TRUE(found); for (auto* ptr : ptrs) { mi_free(ptr); } } TEST_F(CompactObjectTest, MimallocUnderutilzationDifferentSizes) { // This test uses different objects sizes to cover more use cases // related to issue https://github.com/dragonflydb/dragonfly/issues/448 size_t allocation_size = 97; int count = 2000; int mem_factor_1 = 3; int mem_factor_2 = 2; std::vector ptrs = AllocateForTest(count, allocation_size, mem_factor_1, mem_factor_2); bool found = HasUnderutilizedMemory(ptrs, kUnderUtilizedRatio); ASSERT_FALSE(found); DeallocateAtRandom(kRandomStep, &ptrs); found = HasUnderutilizedMemory(ptrs, kUnderUtilizedRatio); ASSERT_TRUE(found); for (auto* ptr : ptrs) { mi_free(ptr); } } TEST_F(CompactObjectTest, MimallocUnderutilzationWithRealloc) { // This test is checking underutilzation with reallocation as well as deallocation // of the memory - see issue https://github.com/dragonflydb/dragonfly/issues/448 size_t allocation_size = 102; int count = 2000; int mem_factor_1 = 4; int mem_factor_2 = 1; std::vector ptrs = AllocateForTest(count, allocation_size, mem_factor_1, mem_factor_2); bool found = HasUnderutilizedMemory(ptrs, kUnderUtilizedRatio); ASSERT_FALSE(found); DeallocateAtRandom(kRandomStep, &ptrs); // This is another case, where we are filling the "gaps" by doing re-allocations // in this case, since we are not setting all the values back it should still have // places that are not used. Plus since we are not looking at the first page // other pages should be underutilized. for (size_t i = kRandomStartIndex; i < ptrs.size(); i += kRandomStep) { if (!ptrs[i]) { ptrs[i] = mi_heap_malloc(mi_heap_get_backing(), allocation_size); } } found = HasUnderutilizedMemory(ptrs, kUnderUtilizedRatio); ASSERT_TRUE(found); for (auto* ptr : ptrs) { mi_free(ptr); } } TEST_F(CompactObjectTest, JsonTypeTest) { using namespace jsoncons; // This test verify that we can set a json type // and that we "know", it JSON and not a string std::string_view json_str = R"( {"firstName":"John","lastName":"Smith","age":27,"weight":135.25,"isAlive":true, "address":{"street":"21 2nd Street","city":"New York","state":"NY","zipcode":"10021-3100"}, "phoneNumbers":[{"type":"home","number":"212 555-1234"},{"type":"office","number":"646 555-4567"}], "children":[],"spouse":null} )"; std::optional json_option2 = ParseJsonUsingShardHeap(R"({"a":{}, "b":{"a":1}, "c":{"a":1, "b":2}})"); cobj_.SetString(json_str); ASSERT_TRUE(cobj_.ObjType() == OBJ_STRING); // we set this as a string JsonType* failed_json = cobj_.GetJson(); ASSERT_TRUE(failed_json == nullptr); ASSERT_TRUE(cobj_.ObjType() == OBJ_STRING); std::optional json_option = ParseJsonUsingShardHeap(json_str); ASSERT_TRUE(json_option.has_value()); cobj_.SetJson(std::move(json_option.value())); ASSERT_TRUE(cobj_.ObjType() == OBJ_JSON); // and now this is a JSON type JsonType* json = cobj_.GetJson(); ASSERT_TRUE(json != nullptr); ASSERT_TRUE(json->contains("firstName")); // set second object make sure that we don't have any memory issue ASSERT_TRUE(json_option2.has_value()); cobj_.SetJson(std::move(json_option2.value())); ASSERT_TRUE(cobj_.ObjType() == OBJ_JSON); // still is a JSON type json = cobj_.GetJson(); ASSERT_TRUE(json != nullptr); ASSERT_TRUE(json->contains("b")); ASSERT_FALSE(json->contains("firstName")); std::optional set_array = ParseJsonUsingShardHeap(""); // now set it to string again cobj_.SetString(R"({"a":{}, "b":{"a":1}, "c":{"a":1, "b":2}})"); ASSERT_TRUE(cobj_.ObjType() == OBJ_STRING); // we set this as a string failed_json = cobj_.GetJson(); ASSERT_TRUE(failed_json == nullptr); } TEST_F(CompactObjectTest, JsonTypeWithPathTest) { std::string_view books_json = R"({"books":[{ "category": "fiction", "title" : "A Wild Sheep Chase", "author" : "Haruki Murakami" },{ "category": "fiction", "title" : "The Night Watch", "author" : "Sergei Lukyanenko" },{ "category": "fiction", "title" : "The Comedians", "author" : "Graham Greene" },{ "category": "memoir", "title" : "The Night Watch", "author" : "Phillips, David Atlee" }]})"; std::optional json_array = ParseJsonUsingShardHeap(books_json); ASSERT_TRUE(json_array.has_value()); cobj_.SetJson(std::move(json_array.value())); ASSERT_TRUE(cobj_.ObjType() == OBJ_JSON); // and now this is a JSON type auto f = [](const auto& /*path*/, JsonType& book) { if (book.at("category") == "memoir" && !book.contains("price")) { book.try_emplace("price", 140.0); } }; JsonType* json = cobj_.GetJson(); ASSERT_TRUE(json != nullptr); auto allocator_set = jsoncons::combine_allocators(json->get_allocator()); jsonpath::json_replace(allocator_set, *json, "$.books[*]"sv, f); // Check whether we've changed the entry for json in place // we should have prices only for memoir books JsonType* json2 = cobj_.GetJson(); ASSERT_TRUE(json != nullptr); ASSERT_TRUE(json->contains("books")); for (auto&& book : (*json2)["books"].array_range()) { // make sure that we add prices only to "memoir" if (book.at("category") == "memoir") { ASSERT_TRUE(book.contains("price")); } else { ASSERT_FALSE(book.contains("price")); } } } // Test listpack defragmentation. // StringMap has built-in defragmantation that is tested in its own test suite. TEST_F(CompactObjectTest, DefragHash) { auto build_str = [](size_t i) { return string(111, 'v') + to_string(i); }; vector lps(10'00); for (size_t i = 0; i < lps.size(); i++) { uint8_t* lp = lpNew(100); for (size_t j = 0; j < 100; j++) { auto s = build_str(j); lp = lpAppend(lp, reinterpret_cast(s.data()), s.length()); } DCHECK_EQ(lpLength(lp), 100u); lps[i] = lp; } for (size_t i = 0; i < lps.size(); i++) { if (i % 10 == 0) continue; lpFree(lps[i]); } // Find a listpack that is located on a underutilized page uint8_t* target_lp = nullptr; PageUsage page_usage{CollectPageStats::NO, 0.8}; for (size_t i = 0; i < lps.size(); i += 10) { if (page_usage.IsPageForObjectUnderUtilized(lps[i])) target_lp = lps[i]; } CHECK_NE(target_lp, nullptr); // Trigger re-allocation cobj_.InitRobj(OBJ_HASH, kEncodingListPack, target_lp); ASSERT_TRUE(cobj_.DefragIfNeeded(&page_usage)); // Check the pointer changes as the listpack needed defragmentation auto lp = (uint8_t*)cobj_.RObjPtr(); EXPECT_NE(lp, target_lp) << "must have changed due to realloc"; uint8_t* fptr = lpFirst(lp); for (size_t i = 0; i < 100; i++) { int64_t len; auto* s = lpGet(fptr, &len, nullptr); string_view sv{reinterpret_cast(s), static_cast(len)}; EXPECT_EQ(sv, build_str(i)); fptr = lpNext(lp, fptr); } for (size_t i = 0; i < lps.size(); i += 10) { if (lps[i] != target_lp) lpFree(lps[i]); } } TEST_F(CompactObjectTest, DefragSet) { // This is still not implemented StringSet* s = CompactObj::AllocateMR(); s->Add("str"); cobj_.InitRobj(OBJ_SET, kEncodingStrMap2, s); PageUsage page_usage{CollectPageStats::NO, 0.8}; ASSERT_FALSE(cobj_.DefragIfNeeded(&page_usage)); } TEST_F(CompactObjectTest, StrEncodingAndMaterialize) { for (bool ascii : {true, false}) { for (size_t len : {64, 128, 256, 512, 1024}) { string test_str(len, 'a'); for (size_t i = 0; i < len; i++) test_str[i] = char('a' + (i % 10)); if (!ascii) test_str.push_back(char(200)); // non-ascii CompactValue obj; obj.SetString(test_str); // Test StrEncoding helper auto strs = obj.GetRawString(); string raw_str = string{strs[0]} + string{strs[1]}; CompactObj::StrEncoding enc = obj.GetStrEncoding(); EXPECT_EQ(test_str, enc.Decode(raw_str).Take()); // Test Materialize obj.SetExternal(0, 0, CompactObj::ExternalRep::STRING); // dummy values obj.Materialize(raw_str, true); EXPECT_EQ(test_str, obj.ToString()); // Restore from external again, but not as a raw value obj.SetExternal(0, 0, CompactObj::ExternalRep::STRING); auto test_str2 = test_str + "updated"; obj.Materialize(test_str2, false); EXPECT_EQ(obj.ToString(), test_str2); } } } TEST_F(CompactObjectTest, ExternalRepresentation) { { CompactValue obj; obj.SetString("test"); obj.SetExternal(0, 4, CompactObj::ExternalRep::STRING); EXPECT_EQ(obj.ObjType(), OBJ_STRING); } { StringMap sm{}; CompactValue obj; obj.SetRObjPtr(&sm); obj.SetExternal(0, 4, CompactObj::ExternalRep::SERIALIZED_MAP); EXPECT_EQ(obj.ObjType(), OBJ_HASH); } } TEST_F(CompactObjectTest, AsanTriggerReadOverflow) { cobj_.SetString(string(32, 'a')); auto dest = make_unique(32); cobj_.GetString(dest.get()); } TEST_F(CompactObjectTest, lpGetInteger) { int64_t val = -1; uint8_t* lp = lpNew(0); for (int j = 0; j < 60; ++j) { lp = lpAppendInteger(lp, val); val *= 2; } val = 1; for (int j = 0; j < 600; ++j) { string str(j * 500, 'a'); lp = lpAppend(lp, reinterpret_cast(str.data()), str.size()); } uint8_t* ptr = lpFirst(lp); while (ptr) { int64_t len1, len2; uint8_t* val1 = lpGet(ptr, &len1, nullptr); int res = lpGetInteger(ptr, &len2); if (res) { ASSERT_EQ(len1, len2); ASSERT_TRUE(val1 == NULL); } else { ASSERT_TRUE(val1 != NULL); } ptr = lpNext(lp, ptr); } lpFree(lp); } static void BuildEncoderAB(HuffmanEncoder* encoder) { array hist; hist.fill(1); hist['a'] = 100; hist['b'] = 50; CHECK(encoder->Build(hist.data(), hist.size() - 1, nullptr)); } TEST_F(CompactObjectTest, Huffman) { HuffmanEncoder encoder; BuildEncoderAB(&encoder); string bindata = encoder.Export(); for (CompactObj::HuffmanDomain domain : {CompactObj::HUFF_KEYS, CompactObj::HUFF_STRING_VALUES}) { ASSERT_TRUE(CompactObj::InitHuffmanThreadLocal(domain, bindata)); for (unsigned i = 30; i < 2048; i += 10) { string data(i, 'a'); variant obj_backing; if (domain) obj_backing = CompactValue{}; auto& cobj = visit([&](auto& co) -> CompactObj& { return co; }, obj_backing); visit([&](auto& co) { co.SetString(data); }, obj_backing); bool malloc_used = i >= 60; ASSERT_EQ(malloc_used, cobj.MallocUsed() > 0) << i; ASSERT_EQ(data.size(), cobj.Size()); ASSERT_EQ(CompactObj::HashCode(data), cobj.HashCode()); string actual; cobj.GetString(&actual); EXPECT_EQ(data, actual); visit(absl::Overload{[&](CompactKey& co) { EXPECT_EQ(co, data); }, [&](CompactValue& co) {}}, obj_backing); } } } TEST_F(CompactObjectTest, GetByteAtOffset) { // Inline string (INLINE_TAG) { string s = "hello"; cobj_.SetString(s); for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(s[i], res) << "inline offset " << i; } } // Integer-encoded string (INT_TAG) { cobj_.SetString("12345"); string expected = "12345"; for (size_t i = 0; i < expected.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(expected[i], res) << "int offset " << i; } } // ASCII string with SMALL_TAG { string s(64, 'x'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'a' + (i % 26); cobj_.SetString(s); for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "long ascii offset " << i; } } // Non-ASCII string with SMALL_TAG { string s(64, '\xC0'); for (size_t i = 0; i < s.size(); ++i) s[i] = static_cast(128 + (i % 128)); cobj_.SetString(s); for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "non-ascii offset " << i; } } // ASCII string ROBJ_TAG { string s(512, 'z'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'A' + (i % 26); cobj_.SetString(s); for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "medium offset " << i; } } // Non-ASCII string ROBJ_TAG { string s(512, 'z'); for (size_t i = 0; i < s.size(); ++i) s[i] = static_cast(128 + (i % 128)); cobj_.SetString(s); for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "medium offset " << i; } } cobj_.Reset(); } TEST_F(CompactObjectTest, SetByteAtOffset) { // Inline string (INLINE_TAG) { string s = "abcde"; cobj_.SetString(s); for (size_t i = 0; i < s.size(); ++i) { std::pair res_set_byte = cobj_.SetByteAtIndex(i, 'Z'); EXPECT_TRUE(res_set_byte.first); EXPECT_TRUE(res_set_byte.second); uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ('Z', res) << "inline set offset " << i; } // All bytes should now be 'Z' string result; cobj_.GetString(&result); EXPECT_EQ(string(5, 'Z'), result); } // Integer-encoded string (INT_TAG) { cobj_.SetString("999"); std::pair res_set_byte = cobj_.SetByteAtIndex(0, 'x'); EXPECT_TRUE(res_set_byte.first); // We didn't modify in-place, SetString is called EXPECT_FALSE(res_set_byte.second); string result; cobj_.GetString(&result); EXPECT_EQ("x99", result); } // ASCII string with SMALL_TAG { string s(64, 'a'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'a' + (i % 26); cobj_.SetString(s); // Modify every 10th byte for (size_t i = 0; i < s.size(); i += 10) { std::pair res_set_byte = cobj_.SetByteAtIndex(i, '!'); EXPECT_TRUE(res_set_byte.first); EXPECT_FALSE(res_set_byte.second); s[i] = '!'; } // Verify all bytes for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "long ascii set offset " << i; } } // Non-ASCII string with SMALL_TAG { string s(64, '\x80'); for (size_t i = 0; i < s.size(); ++i) s[i] = static_cast(128 + (i % 128)); cobj_.SetString(s); std::pair res_set_byte = cobj_.SetByteAtIndex(63, 0xFF); EXPECT_TRUE(res_set_byte.first); EXPECT_FALSE(res_set_byte.second); s[63] = '\xFF'; for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "non-ascii set offset " << i; } } // ASCII string with ROBJ_TAG { string s(512, 'a'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'a' + (i % 26); cobj_.SetString(s); // Modify every 10th byte for (size_t i = 0; i < s.size(); i += 10) { std::pair res_set_byte = cobj_.SetByteAtIndex(i, '!'); EXPECT_TRUE(res_set_byte.first); EXPECT_TRUE(res_set_byte.second); s[i] = '!'; } // Verify all bytes for (size_t i = 0; i < s.size(); ++i) { uint8_t res = 0; EXPECT_TRUE(cobj_.GetByteAtIndex(i, &res)); EXPECT_EQ(static_cast(s[i]), res) << "long ascii set offset " << i; } } // ASCII string with ROBJ_TAG modified to non-ASCII { string s(512, 'a'); for (size_t i = 0; i < s.size(); ++i) s[i] = 'a' + (i % 26); cobj_.SetString(s); // Modify in-place ascii packed string std::pair res_set_byte = cobj_.SetByteAtIndex(0, 'A'); EXPECT_TRUE(res_set_byte.first); EXPECT_TRUE(res_set_byte.second); // Adding non-ascii byte modification should still succeed, but not in-place res_set_byte = cobj_.SetByteAtIndex(255, 0xFF); EXPECT_TRUE(res_set_byte.first); EXPECT_FALSE(res_set_byte.second); // Modification of non-ascii ROBJ string should succeed and in-place res_set_byte = cobj_.SetByteAtIndex(511, 'C'); EXPECT_TRUE(res_set_byte.first); EXPECT_TRUE(res_set_byte.second); uint8_t res; EXPECT_TRUE(cobj_.GetByteAtIndex(0, &res)); EXPECT_EQ('A', res); EXPECT_TRUE(cobj_.GetByteAtIndex(255, &res)); EXPECT_EQ(0xFF, res); EXPECT_TRUE(cobj_.GetByteAtIndex(511, &res)); EXPECT_EQ('C', res); } // Out-of-bounds access should be handled gracefully. { string s = "abc"; cobj_.SetString(s); // SetByteAtIndex: index equal to size() is out-of-bounds. auto res_pair = cobj_.SetByteAtIndex(s.size(), 'X'); EXPECT_FALSE(res_pair.first); EXPECT_FALSE(res_pair.second); // GetByteAtIndex: out-of-bounds should set result to 0. uint8_t res = 123; // sentinel non-zero value EXPECT_FALSE(cobj_.GetByteAtIndex(s.size(), &res)); EXPECT_EQ(0u, res); } cobj_.Reset(); } static void ascii_pack_naive(const char* ascii, size_t len, uint8_t* bin) { const char* end = ascii + len; unsigned i = 0; while (ascii + 8 <= end) { for (i = 0; i < 7; ++i) { *bin++ = (ascii[0] >> i) | (ascii[1] << (7 - i)); ++ascii; } ++ascii; } // epilog - we do not pack since we have less than 8 bytes. while (ascii < end) { *bin++ = *ascii++; } } static void BM_PackNaive(benchmark::State& state) { string val(1024, 'a'); uint8_t buf[1024]; while (state.KeepRunning()) { ascii_pack_naive(val.data(), val.size(), buf); } } BENCHMARK(BM_PackNaive); static void BM_Pack(benchmark::State& state) { string val(1024, 'a'); uint8_t buf[1024]; while (state.KeepRunning()) { detail::ascii_pack(val.data(), val.size(), buf); } } BENCHMARK(BM_Pack); static void BM_PackSimd(benchmark::State& state) { string val(1024, 'a'); uint8_t buf[1024]; while (state.KeepRunning()) { detail::ascii_pack_simd(val.data(), val.size(), buf); } } BENCHMARK(BM_PackSimd); static void BM_PackSimd2(benchmark::State& state) { string val(1024, 'a'); uint8_t buf[1024]; while (state.KeepRunning()) { detail::ascii_pack_simd2(val.data(), val.size(), buf); } } BENCHMARK(BM_PackSimd2); static void BM_Unpack(benchmark::State& state) { string val(1024, 'a'); uint8_t buf[1024]; detail::ascii_pack(val.data(), val.size(), buf); while (state.KeepRunning()) { detail::ascii_unpack(buf, val.size(), val.data()); } } BENCHMARK(BM_Unpack); static void BM_UnpackSimd(benchmark::State& state) { string val(1024, 'a'); uint8_t buf[1024]; detail::ascii_pack(val.data(), val.size(), buf); while (state.KeepRunning()) { detail::ascii_unpack_simd(buf, val.size(), val.data()); } } BENCHMARK(BM_UnpackSimd); static void BM_LpCompare(benchmark::State& state) { std::mt19937_64 rd; uint8_t* lp = lpNew(0); for (unsigned i = 0; i < 100; ++i) { lp = lpAppendInteger(lp, rd() % (1ULL << 48)); } string val = absl::StrCat(1ULL << 49); while (state.KeepRunning()) { uint8_t* elem = lpLast(lp); while (elem) { lpCompare(elem, reinterpret_cast(val.data()), val.size()); elem = lpPrev(lp, elem); } } lpFree(lp); } BENCHMARK(BM_LpCompare); static void BM_LpCompareInt(benchmark::State& state) { std::mt19937_64 rd; uint8_t* lp = lpNew(0); for (unsigned i = 0; i < 100; ++i) { lp = lpAppendInteger(lp, rd() % (1ULL << 48)); } int64_t val = 1ULL << 49; while (state.KeepRunning()) { uint8_t* elem = lpLast(lp); int64_t sz; while (elem) { DCHECK_NE(0xFF, *elem); lpGetInteger(elem, &sz); int res = sz == val; benchmark::DoNotOptimize(res); elem = lpPrev(lp, elem); } } lpFree(lp); } BENCHMARK(BM_LpCompareInt); static void BM_LpGet(benchmark::State& state) { unsigned version = state.range(0); uint8_t* lp = lpNew(0); int64_t val = -1; for (unsigned i = 0; i < 60; ++i) { lp = lpAppendInteger(lp, val); val *= 2; } while (state.KeepRunning()) { uint8_t* elem = lpLast(lp); int64_t ival; if (version == 1) { while (elem) { unsigned char* value = lpGet(elem, &ival, NULL); benchmark::DoNotOptimize(value); elem = lpPrev(lp, elem); } } else { while (elem) { int res = lpGetInteger(elem, &ival); benchmark::DoNotOptimize(res); elem = lpPrev(lp, elem); } } } lpFree(lp); } BENCHMARK(BM_LpGet)->Arg(1)->Arg(2); extern "C" int lpStringToInt64(const char* s, unsigned long slen, int64_t* value); static void BM_LpString2Int(benchmark::State& state) { int version = state.range(0); std::mt19937_64 rd; vector values; for (unsigned i = 0; i < 1000; ++i) { int64_t val = rd(); values.push_back(absl::StrCat(val)); } int64_t ival = 0; while (state.KeepRunning()) { for (const auto& val : values) { int res = version == 1 ? lpStringToInt64(val.data(), val.size(), &ival) : absl::SimpleAtoi(val, &ival); benchmark::DoNotOptimize(res); } } } BENCHMARK(BM_LpString2Int)->Arg(1)->Arg(2); } // namespace dfly ================================================ FILE: src/core/dash.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "absl/random/random.h" #include "base/pmr/memory_resource.h" #include "core/dash_internal.h" namespace dfly { // DASH: Dynamic And Scalable Hashing. template class DashTable : public detail::DashTableBase { DashTable(const DashTable&) = delete; DashTable& operator=(const DashTable&) = delete; using Base = detail::DashTableBase; using SegmentType = detail::Segment<_Key, _Value, Policy>; using SegmentIterator = typename SegmentType::Iterator; public: using Key_t = _Key; using Value_t = _Value; using Segment_t = SegmentType; //! Total number of buckets in a segment (including stash). static constexpr double kTaxAmount = SegmentType::kTaxSize; static constexpr size_t kSegBytes = sizeof(SegmentType); // How many bytes the non-stash part is taking. static constexpr size_t kSegRegularBytes = kSegBytes - (SegmentType::kStashBucketNum * SegmentType::kBucketSz); static constexpr size_t kSegCapacity = SegmentType::capacity(); static constexpr size_t kSlotNum = SegmentType::kSlotNum; static constexpr size_t kBucketNum = SegmentType::kBucketNum; // if IsSingleBucket is true - iterates only over a single bucket. template class Iterator; using const_iterator = Iterator; using iterator = Iterator; using const_bucket_iterator = Iterator; using bucket_iterator = Iterator; using Cursor = detail::DashCursor; struct HotBuckets { static constexpr size_t kRegularBuckets = 4; static constexpr size_t kNumBuckets = kRegularBuckets + SegmentType::kStashBucketNum; struct ByType { bucket_iterator regular_buckets[kRegularBuckets]; bucket_iterator stash_buckets[SegmentType::kStashBucketNum]; }; union Probes { ByType by_type; bucket_iterator arr[kNumBuckets]; Probes() : arr() { } } probes; // id must be in the range [0, kNumBuckets). bucket_iterator at(unsigned id) const { return probes.arr[id]; } unsigned num_buckets; // key_hash of a key that we try to insert. // I use it as pseudo-random number in my gc/eviction heuristics. uint64_t key_hash; }; struct DefaultEvictionPolicy { static constexpr bool can_gc = false; static constexpr bool can_evict = false; bool CanGrow(const DashTable&) { return true; } void OnMove(Cursor source, Cursor dest) { } void RecordSplit(SegmentType* segment) { } /* /// Required interface in case can_gc is true // Returns number of garbage collected items deleted. 0 - means nothing has been // deleted. unsigned GarbageCollect(const EvictionBuckets& eb, DashTable* me) const { return 0; } // Required interface in case can_gc is true // returns number of items evicted from the table. // 0 means - nothing has been evicted. unsigned Evict(const EvictionBuckets& eb, DashTable* me) { return 0; } */ }; DashTable(size_t capacity_log = 1, const Policy& policy = Policy{}, PMR_NS::memory_resource* mr = PMR_NS::get_default_resource()); ~DashTable(); void Reserve(size_t size); // false for duplicate, true if inserted. template std::pair Insert(U&& key, V&& value) { DefaultEvictionPolicy policy; return InsertInternal(std::forward(key), std::forward(value), policy, InsertMode::kInsertIfNotFound); } template std::pair Insert(U&& key, V&& value, EvictionPolicy& ev) { return InsertInternal(std::forward(key), std::forward(value), ev, InsertMode::kInsertIfNotFound); } template iterator InsertNew(U&& key, V&& value) { DefaultEvictionPolicy policy; return InsertNew(std::forward(key), std::forward(value), policy); } template iterator InsertNew(U&& key, V&& value, EvictionPolicy& ev) { return InsertInternal(std::forward(key), std::forward(value), ev, InsertMode::kForceInsert) .first; } template const_iterator Find(U&& key) const; template iterator Find(U&& key); // Prefetches the memory where the key would resize into the cache. template void Prefetch(U&& key) const; // Find first entry with given key hash that evaulates to true on pred. // Pred accepts either (const key&) or (const key&, const value&) template iterator FindFirst(uint64_t key_hash, Pred&& pred); // it must be valid. void Erase(iterator it); size_t Erase(const Key_t& k); iterator begin() { iterator it{this, 0, 0, 0}; it.Seek2Occupied(); return it; } const_iterator cbegin() const { const_iterator it{this, 0, 0, 0}; it.Seek2Occupied(); return it; } iterator end() const { return iterator{}; } const_iterator cend() const { return const_iterator{}; } using Base::depth; using Base::Empty; using Base::size; using Base::unique_segments; // Direct access to the segment for debugging purposes. Segment_t* GetSegment(unsigned segment_id) { return segment_[segment_id]; } // - If there is no buddy for segment_id return segment_id. // Otherwise, return buddy_id. // - A buddy is a sibling segment that was created from the // same parent during split and can be merged back together. // It's the adjacent subtree of the same depth. unsigned FindBuddyId(unsigned segment_id) { auto* seg = GetSegment(segment_id); uint8_t depth = seg->local_depth(); if (depth <= 1) { return segment_id; } const size_t bit_pos = global_depth_ - depth; const size_t buddy_idx = segment_id ^ (1u << bit_pos); assert(buddy_idx < segment_.size()); auto* buddy = GetSegment(buddy_idx); // There is no adjacent subtree of the same depth if (buddy->local_depth() != depth) { return segment_id; } return buddy_idx; } // - Moves all items from `buddy_id` to `keep_id` (merges the two segments). // After merge completes, `buddy_id` segment is deleted. // - Return true if the two segments merged successfully. // - If an insertion fails we rollback and abort the merge (return false). // - Merge can run only if there are no active snapshots. // - Prefer calling this function only when the combined size of both segments // than x * segment_capacity. With x: 0 < x < 0.25 as statistically this won't // trigger rollbacks. bool Merge(unsigned keep_id, unsigned buddy_id) { auto* keep = GetSegment(keep_id); auto* buddy = GetSegment(buddy_id); assert((keep->local_depth() == buddy->local_depth())); // assert((keep->SlowSize() + buddy->SlowSize() < (0.25 * buddy->capacity()))); assert(keep->local_depth() != 1); assert(keep != buddy); assert(keep_id < buddy_id); // Callers must iterate low to high to ensure correct orientation // Don't merge below initial_depth to maintain Clear() invariant // After merge, keep will have depth-1, which determines unique_segments uint8_t depth_after_merge = keep->local_depth() - 1; if (depth_after_merge < initial_depth_) { return false; } bool should_rollback = false; // Decrease depth (merge back to parent) keep->set_local_depth(keep->local_depth() - 1); // Move all items from buddy to keep buddy->TraverseAll([&](const auto& it) { if (should_rollback) { return; } uint64_t hash = DoHash(buddy->Key(it.index, it.slot)); auto& src_bucket = buddy->GetBucket(it.index); auto res = keep->InsertUniq(std::move(src_bucket.key[it.slot]), std::move(src_bucket.value[it.slot]), hash, false, [](auto&&...) {}); if (!res.found()) { should_rollback = true; return; } // Clear the slot in buddy so rollback can reuse the space src_bucket.Delete(it.slot); }); if (should_rollback) { auto hash_fn = [this](const auto& k) { return policy_.HashFn(k); }; keep->Split(hash_fn, buddy, [](auto&&...) {}); return false; } // Same as Split() uint32_t buddy_chunk_size = 1u << (global_depth_ - buddy->local_depth()); uint32_t buddy_start = buddy_id & ~(buddy_chunk_size - 1u); for (size_t i = buddy_start; i < buddy_start + buddy_chunk_size; ++i) { segment_[i] = keep; } // Free buddy segment PMR_NS::polymorphic_allocator pa(segment_.get_allocator()); using alloc_traits = std::allocator_traits; alloc_traits::destroy(pa, buddy); alloc_traits::deallocate(pa, buddy, 1); // Decrement unique segment counter --unique_segments_; bucket_count_ -= keep->num_buckets(); return true; } size_t GetSegmentCount() const { return segment_.size(); } size_t NextSeg(size_t sid) const { size_t delta = (1u << (global_depth_ - segment_[sid]->local_depth())); return sid + delta; } template uint64_t DoHash(const U& k) const { return policy_.HashFn(k); } // Flat memory usage (allocated) of the table, not including the the memory allocated // by the hosted objects. size_t mem_usage() const { return segment_.capacity() * sizeof(void*) + sizeof(SegmentType) * unique_segments_; } // Returns the total number of buckets in the table, in contrast to capacity() which // returns the total number of slots. size_t bucket_count() const { return bucket_count_; } // Overall capacity of the table (including stash buckets) in number of keys. size_t capacity() const { return bucket_count() * kSlotNum; } double load_factor() const { return double(size()) / capacity(); } static constexpr unsigned LargestBucketId() { return SegmentType::kBucketNum + SegmentType::kStashBucketNum - 1; } // Gets a random cursor based on the available segments and buckets. // Returns: cursor with a random position Cursor GetRandomCursor(absl::BitGen* bitgen); // Traverses over a single logical bucket in table and calls cb(iterator) 0 or more // times. if cursor=0 starts traversing from the beginning, otherwise continues from where it // stopped. returns 0 if the supplied cursor reached end of traversal. Traverse iterates at bucket // logical granularity, which means for each non-empty bucket it calls cb per each entry in the // logical bucket before returning. Unlike begin/end interface, traverse is stable during table // mutations. It guarantees that if key exists (1)at the beginning of traversal, (2) stays in the // table during the traversal, then Traverse() will eventually reach it even when the table // shrinks or grows. Returns: cursor that is guaranteed to be less than 2^40. template Cursor Traverse(Cursor curs, Cb&& cb); // Traverses over physical buckets. It calls cb once for each bucket by passing a bucket iterator. // if cursor=0 starts traversing from the beginning, otherwise continues from where // it stopped. returns 0 if the supplied cursor reached end of traversal. // Unlike Traverse, TraverseBuckets calls cb once on bucket iterator and not on each entry in // bucket. TraverseBuckets is stable during table mutations. It guarantees traversing all buckets // that existed at the beginning of traversal. template Cursor TraverseBuckets(Cursor curs, Cb&& cb); // Traverses over a single bucket in table and calls cb(iterator). The traverse order will be // segment by segment over physical backets. // traverse by segment order does not guarantees coverage if the table grows/shrinks, it is useful // when formal full coverage is not critically important. template Cursor TraverseBySegmentOrder(Cursor curs, Cb&& cb); // Discards slots information. static const_bucket_iterator BucketIt(const_iterator it) { return const_bucket_iterator{it.owner_, it.seg_id_, it.bucket_id_, 0}; } // Seeks to the first occupied slot if exists in the bucket. const_bucket_iterator BucketIt(unsigned segment_id, unsigned bucket_id) const { return const_bucket_iterator{this, segment_id, uint8_t(bucket_id)}; } bucket_iterator BucketIt(unsigned segment_id, unsigned bucket_id) { return bucket_iterator{this, segment_id, uint8_t(bucket_id)}; } iterator GetIterator(unsigned segment_id, unsigned bucket_id, unsigned slot_id) { return iterator{this, segment_id, uint8_t(bucket_id), uint8_t(slot_id)}; } const_bucket_iterator CursorToBucketIt(Cursor c) const { return const_bucket_iterator{this, c.segment_id(global_depth_), c.bucket_id(), 0}; } bucket_iterator CursorToBucketIt(Cursor c) { return bucket_iterator{this, c.segment_id(global_depth_), c.bucket_id(), 0}; } // Capture Version Change. Runs cb(it) on every bucket! (not entry) in the table whose version // would potentially change upon insertion of 'k'. // In practice traversal is limited to a single segment. The operation is read-only and // simulates insertion process. 'cb' must accept bucket_iterator. // Note: the interface a bit hacky. // The functions call cb on physical buckets with version smaller than ver_threshold that // due to entry movements might update its version to version greater than ver_threshold. // // These are not const functions because they send non-const iterators that allow // updating contents/versions of the passed iterators. template void CVCUponInsert(uint64_t ver_threshold, const U& key, Cb&& cb); template void CVCUponBump(uint64_t ver_threshold, const_iterator it, Cb&& cb); void Clear(); // Returns true if an element was deleted i.e the rightmost slot was busy. bool ShiftRight(bucket_iterator it); template iterator BumpUp(iterator it, BumpPolicy& bp) { SegmentIterator seg_it = segment_[it.seg_id_]->BumpUp( it.bucket_id_, it.slot_id_, DoHash(it->first), bp, [&](uint32_t segment_id, detail::PhysicalBid from, detail::PhysicalBid to) { // OnMove is used to notify policy about the items moves across buckets. bp.OnMove(Cursor{global_depth_, segment_id, from}, Cursor{global_depth_, segment_id, to}); }); return iterator{this, it.seg_id_, seg_it.index, seg_it.slot}; } uint64_t garbage_collected() const { return garbage_collected_; } uint64_t stash_unloaded() const { return stash_unloaded_; } private: enum class InsertMode { kInsertIfNotFound, kForceInsert, }; Cursor AdvanceCursorBucketOrder(Cursor cursor); template std::pair InsertInternal(U&& key, V&& value, EvictionPolicy& policy, InsertMode mode); void IncreaseDepth(unsigned new_depth); template void Split(uint32_t seg_id, EvictionPolicy& ev); // Segment directory contains multiple segment pointers, some of them pointing to // the same object. IterateDistinct goes over all distinct segments in the table. template void IterateDistinct(Cb&& cb); template auto EqPred(const K& key) const { return [p = &policy_, &key](const auto& probe) -> bool { return p->Equal(probe, key); }; } SegmentType* ConstructSegment(uint8_t depth, uint32_t id) { auto* mr = segment_.get_allocator().resource(); PMR_NS::polymorphic_allocator pa(mr); SegmentType* res = pa.allocate(1); pa.construct(res, depth, id, mr); // new SegmentType(depth); bucket_count_ += res->num_buckets(); return res; } Policy policy_; std::vector> segment_; uint64_t garbage_collected_ = 0; uint64_t stash_unloaded_ = 0; }; // DashTable template template class DashTable<_Key, _Value, Policy>::Iterator { using Owner = std::conditional_t; Owner* owner_; uint32_t seg_id_; detail::PhysicalBid bucket_id_; uint8_t slot_id_; friend class DashTable; Iterator(Owner* me, uint32_t seg_id, detail::PhysicalBid bid, uint8_t sid) : owner_(me), seg_id_(seg_id), bucket_id_(bid), slot_id_(sid) { } Iterator(Owner* me, uint32_t seg_id, detail::PhysicalBid bid) : owner_(me), seg_id_(seg_id), bucket_id_(bid), slot_id_(0) { Seek2Occupied(); } public: using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using IteratorPairType = std::conditional_t, detail::IteratorPair>; // Copy constructor from iterator to const_iterator. template ::type* = nullptr> Iterator(const Iterator& other) noexcept : owner_(other.owner_), seg_id_(other.seg_id_), bucket_id_(other.bucket_id_), slot_id_(other.slot_id_) { } // Copy constructor from iterator to bucket_iterator and vice versa. template Iterator(const Iterator& other) noexcept : owner_(other.owner_), seg_id_(other.seg_id_), bucket_id_(other.bucket_id_), slot_id_(IsSingleBucket ? 0 : other.slot_id_) { // if this - is a bucket_iterator - we reset slot_id to the first occupied space. if constexpr (IsSingleBucket) { Seek2Occupied(); } } Iterator() : owner_(nullptr), seg_id_(0), bucket_id_(0), slot_id_(0) { } Iterator(const Iterator& other) = default; Iterator(Iterator&& other) = default; Iterator& operator=(const Iterator& other) = default; Iterator& operator=(Iterator&& other) = default; // pre Iterator& operator++() { ++slot_id_; Seek2Occupied(); return *this; } Iterator& operator+=(int delta) { slot_id_ += delta; Seek2Occupied(); return *this; } Iterator& AdvanceIfNotOccupied() { if (!IsOccupied()) { this->operator++(); } return *this; } IteratorPairType operator->() const { auto* seg = owner_->segment_[seg_id_]; return {seg->Key(bucket_id_, slot_id_), seg->Value(bucket_id_, slot_id_)}; } // Make it self-contained. Does not need container::end(). bool is_done() const { return owner_ == nullptr; } bool IsOccupied() const { return (seg_id_ < owner_->segment_.size()) && ((owner_->segment_[seg_id_]->IsBusy(bucket_id_, slot_id_))); } Owner& owner() const { return *owner_; } template std::enable_if_t GetVersion() const { assert(owner_ && seg_id_ < owner_->segment_.size()); return owner_->segment_[seg_id_]->GetVersion(bucket_id_); } template std::enable_if_t SetVersion(uint64_t v) { return owner_->segment_[seg_id_]->SetVersion(bucket_id_, v); } friend bool operator==(const Iterator& lhs, const Iterator& rhs) { if (lhs.owner_ == nullptr && rhs.owner_ == nullptr) return true; return lhs.owner_ == rhs.owner_ && lhs.seg_id_ == rhs.seg_id_ && lhs.bucket_id_ == rhs.bucket_id_ && lhs.slot_id_ == rhs.slot_id_; } friend bool operator!=(const Iterator& lhs, const Iterator& rhs) { return !(lhs == rhs); } // Bucket resolution cursor that is safe to use with insertions/removals. // Serves as a hint really to the placement of the original item, i.e. the item // could have moved. detail::DashCursor bucket_cursor() const { return detail::DashCursor(owner_->global_depth_, seg_id_, bucket_id_); } detail::PhysicalBid bucket_id() const { return bucket_id_; } // Returns the unique address of the physical bucket as an integer. // Stable for the lifetime of a serialization (mutations that could trigger // segment splits are blocked while a snapshot version is registered). uintptr_t bucket_address() const { assert(owner_ && seg_id_ < owner_->segment_.size()); return reinterpret_cast(&owner_->segment_[seg_id_]->GetBucket(bucket_id_)); } unsigned slot_id() const { return slot_id_; } unsigned segment_id() const { return seg_id_; } private: void Seek2Occupied(); }; // Iterator /** _____ _ _ _ _ |_ _| | | | | | | (_) | | _ __ ___ _ __ | | ___ _ __ ___ ___ _ __ | |_ __ _| |_ _ ___ _ __ | | | '_ ` _ \| '_ \| |/ _ \ '_ ` _ \ / _ \ '_ \| __/ _` | __| |/ _ \| '_ \ _| |_| | | | | | |_) | | __/ | | | | | __/ | | | || (_| | |_| | (_) | | | | |_____|_| |_| |_| .__/|_|\___|_| |_| |_|\___|_| |_|\__\__,_|\__|_|\___/|_| |_| | | |_| **/ template template void DashTable<_Key, _Value, Policy>::Iterator::Seek2Occupied() { if (owner_ == nullptr) return; assert(seg_id_ < owner_->segment_.size()); if constexpr (IsSingleBucket) { const auto& b = owner_->segment_[seg_id_]->GetBucket(bucket_id_); uint32_t mask = b.GetBusy() >> slot_id_; if (mask) { int slot = __builtin_ctz(mask); slot_id_ += slot; return; } } else { while (seg_id_ < owner_->segment_.size()) { auto seg_it = owner_->segment_[seg_id_]->FindValidStartingFrom(bucket_id_, slot_id_); if (seg_it.found()) { bucket_id_ = seg_it.index; slot_id_ = seg_it.slot; return; } seg_id_ = owner_->NextSeg(seg_id_); bucket_id_ = slot_id_ = 0; } } owner_ = nullptr; } template DashTable<_Key, _Value, Policy>::DashTable(size_t capacity_log, const Policy& policy, PMR_NS::memory_resource* mr) : Base(capacity_log), policy_(policy), segment_(mr) { segment_.resize(unique_segments_); // I assume we have enough memory to create the initial table and do not check allocations. for (uint32_t i = 0; i < segment_.size(); ++i) { segment_[i] = ConstructSegment(global_depth_, i); // new SegmentType(global_depth_); } } template DashTable<_Key, _Value, Policy>::~DashTable() { Clear(); auto* resource = segment_.get_allocator().resource(); PMR_NS::polymorphic_allocator pa(resource); using alloc_traits = std::allocator_traits; IterateDistinct([&](SegmentType* seg) { alloc_traits::destroy(pa, seg); alloc_traits::deallocate(pa, seg, 1); return false; }); } template template void DashTable<_Key, _Value, Policy>::CVCUponInsert(uint64_t ver_threshold, const U& key, Cb&& cb) { uint64_t key_hash = DoHash(key); uint32_t seg_id = SegmentId(key_hash); assert(seg_id < segment_.size()); const SegmentType* target = segment_[seg_id]; uint8_t bids[2]; unsigned num_touched = target->CVCOnInsert(ver_threshold, key_hash, bids); if (num_touched < UINT16_MAX) { for (unsigned i = 0; i < num_touched; ++i) { cb(bucket_iterator{this, seg_id, bids[i]}); } return; } // Segment is full, we need to return the whole segment, because it can be split // and its entries can be reshuffled into different buckets. for (uint8_t i = 0; i < target->num_buckets(); ++i) { if (target->GetVersion(i) < ver_threshold && !target->GetBucket(i).IsEmpty()) { cb(bucket_iterator{this, seg_id, i}); } } } template template void DashTable<_Key, _Value, Policy>::CVCUponBump(uint64_t ver_upperbound, const_iterator it, Cb&& cb) { uint64_t key_hash = DoHash(it->first); uint32_t seg_id = it.segment_id(); assert(seg_id < segment_.size()); const SegmentType* target = segment_[seg_id]; uint8_t bids[3]; unsigned num_touched = target->CVCOnBump(ver_upperbound, it.bucket_id(), it.slot_id(), key_hash, bids); for (unsigned i = 0; i < num_touched; ++i) { cb(bucket_iterator{this, seg_id, bids[i]}); } } template void DashTable<_Key, _Value, Policy>::Clear() { auto cb = [this](SegmentType* seg) { seg->TraverseAll([this, seg](const SegmentIterator& it) { policy_.DestroyKey(seg->Key(it.index, it.slot)); policy_.DestroyValue(seg->Value(it.index, it.slot)); }); seg->Clear(); return false; }; IterateDistinct(cb); size_ = 0; // Consider the following case: table with 8 segments overall, 4 distinct. // S1, S1, S1, S1, S2, S3, S4, S4 /* This corresponds to the tree: R / \ S1 /\ /\ S4 S2 S3 We want to collapse this tree into, say, 2 segment directory. That means we need to keep S1, S2 but delete S3, S4. That means, we need to move representative segments until we reached the desired size and then erase all other distinct segments. **********/ if (global_depth_ > initial_depth_) { PMR_NS::polymorphic_allocator pa(segment_.get_allocator()); using alloc_traits = std::allocator_traits; size_t dest = 0, src = 0; size_t new_size = (1 << initial_depth_); bucket_count_ = 0; while (src < segment_.size()) { auto* seg = segment_[src]; size_t next_src = NextSeg(src); // must do before because NextSeg is dependent on seg. if (dest < new_size) { seg->set_local_depth(initial_depth_); bucket_count_ += seg->num_buckets(); segment_[dest++] = seg; } else { alloc_traits::destroy(pa, seg); alloc_traits::deallocate(pa, seg, 1); } src = next_src; } global_depth_ = initial_depth_; unique_segments_ = new_size; segment_.resize(new_size); } } template bool DashTable<_Key, _Value, Policy>::ShiftRight(bucket_iterator it) { auto* seg = segment_[it.seg_id_]; typename Segment_t::Hash_t hash_val = 0; auto& bucket = seg->GetBucket(it.bucket_id_); if (bucket.GetBusy() & (1 << (kSlotNum - 1))) { it.slot_id_ = kSlotNum - 1; hash_val = DoHash(it->first); policy_.DestroyKey(it->first); policy_.DestroyValue(it->second); } bool deleted = seg->ShiftRight(it.bucket_id_, hash_val); size_ -= unsigned(deleted); return deleted; } template template void DashTable<_Key, _Value, Policy>::IterateDistinct(Cb&& cb) { size_t i = 0; while (i < segment_.size()) { auto* seg = segment_[i]; size_t next_id = NextSeg(i); if (cb(seg)) break; i = next_id; } } template template auto DashTable<_Key, _Value, Policy>::Find(U&& key) const -> const_iterator { uint64_t key_hash = DoHash(key); uint32_t seg_id = SegmentId(key_hash); // seg_id takes up global_depth_ high bits. // Hash structure is like this: [SSUUUUBF], where S is segment id, U - unused, // B - bucket id and F is a fingerprint. Segment id is needed to identify the correct segment. // Once identified, the segment instance uses the lower part of hash to locate the key. // It uses 8 least significant bits for a fingerprint and few more bits for bucket id. if (auto seg_it = segment_[seg_id]->FindIt(key_hash, EqPred(key)); seg_it.found()) { return {this, seg_id, seg_it.index, seg_it.slot}; } return {}; } template template auto DashTable<_Key, _Value, Policy>::Find(U&& key) -> iterator { return FindFirst(DoHash(key), EqPred(key)); } template template void DashTable<_Key, _Value, Policy>::Prefetch(U&& key) const { uint64_t key_hash = DoHash(key); uint32_t seg_id = SegmentId(key_hash); segment_[seg_id]->Prefetch(key_hash); } template template auto DashTable<_Key, _Value, Policy>::FindFirst(uint64_t key_hash, Pred&& pred) -> iterator { uint32_t seg_id = SegmentId(key_hash); if (auto seg_it = segment_[seg_id]->FindIt(key_hash, pred); seg_it.found()) { return {this, seg_id, seg_it.index, seg_it.slot}; } return {}; } template size_t DashTable<_Key, _Value, Policy>::Erase(const Key_t& key) { uint64_t key_hash = DoHash(key); size_t x = SegmentId(key_hash); auto* target = segment_[x]; auto it = target->FindIt(key_hash, EqPred(key)); if (!it.found()) return 0; policy_.DestroyKey(target->Key(it.index, it.slot)); policy_.DestroyValue(target->Value(it.index, it.slot)); target->Delete(it, key_hash); --size_; return 1; } template void DashTable<_Key, _Value, Policy>::Erase(iterator it) { auto* target = segment_[it.seg_id_]; uint64_t key_hash = DoHash(it->first); SegmentIterator sit{it.bucket_id_, it.slot_id_}; policy_.DestroyKey(it->first); policy_.DestroyValue(it->second); target->Delete(sit, key_hash); --size_; } template void DashTable<_Key, _Value, Policy>::Reserve(size_t size) { if (size <= capacity()) return; size_t sg_floor = (size - 1) / SegmentType::capacity(); if (sg_floor < segment_.size()) { return; } assert(sg_floor > 1u); unsigned new_depth = 1 + (63 ^ __builtin_clzll(sg_floor)); IncreaseDepth(new_depth); } template template auto DashTable<_Key, _Value, Policy>::InsertInternal(U&& key, V&& value, EvictionPolicy& ev, InsertMode mode) -> std::pair { uint64_t key_hash = DoHash(key); uint32_t target_seg_id = SegmentId(key_hash); while (true) { // Keep last global_depth_ msb bits of the hash. assert(target_seg_id < segment_.size()); SegmentType* target = segment_[target_seg_id]; // Load heap allocated segment data - to avoid TLB miss when accessing the bucket. __builtin_prefetch(target, 0, 1); typename SegmentType::Iterator it; bool res = true; unsigned num_buckets = target->num_buckets(); auto move_cb = [&](uint32_t segment_id, detail::PhysicalBid from, detail::PhysicalBid to) { // OnMove is used to notify policy about the move of items across buckets. ev.OnMove(Cursor{global_depth_, segment_id, from}, Cursor{global_depth_, segment_id, to}); }; if (mode == InsertMode::kForceInsert) { it = target->InsertUniq(std::forward(key), std::forward(value), key_hash, true, move_cb); res = it.found(); } else { std::tie(it, res) = target->Insert(std::forward(key), std::forward(value), key_hash, EqPred(key), move_cb); } if (res) { // success // in case segment bucket count changed, we need to update total bucket count. bucket_count_ += (target->num_buckets() - num_buckets); ++size_; return std::make_pair(iterator{this, target_seg_id, it.index, it.slot}, true); } /*duplicate insert, insertion failure*/ if (it.found()) { return std::make_pair(iterator{this, target_seg_id, it.index, it.slot}, false); } bool consider_throw = true; // At this point we must split the segment. // try garbage collect or evict. if constexpr (EvictionPolicy::can_evict || EvictionPolicy::can_gc) { // Try gc. uint8_t bid[HotBuckets::kRegularBuckets]; SegmentType::FillProbeArray(key_hash, bid); HotBuckets hotspot; hotspot.key_hash = key_hash; for (unsigned j = 0; j < HotBuckets::kRegularBuckets; ++j) { hotspot.probes.by_type.regular_buckets[j] = bucket_iterator{this, target_seg_id, bid[j]}; } for (unsigned i = 0; i < SegmentType::kStashBucketNum; ++i) { hotspot.probes.by_type.stash_buckets[i] = bucket_iterator{this, target_seg_id, uint8_t(Policy::kBucketNum + i), 0}; } hotspot.num_buckets = HotBuckets::kNumBuckets; // The difference between gc and eviction is that gc can be applied even if // the table can grow since we throw away logically deleted items. // For eviction to be applied we should reach the growth limit. if constexpr (EvictionPolicy::can_gc) { unsigned res = ev.GarbageCollect(hotspot, this); garbage_collected_ += res; if (res) { // We succeeded to gc. Lets continue with the momentum. // In terms of API abuse it's an awful hack, just to see if it works. /*unsigned start = (bid[HotBuckets::kNumBuckets - 1] + 1) % kLogicalBucketNum; for (unsigned i = 0; i < HotBuckets::kNumBuckets; ++i) { uint8_t id = (start + i) % kLogicalBucketNum; buckets.probes.arr[i] = bucket_iterator{this, target_seg_id, id}; } garbage_collected_ += ev.GarbageCollect(buckets, this); */ continue; } } auto hash_fn = [this](const auto& k) { return policy_.HashFn(k); }; unsigned moved = target->UnloadStash(hash_fn, move_cb); if (moved > 0) { stash_unloaded_ += moved; continue; } // We evict only if our policy says we can not grow if constexpr (EvictionPolicy::can_evict) { bool can_grow = ev.CanGrow(*this); if (can_grow) { consider_throw = false; } else { unsigned res = ev.Evict(hotspot, this); if (res) continue; } } } if (consider_throw && !ev.CanGrow(*this)) { throw std::bad_alloc{}; } // Split the segment. if (target->local_depth() == global_depth_) { IncreaseDepth(global_depth_ + 1); target_seg_id = SegmentId(key_hash); assert(target_seg_id < segment_.size() && segment_[target_seg_id] == target); } ev.RecordSplit(target); Split(target_seg_id, ev); } return std::make_pair(iterator{}, false); } template void DashTable<_Key, _Value, Policy>::IncreaseDepth(unsigned new_depth) { assert(!segment_.empty()); assert(new_depth > global_depth_); size_t prev_sz = segment_.size(); size_t repl_cnt = 1ul << (new_depth - global_depth_); segment_.resize(1ul << new_depth); for (int i = prev_sz - 1; i >= 0; --i) { size_t offs = i * repl_cnt; std::fill(segment_.begin() + offs, segment_.begin() + offs + repl_cnt, segment_[i]); segment_[i]->set_segment_id(offs); // update segment id. } global_depth_ = new_depth; } template template void DashTable<_Key, _Value, Policy>::Split(uint32_t seg_id, EvictionPolicy& ev) { SegmentType* source = segment_[seg_id]; uint32_t chunk_size = 1u << (global_depth_ - source->local_depth()); uint32_t start_idx = seg_id & (~(chunk_size - 1)); assert(segment_[start_idx] == source && segment_[start_idx + chunk_size - 1] == source); uint32_t target_id = start_idx + chunk_size / 2; SegmentType* target = ConstructSegment(source->local_depth() + 1, target_id); auto hash_fn = [this](const auto& k) { return policy_.HashFn(k); }; // remove current segment bucket count. bucket_count_ -= (source->num_buckets() + target->num_buckets()); source->Split( std::move(hash_fn), target, [&](uint32_t segment_from, detail::PhysicalBid from, uint32_t segment_to, detail::PhysicalBid to) { // OnMove is used to notify eviction policy about the moves across // buckets/segments during the split. ev.OnMove(Cursor{global_depth_, segment_from, from}, Cursor{global_depth_, segment_to, to}); }); // add back the updated bucket count. bucket_count_ += (target->num_buckets() + source->num_buckets()); ++unique_segments_; for (size_t i = target_id; i < start_idx + chunk_size; ++i) { segment_[i] = target; } } template template auto DashTable<_Key, _Value, Policy>::TraverseBySegmentOrder(Cursor curs, Cb&& cb) -> Cursor { uint32_t sid = curs.segment_id(global_depth_); assert(sid < segment_.size()); SegmentType* s = segment_[sid]; assert(s); uint8_t bid = curs.bucket_id(); auto dt_cb = [&](const SegmentIterator& it) { cb(iterator{this, sid, it.index, it.slot}); }; s->TraverseBucket(bid, std::move(dt_cb)); ++bid; if (SegmentType::OutOfRange(bid)) { sid = NextSeg(sid); if (sid >= segment_.size()) { return Cursor::end(); } bid = 0; } return Cursor{global_depth_, sid, bid}; } template auto DashTable<_Key, _Value, Policy>::GetRandomCursor(absl::BitGen* bitgen) -> Cursor { uint32_t sid = absl::Uniform(*bitgen, 0, segment_.size()); uint8_t bid = absl::Uniform(*bitgen, 0, Policy::kBucketNum); return Cursor{global_depth_, sid, bid}; } template template auto DashTable<_Key, _Value, Policy>::Traverse(Cursor curs, Cb&& cb) -> Cursor { uint32_t sid = curs.segment_id(global_depth_); uint8_t bid = curs.bucket_id(); // Test validity of the cursor. if (bid >= Policy::kBucketNum || sid >= segment_.size()) return Cursor::end(); auto hash_fun = [this](const auto& k) { return policy_.HashFn(k); }; bool fetched = false; // We fix bid and go over all segments. Once we reach the end we increase bid and repeat. do { SegmentType* s = segment_[sid]; assert(s); auto dt_cb = [&](const SegmentIterator& it) { cb(iterator{this, sid, it.index, it.slot}); }; fetched = s->TraverseLogicalBucket(bid, hash_fun, std::move(dt_cb)); sid = NextSeg(sid); if (sid >= segment_.size()) { sid = 0; ++bid; if (bid >= Policy::kBucketNum) return Cursor::end(); } } while (!fetched); return Cursor{global_depth_, sid, bid}; } template auto DashTable<_Key, _Value, Policy>::AdvanceCursorBucketOrder(Cursor cursor) -> Cursor { // We fix bid and go over all segments. Once we reach the end we increase bid and repeat. uint32_t sid = cursor.segment_id(global_depth_); uint8_t bid = cursor.bucket_id(); sid = NextSeg(sid); if (sid >= segment_.size()) { sid = 0; ++bid; if (SegmentType::OutOfRange(bid)) return Cursor::end(); } return Cursor{global_depth_, sid, bid}; } template template auto DashTable<_Key, _Value, Policy>::TraverseBuckets(Cursor cursor, Cb&& cb) -> Cursor { if (SegmentType::OutOfRange(cursor.bucket_id())) // sanity. return Cursor::end(); constexpr uint32_t kMaxIterations = 8; bool invoked = false; for (uint32_t i = 0; i < kMaxIterations; ++i) { uint32_t sid = cursor.segment_id(global_depth_); uint8_t bid = cursor.bucket_id(); SegmentType* s = segment_[sid]; assert(s); if (bid < s->num_buckets()) { const auto& bucket = s->GetBucket(bid); if (bucket.GetBusy()) { // Invoke callback only if bucket has elements. cb(BucketIt(sid, bid)); invoked = true; } } cursor = AdvanceCursorBucketOrder(cursor); if (invoked || !cursor) // Break end of traversal or callback invoked. return cursor; } return cursor; } } // namespace dfly ================================================ FILE: src/core/dash_bench.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #include #include "base/hash.h" #include "base/histogram.h" #include "base/init.h" #include "core/dash.h" extern "C" { #include "redis/dict.h" #include "redis/sds.h" #include "redis/zmalloc.h" } using namespace std; ABSL_FLAG(uint32_t, n, 100000, "num items"); ABSL_FLAG(string, type, "dash", ""); ABSL_FLAG(bool, sds, false, "If true, uses sds as primary key"); namespace dfly { static uint64_t dictSdsHash(const void* key) { return dictGenHashFunction((unsigned char*)key, sdslen((char*)key)); } static int dictSdsKeyCompare(dict*, const void* key1, const void* key2) { int l1, l2; l1 = sdslen((sds)key1); l2 = sdslen((sds)key2); if (l1 != l2) return 0; return memcmp(key1, key2, l1) == 0; } static dictType SdsDict = { dictSdsHash, /* hash function */ NULL, /* key dup */ NULL, /* val dup */ dictSdsKeyCompare, /* key compare */ NULL, // dictSdsDestructor, /* key destructor */ NULL, /* val destructor */ NULL, }; struct UInt64Policy { enum { kSlotNum = 12, kBucketNum = 64, kStashBucketNum = 2 }; static constexpr bool kUseVersion = false; static uint64_t HashFn(uint64_t v) { return XXH3_64bits(&v, sizeof(v)); } template static void DestroyValue(const U&) { } template static void DestroyKey(const U&) { } template static bool Equal(U&& u, V&& v) { return u == v; } }; struct SdsDashPolicy { enum { kSlotNum = 14, kBucketNum = 56, kStashBucketNum = 4 }; static constexpr bool kUseVersion = false; static uint64_t HashFn(sds u) { return XXH3_64bits(reinterpret_cast(u), sdslen(u)); } static uint64_t HashFn(std::string_view u) { return XXH3_64bits(u.data(), u.size()); } static void DestroyKey(sds s) { sdsfree(s); } static void DestroyValue(uint64_t) { } static bool Equal(sds u1, sds u2) { return dictSdsKeyCompare(nullptr, u1, u2) == 0; } static bool Equal(sds u1, std::string_view u2) { return u2 == std::string_view{u1, sdslen(u1)}; } }; using Dash64 = DashTable; using DashSds = DashTable; using absl::GetFlag; inline void Sample(int64_t start, int64_t end, base::Histogram* hist) { hist->Add((end - start) / 100); } Dash64 udt; DashSds sds_dt; base::Histogram hist; #define USE_TIME 1 int64_t GetNow() { #if USE_TIME return absl::GetCurrentTimeNanos(); #else return absl::base_internal::CycleClock::Now(); #endif } #if defined(__i386__) || defined(__amd64__) #define LFENCE __asm__ __volatile__("lfence") #else #define LFENCE __asm__ __volatile__("ISB") #endif absl::flat_hash_map mymap; void BenchFlat(uint64_t num) { for (uint64_t i = 0; i < num; ++i) { time_t start = GetNow(); mymap.emplace(i, 0); LFENCE; time_t end = GetNow(); Sample(start, end, &hist); } } void BenchDash(uint64_t num) { for (uint64_t i = 0; i < num; ++i) { time_t start = GetNow(); udt.Insert(i, 0); LFENCE; time_t end = GetNow(); Sample(start, end, &hist); } } inline sds Prefix() { return sdsnew("xxxxxxxxxxxxxxxxxxxxxxx"); } void BenchDashSds(uint64_t num) { sds key = sdscatsds(Prefix(), sdsfromlonglong(0)); for (uint64_t i = 0; i < num; ++i) { time_t start = GetNow(); sds_dt.Insert(key, 0); time_t end = GetNow(); Sample(start, end, &hist); key = sdscatsds(Prefix(), sdsfromlonglong(i + 1)); } } static uint64_t callbackHash(const void* key) { return XXH64(&key, sizeof(key), 0); } static dictType IntDict = {callbackHash, NULL, NULL, NULL, NULL, NULL, NULL}; dict* redis_dict = nullptr; void BenchDict(uint64_t num) { redis_dict = dictCreate(&IntDict); for (uint64_t i = 0; i < num; ++i) { time_t start = GetNow(); dictAdd(redis_dict, (void*)i, nullptr); LFENCE; time_t end = GetNow(); Sample(start, end, &hist); } } void BenchDictSds() { uint64_t num = GetFlag(FLAGS_n); sds key = sdscat(Prefix(), sdsfromlonglong(0)); redis_dict = dictCreate(&SdsDict); for (uint64_t i = 0; i < num; ++i) { time_t start = GetNow(); dictAdd(redis_dict, key, nullptr); time_t end = GetNow(); Sample(start, end, &hist); key = sdscatsds(Prefix(), sdsfromlonglong(i + 1)); } } } // namespace dfly using namespace dfly; int main(int argc, char* argv[]) { MainInitGuard guard(&argc, &argv); init_zmalloc_threadlocal(mi_heap_get_backing()); string table_type = GetFlag(FLAGS_type); bool is_sds = GetFlag(FLAGS_sds); uint64_t start = absl::GetCurrentTimeNanos(); uint64_t num = GetFlag(FLAGS_n); if (table_type == "dash") { if (is_sds) { BenchDashSds(num); } else { BenchDash(num); } } else if (table_type == "dict") { if (is_sds) { BenchDictSds(); } else { BenchDict(num); } } else if (table_type == "flat") { BenchFlat(num); } else { LOG(FATAL) << "Unknown type " << table_type; } CONSOLE_INFO << "latencies histogram (jiffies, 100ns):\n" << hist.ToString(); uint64_t delta = (absl::GetCurrentTimeNanos() - start) / 1000000; CONSOLE_INFO << "Took " << delta << " ms"; return 0; } ================================================ FILE: src/core/dash_internal.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include #include "base/pmr/memory_resource.h" #include "core/sse_port.h" namespace dfly { namespace detail { template class SlotBitmap { static_assert(NUM_SLOTS > 0 && NUM_SLOTS <= 28); static constexpr bool SINGLE = NUM_SLOTS <= 14; static constexpr unsigned kLen = SINGLE ? 1 : 2; static constexpr unsigned kAllocMask = (1u << NUM_SLOTS) - 1; static constexpr unsigned kBitmapLenMask = (1 << 4) - 1; public: // probe - true means the entry is probing, i.e. not owning. // probe=true GetProbe returns index of probing entries, i.e. hosted but not owned by this bucket. // probe=false - mask of owning entries uint32_t GetProbe(bool probe) const { if constexpr (SINGLE) return ((val_[0].d >> 4) & kAllocMask) ^ ((!probe) * kAllocMask); else return (val_[1].d & kAllocMask) ^ ((!probe) * kAllocMask); } // GetBusy returns the busy mask. uint32_t GetBusy() const { return SINGLE ? val_[0].d >> 18 : val_[0].d; } bool IsFull() const { return Size() == NUM_SLOTS; } unsigned Size() const { return SINGLE ? (val_[0].d & kBitmapLenMask) : __builtin_popcount(val_[0].d); } // Precondition: Must have empty slot // returns result in [0, NUM_SLOTS) range. int FindEmptySlot() const { uint32_t mask = ~(GetBusy()); // returns the index for first set bit (FindLSBSetNonZero). mask must be non-zero. int slot = __builtin_ctz(mask); assert(slot < int(NUM_SLOTS)); return slot; } // mask is NUM_SLOTS bits saying which slots needs to be freed (1 - should clear). void ClearSlots(uint32_t mask); void Clear() { if (SINGLE) { val_[0].d = 0; } else { val_[0].d = val_[1].d = 0; } } void ClearSlot(unsigned index); void SetSlot(unsigned index, bool probe); // cell 0 corresponds to first lsb bit in the busy mask, hence we need to shift left // the bitmap in order to shift right the cell-array. // Returns true if discarded the last slot (i.e. it was busy). bool ShiftLeft(); void Swap(unsigned slot_a, unsigned slot_b); private: // SINGLE: // val_[0] is [14 bit- busy][14bit-probing, whether the key does not belong to this // bucket][4bit-count] // kLen == 2: // val_[0] is 28 bit busy // val_[1] is 28 bit probing // count is implemented via popcount of val_[0]. struct Unaligned { // Apparently with wrapping struct we can persuade compiler to declare an unaligned int. // https://stackoverflow.com/questions/19915303/packed-qualifier-ignored uint32_t d __attribute__((packed, aligned(1))); Unaligned() : d(0) { } }; Unaligned val_[kLen]; }; // SlotBitmap template class BucketBase { // We can not allow more than 4 stash fps because we hold stash positions in single byte // stash_pos_ variable that uses 2 bits per stash bucket to point which bucket holds that fp. // Hence we can point at most from 4 fps to 4 stash buckets. // If any of those limits need to be raised we should increase stash_pos_ similarly to how we did // with SlotBitmap. static constexpr unsigned kStashFpLen = 4; static constexpr unsigned kStashPresentBit = 1 << 4; using FpArray = std::array; using StashFpArray = std::array; public: using SlotId = uint8_t; static constexpr SlotId kNanSlot = 255; bool IsFull() const { return Size() == NUM_SLOTS; } bool IsEmpty() const { return GetBusy() == 0; } unsigned Size() const { return slotb_.Size(); } void Delete(SlotId sid) { slotb_.ClearSlot(sid); } unsigned Find(uint8_t fp_hash, bool probe) const { unsigned mask = CompareFP(fp_hash) & GetBusy(); return mask & GetProbe(probe); } uint8_t Fp(unsigned i) const { assert(i < finger_arr_.size()); return finger_arr_[i]; } void SetStashPtr(unsigned stash_pos, uint8_t meta_hash, BucketBase* next); // returns 0 if stash was cleared from this bucket, 1 if it was cleared from next bucket. unsigned UnsetStashPtr(uint8_t fp_hash, unsigned stash_pos, BucketBase* next); // probe - true means the entry is probing, i.e. not owning. // probe=true GetProbe returns index of probing entries, i.e. hosted but not owned by this bucket. // probe=false - mask of owning entries uint32_t GetProbe(bool probe) const { return slotb_.GetProbe(probe); } // GetBusy returns the busy mask. uint32_t GetBusy() const { return slotb_.GetBusy(); } bool IsBusy(unsigned slot) const { return (GetBusy() & (1u << slot)) != 0; } // mask is saying which slots needs to be freed (1 - should clear). void ClearSlots(uint32_t mask) { slotb_.ClearSlots(mask); } void Clear() { slotb_.Clear(); } void ClearStashPtrs() { stash_busy_ = 0; stash_pos_ = 0; stash_probe_mask_ = 0; overflow_count_ = 0; } bool HasStash() const { return stash_busy_ & kStashPresentBit; } void SetHash(unsigned slot_id, uint8_t meta_hash, bool probe); bool HasStashOverflow() const { return overflow_count_ > 0; } // func accepts an fp_index in range [0, kStashFpLen) and // stash position [0, STASH_BUCKET_NUM) that with fingerprint=fp. func must return // a slot id if it found whatever it searched for when iterating or kNanSlot to continue. // IterateStash returns: first - stash position [0, STASH_BUCKET_NUM), second - slot id // pointing to that stash. template std::pair IterateStash(uint8_t fp, bool is_probe, F&& func) const; void Swap(unsigned slot_a, unsigned slot_b) { slotb_.Swap(slot_a, slot_b); std::swap(finger_arr_[slot_a], finger_arr_[slot_b]); } protected: uint32_t CompareFP(uint8_t fp) const; bool ShiftRight(); // Returns true if stash_pos was stored, false overwise bool SetStash(uint8_t fp, unsigned stash_pos, bool probe); bool ClearStash(uint8_t fp, unsigned stash_pos, bool probe); SlotBitmap slotb_; // allocation bitmap + pointer bitmap + counter /*only use the first 14 bytes, can be accelerated by SSE instruction,0-13 for finger, 14-17 for overflowed*/ FpArray finger_arr_; StashFpArray stash_arr_; uint8_t stash_busy_ = 0; // kStashFpLen+1 bits are used uint8_t stash_pos_ = 0; // 4x2 bits for pointing to stash bucket. // stash_probe_mask_ indicates whether the overflow fingerprint is for the neighbour (1) // or for this bucket (0). kStashFpLen bits are used. uint8_t stash_probe_mask_ = 0; // number of overflowed items stored in stash buckets that do not have fp hashes. uint8_t overflow_count_ = 0; }; // BucketBase static_assert(sizeof(BucketBase<12>) == 24); static_assert(alignof(BucketBase<14>) == 1); static_assert(alignof(BucketBase<12>) == 1); // Optional version support as part of DashTable. // This works like this: each slot has 2 bytes for version and a bucket has another 6. // therefore all slots in the bucket shared the same 6 high bytes of 8-byte version. // In order to achieve this we store high6(max{version(entry)}) for every entry. // Hence our version control may have false positives, i.e. signal that an entry has changed // when in practice its neighbour incremented the high6 part of its bucket. template class VersionedBB : public BucketBase { using Base = BucketBase; public: // one common version per bucket. void SetVersion(uint64_t version); uint64_t GetVersion() const { uint64_t c = absl::little_endian::Load64(version_); // c |= low_[slot_id]; return c; } void UpdateVersion(uint64_t version) { uint64_t c = std::max(GetVersion(), version); absl::little_endian::Store64(version_, c); } void Clear() { Base::Clear(); // low_.fill(0); memset(version_, 0, sizeof(version_)); } bool ShiftRight() { bool res = Base::ShiftRight(); return res; } void Swap(unsigned slot_a, unsigned slot_b) { Base::Swap(slot_a, slot_b); } private: uint8_t version_[8] = {0}; }; static_assert(alignof(VersionedBB<14>) == 1); static_assert(sizeof(VersionedBB<12>) == 12 * 2 + 8); static_assert(sizeof(VersionedBB<14>) <= 14 * 2 + 8); // Segment - static-hashtable of size kSlotNum*(kBucketNum + kStashBucketNum). struct DefaultSegmentPolicy { static constexpr unsigned kSlotNum = 12; static constexpr unsigned kBucketNum = 64; static constexpr bool kUseVersion = true; }; using PhysicalBid = uint8_t; using LogicalBid = uint8_t; template class Segment { public: static constexpr unsigned kSlotNum = Policy::kSlotNum; static constexpr unsigned kBucketNum = Policy::kBucketNum; static constexpr unsigned kStashBucketNum = 4; static constexpr bool kUseVersion = Policy::kUseVersion; private: static_assert(kBucketNum + kStashBucketNum < 255); static constexpr unsigned kFingerBits = 8; using BucketType = std::conditional_t, BucketBase>; struct Bucket : public BucketType { using BucketType::kNanSlot; using typename BucketType::SlotId; KeyType key[kSlotNum]; ValueType value[kSlotNum]; template void Insert(uint8_t slot, U&& u, V&& v, uint8_t meta_hash, bool probe) { assert(slot < kSlotNum); key[slot] = std::forward(u); value[slot] = std::forward(v); this->SetHash(slot, meta_hash, probe); } // Returns slot id if insertion is successful, -1 if no free slots are found. template int TryInsertToBucket(U&& key, V&& value, uint8_t meta_hash, bool probe) { if (this->IsFull()) { return -1; // no free space in the bucket. } int slot = this->slotb_.FindEmptySlot(); assert(slot >= 0); Insert(slot, std::forward(key), std::forward(value), meta_hash, probe); return slot; } template SlotId FindByFp(uint8_t fp_hash, bool probe, Pred&& pred) const; bool ShiftRight(); void Swap(unsigned slot_a, unsigned slot_b) { BucketType::Swap(slot_a, slot_b); std::swap(key[slot_a], key[slot_b]); std::swap(value[slot_a], value[slot_b]); } template void ForEachSlotImpl(This obj, Cb&& cb) const { uint32_t mask = this->GetBusy(); uint32_t probe_mask = this->GetProbe(true); for (unsigned j = 0; j < kSlotNum; ++j) { if (mask & 1) { cb(obj, j, probe_mask & 1); } mask >>= 1; probe_mask >>= 1; } } // calls for each busy slot: cb(iterator, probe) template void ForEachSlot(Cb&& cb) const { ForEachSlotImpl(this, std::forward(cb)); } // calls for each busy slot: cb(iterator, probe) template void ForEachSlot(Cb&& cb) { ForEachSlotImpl(this, std::forward(cb)); } }; // class Bucket static constexpr PhysicalBid kNanBid = 0xFF; using SlotId = typename BucketType::SlotId; public: struct Iterator { PhysicalBid index; // bucket index uint8_t slot; Iterator() : index(kNanBid), slot(BucketType::kNanSlot) { } Iterator(PhysicalBid bi, uint8_t sid) : index(bi), slot(sid) { } bool found() const { return index != kNanBid; } }; struct Stats { size_t neighbour_probes = 0; size_t stash_probes = 0; size_t stash_overflow_probes = 0; }; static constexpr size_t kFpMask = (1 << kFingerBits) - 1; using Value_t = ValueType; using Key_t = KeyType; using Hash_t = uint64_t; explicit Segment(size_t depth, uint32_t id, PMR_NS::memory_resource* mr) : local_depth_(depth), segment_id_(id), mr_(mr) { } ~Segment() { Clear(); } Segment(const Segment&) = delete; Segment& operator=(const Segment&) = delete; // Returns (iterator, true) if insert succeeds, // (iterator, false) for duplicate and (invalid-iterator, false) if it's full template std::pair Insert(K&& key, V&& value, Hash_t key_hash, Pred&& pred, OnMoveCb&& on_move_cb); template void Split(HashFn&& hfunc, Segment* dest, OnMoveCb&& on_move_cb); void Delete(const Iterator& it, Hash_t key_hash); void Clear(); // clears the segment. size_t SlowSize() const; static constexpr size_t capacity() { return kMaxSize; } static constexpr bool OutOfRange(PhysicalBid bid) { return bid >= kBucketNum + kStashBucketNum; } size_t local_depth() const { return local_depth_; } void set_local_depth(uint32_t depth) { local_depth_ = depth; } template std::enable_if_t GetVersion(PhysicalBid bid) const { return GetBucket(bid).GetVersion(); } template std::enable_if_t SetVersion(PhysicalBid bid, uint64_t v) { return GetBucket(bid).SetVersion(v); } // Traverses over Segment's bucket bid and calls cb(const Iterator& it) 0 or more times // for each slot in the bucket. returns false if bucket is empty. // Please note that `it` will not necessary point to bid due to probing and stash buckets // containing items that should have been resided in bid. template bool TraverseLogicalBucket(LogicalBid bid, HashFn&& hfun, Cb&& cb) const; // Cb accepts (const Iterator&). template void TraverseAll(Cb&& cb) const; // Traverses over Segment's bucket bid and calls cb(Iterator& it) // for each slot in the bucket. The iteration goes over a physical bucket. template void TraverseBucket(PhysicalBid bid, Cb&& cb); // Used in test. unsigned NumProbingBuckets() const { unsigned res = 0; for (PhysicalBid i = 0; i < kBucketNum; ++i) { res += (bucket_[i].GetProbe(true) != 0); } return res; }; const Bucket& GetBucket(PhysicalBid i) const { return bucket_[i]; } Bucket& GetBucket(PhysicalBid i) { return bucket_[i]; } bool IsBusy(PhysicalBid bid, unsigned slot) const { return GetBucket(bid).GetBusy() & (1U << slot); } Key_t& Key(PhysicalBid bid, unsigned slot) { assert(IsBusy(bid, slot)); return GetBucket(bid).key[slot]; } const Key_t& Key(PhysicalBid bid, unsigned slot) const { assert(IsBusy(bid, slot)); return GetBucket(bid).key[slot]; } Value_t& Value(PhysicalBid bid, unsigned slot) { assert(IsBusy(bid, slot)); return GetBucket(bid).value[slot]; } const Value_t& Value(PhysicalBid bid, unsigned slot) const { assert(IsBusy(bid, slot)); return GetBucket(bid).value[slot]; } // fill bucket ids that may be used probing for this key_hash. // The order is: exact, neighbour buckets. static void FillProbeArray(Hash_t key_hash, uint8_t dest[4]) { dest[1] = HomeIndex(key_hash); dest[0] = PrevBid(dest[1]); dest[2] = NextBid(dest[1]); dest[3] = NextBid(dest[2]); } // Find item with given key hash and truthy predicate template Iterator FindIt(Hash_t key_hash, Pred&& pred) const; void Prefetch(Hash_t key_hash) const; // Returns valid iterator if succeeded or invalid if not (it's full). // Requires: key should be not present in the segment. // if spread is true, tries to spread the load between neighbour and home buckets, // otherwise chooses home bucket first. // TODO: I am actually not sure if spread optimization is helpful. Worth checking // whether we get higher occupancy rates when using it. template Iterator InsertUniq(U&& key, V&& value, Hash_t key_hash, bool spread, OnMoveCb&& on_move_cb); // capture version change in case of insert. // Returns ids of buckets whose version would cross ver_threshold upon insertion of key_hash // into the segment. // Returns UINT16_MAX if segment is full. Otherwise, returns number of touched bucket ids (1 or 2) // if the insertion would happen. The ids are put into bid array that should have at least 2 // spaces. template std::enable_if_t CVCOnInsert(uint64_t ver_threshold, Hash_t key_hash, PhysicalBid bid[2]) const; // Returns bucket ids whose versions will change as a result of bumping up the item // Can return upto 3 buckets. template std::enable_if_t CVCOnBump(uint64_t ver_threshold, unsigned bid, unsigned slot, Hash_t hash, PhysicalBid result_bid[3]) const; // Finds a valid entry going from specified indices up. Iterator FindValidStartingFrom(PhysicalBid bid, unsigned slot) const; // Shifts all slots in the bucket right. // Returns true if the last slot was busy and the entry has been deleted. bool ShiftRight(PhysicalBid bid, Hash_t right_hashval) { if (bid >= kBucketNum) { // Stash constexpr auto kLastSlotMask = 1u << (kSlotNum - 1); if (GetBucket(bid).GetBusy() & kLastSlotMask) RemoveStashReference(bid - kBucketNum, right_hashval); } return bucket_[bid].ShiftRight(); } // Bumps up this entry making it more "important" for the eviction policy. template Iterator BumpUp(PhysicalBid bid, SlotId slot, Hash_t key_hash, const BumpPolicy& ev, OnMoveCb&& cb); // Tries to move stash entries back to their normal buckets (exact or neighbour). // Returns number of entries that succeeded to unload. // Important! Affects versions of the moved items and the items in the destination // buckets. template unsigned UnloadStash(HFunc&& hfunc, OnMoveCb&& cb); unsigned num_buckets() const { return kBucketNum + kStashBucketNum; } uint32_t segment_id() const { return segment_id_; } // needed only when DashTable grows its segment table. void set_segment_id(uint32_t new_id) { segment_id_ = new_id; } private: static_assert(sizeof(Iterator) == 2); static LogicalBid HomeIndex(Hash_t hash) { return (hash >> kFingerBits) % kBucketNum; } static LogicalBid NextBid(LogicalBid bid) { return bid < kBucketNum - 1 ? bid + 1 : 0; } static LogicalBid PrevBid(LogicalBid bid) { return bid ? bid - 1 : kBucketNum - 1; } // if own_items is true it means we try to move owned item to probing bucket. // if own_items false it means we try to move non-owned item from probing bucket back to its host. int MoveToOther(bool own_items, unsigned from, unsigned to); // dry-run version of MoveToOther. bool CheckIfMovesToOther(bool own_items, unsigned from, unsigned to) const; /*both clear this bucket and its neighbor bucket*/ void RemoveStashReference(unsigned stash_pos, Hash_t key_hash); // returns a valid iterator if succeeded. Iterator TryMoveFromStash(unsigned stash_id, unsigned stash_slot_id, Hash_t key_hash); const static unsigned kTotalBuckets = kBucketNum + kStashBucketNum; static_assert(kTotalBuckets < 0xFF); Bucket bucket_[kTotalBuckets]; uint8_t local_depth_; uint32_t segment_id_; // segment id in the table. PMR_NS::memory_resource* mr_ = nullptr; public: static constexpr size_t kBucketSz = sizeof(Bucket); static constexpr size_t kMaxSize = (kBucketNum + kStashBucketNum) * kSlotNum; static constexpr double kTaxSize = (double(sizeof(Segment)) / kMaxSize) - sizeof(Key_t) - sizeof(Value_t); #ifdef ENABLE_DASH_STATS mutable Stats stats; #endif }; // Segment class DashTableBase { public: explicit DashTableBase(uint32_t gd) : unique_segments_(1 << gd), initial_depth_(gd), global_depth_(gd) { } DashTableBase(const DashTableBase&) = delete; DashTableBase& operator=(const DashTableBase&) = delete; uint32_t unique_segments() const { return unique_segments_; } uint16_t depth() const { return global_depth_; } size_t size() const { return size_; } size_t Empty() const { return size_ == 0; } protected: uint32_t SegmentId(size_t hash) const { if (global_depth_) { return hash >> (64 - global_depth_); } return 0; } size_t size_ = 0; uint32_t unique_segments_ = 0, bucket_count_ = 0; uint8_t initial_depth_; uint8_t global_depth_; }; // DashTableBase template class IteratorPair { public: IteratorPair(KeyType& k, ValueType& v) : first(k), second(v) { } IteratorPair* operator->() { return this; } const IteratorPair* operator->() const { return this; } KeyType& first; ValueType& second; }; // Represents a cursor that points to a bucket in dash table. // One major difference with iterator is that the cursor survives dash table resizes and // will always point to the most appropriate segment with the same bucket. // It uses 40 lsb bits out of 64 assuming that number of segments does not cross 4B. // It's a reasonable assumption in shared nothing architecture when we usually have no more than // 32GB per CPU. Each segment spawns hundreds of entries so we can not grow segment table // to billions. class DashCursor { public: explicit DashCursor(uint64_t token = 0) : val_(token) { } DashCursor(uint8_t depth, uint32_t seg_id, PhysicalBid bid) : val_((uint64_t(seg_id) << (40 - depth)) | bid) { } static DashCursor end() { return DashCursor{}; } PhysicalBid bucket_id() const { return val_ & 0xFF; } // segment_id is padded to the left of 32 bit region: // | segment_id......| bucket_id // 40 8 0 // By using depth we take most significant bits of segment_id if depth has decreased // since the cursor has been created, or extend the least significant bits with zeros, // if depth was increased. uint32_t segment_id(uint8_t depth) const { return val_ >> (40 - depth); } uint64_t token() const { return val_; } explicit operator bool() const { return val_ != 0; } private: uint64_t val_; }; /*********************************************************** * Implementation section. */ template void SlotBitmap::SetSlot(unsigned index, bool probe) { if constexpr (SINGLE) { assert(((val_[0].d >> (index + 18)) & 1) == 0); val_[0].d |= (1 << (index + 18)); val_[0].d |= (unsigned(probe) << (index + 4)); assert((val_[0].d & kBitmapLenMask) < NUM_SLOTS); ++val_[0].d; assert(__builtin_popcount(val_[0].d >> 18) == (val_[0].d & kBitmapLenMask)); } else { assert(((val_[0].d >> index) & 1) == 0); val_[0].d |= (1u << index); val_[1].d |= (unsigned(probe) << index); } } template void SlotBitmap::ClearSlot(unsigned index) { assert(Size() > 0); if constexpr (SINGLE) { uint32_t new_bitmap = val_[0].d & (~(1u << (index + 18))) & (~(1u << (index + 4))); new_bitmap -= 1; val_[0].d = new_bitmap; } else { uint32_t mask = 1u << index; val_[0].d &= ~mask; val_[1].d &= ~mask; } } template bool SlotBitmap::ShiftLeft() { constexpr uint32_t kBusyLastSlot = (kAllocMask >> 1) + 1; bool res; if constexpr (SINGLE) { constexpr uint32_t kShlMask = kAllocMask - 1; // reset lsb res = (val_[0].d & (kBusyLastSlot << 18)) != 0; uint32_t l = (val_[0].d << 1) & (kShlMask << 4); uint32_t p = (val_[0].d << 1) & (kShlMask << 18); val_[0].d = __builtin_popcount(p) | l | p; } else { res = (val_[0].d & kBusyLastSlot) != 0; val_[0].d <<= 1; val_[0].d &= kAllocMask; val_[1].d <<= 1; val_[1].d &= kAllocMask; } return res; } template void SlotBitmap::ClearSlots(uint32_t mask) { if (SINGLE) { uint32_t count = __builtin_popcount(mask); assert(count <= (val_[0].d & 0xFF)); mask = (mask << 4) | (mask << 18); val_[0].d &= ~mask; val_[0].d -= count; } else { val_[0].d &= ~mask; val_[1].d &= ~mask; } } template void SlotBitmap::Swap(unsigned slot_a, unsigned slot_b) { if (slot_a > slot_b) std::swap(slot_a, slot_b); if constexpr (SINGLE) { uint32_t a = (val_[0].d << (slot_b - slot_a)) ^ val_[0].d; uint32_t bm = (1 << (slot_b + 4)) | (1 << (slot_b + 18)); a &= bm; a |= (a >> (slot_b - slot_a)); val_[0].d ^= a; } else { uint32_t a = (val_[0].d << (slot_b - slot_a)) ^ val_[0].d; a &= (1 << slot_b); a |= (a >> (slot_b - slot_a)); val_[0].d ^= a; a = (val_[1].d << (slot_b - slot_a)) ^ val_[1].d; a &= (1 << slot_b); a |= (a >> (slot_b - slot_a)); val_[1].d ^= a; } } /* ___ _ _ ____ _ _ ____ ___ ___ ____ ____ ____ |__] | | | |_/ |___ | |__] |__| [__ |___ |__] |__| |___ | \_ |___ | |__] | | ___] |___ */ template bool BucketBase::ClearStash(uint8_t fp, unsigned stash_pos, bool probe) { auto cb = [stash_pos, this](unsigned i, unsigned pos) -> SlotId { if (pos == stash_pos) { stash_busy_ &= (~(1u << i)); stash_probe_mask_ &= (~(1u << i)); stash_pos_ &= (~(3u << (i * 2))); assert(0u == ((stash_pos_ >> (i * 2)) & 3)); return 0; } return kNanSlot; }; std::pair res = IterateStash(fp, probe, std::move(cb)); return res.second != kNanSlot; } template void BucketBase::SetHash(unsigned slot_id, uint8_t meta_hash, bool probe) { assert(slot_id < finger_arr_.size()); finger_arr_[slot_id] = meta_hash; slotb_.SetSlot(slot_id, probe); } template bool BucketBase::SetStash(uint8_t fp, unsigned stash_pos, bool probe) { // stash_busy_ is never 0xFFFFF so it's safe to run __builtin_ctz below. unsigned free_slot = __builtin_ctz(~stash_busy_); if (free_slot >= kStashFpLen) return false; stash_arr_[free_slot] = fp; stash_busy_ |= (1u << free_slot); // set the overflow slot // stash_probe_mask_ specifies which records relate to other bucket. stash_probe_mask_ |= (unsigned(probe) << free_slot); // 2 bits denote the bucket index. free_slot *= 2; stash_pos_ &= (~(3 << free_slot)); // clear (can be removed?) stash_pos_ |= (stash_pos << free_slot); // and set return true; } template void BucketBase::SetStashPtr(unsigned stash_pos, uint8_t meta_hash, BucketBase* next) { assert(stash_pos < 4); // we use only kStashFpLen fp slots for handling stash buckets, // therefore if all those slots are used we try neighbor (probing bucket) as a fallback to point // to stash buckets. otherwise we increment overflow count. // if overflow is incremented we will need to check all the stash buckets when looking for a key, // otherwise we can use overflow_index_ to find the the stash bucket efficiently. if (!SetStash(meta_hash, stash_pos, false)) { if (!next->SetStash(meta_hash, stash_pos, true)) { overflow_count_++; } } stash_busy_ |= kStashPresentBit; } template unsigned BucketBase::UnsetStashPtr(uint8_t fp_hash, unsigned stash_pos, BucketBase* next) { /*also needs to ensure that this meta_hash must belongs to other bucket*/ bool clear_success = ClearStash(fp_hash, stash_pos, false); unsigned res = 0; if (!clear_success) { clear_success = next->ClearStash(fp_hash, stash_pos, true); res += clear_success; } if (!clear_success) { assert(overflow_count_ > 0); overflow_count_--; } // kStashPresentBit helps with summarizing all the stash states into a single binary flag. // We need it because of the next, though if we make sure to move stash pointers upon split/delete // towards the owner we should not reach the state where mask1 == 0 but mask2 & // next->stash_probe_mask_ != 0. unsigned mask1 = stash_busy_ & (kStashPresentBit - 1); unsigned mask2 = next->stash_busy_ & (kStashPresentBit - 1); if (((mask1 & (~stash_probe_mask_)) == 0) && (overflow_count_ == 0) && ((mask2 & next->stash_probe_mask_) == 0)) { stash_busy_ &= ~kStashPresentBit; } return res; } #ifdef __s390x__ template uint32_t BucketBase::CompareFP(uint8_t fp) const { static_assert(FpArray{}.size() <= 16); vector unsigned char v1; // Replicate 16 times fp to key_data. for (int i = 0; i < 16; i++) { v1[i] = fp; } // Loads 16 bytes of src into seg_data. vector unsigned char v2 = vec_load_len(finger_arr_.data(), 16); // compare 1-byte vectors seg_data and key_data, dst[i] := ( a[i] == b[i] ) ? 0xFF : 0. vector bool char rv_mask = vec_cmpeq(v1, v2); // collapses 16 msb bits from each byte in rv_mask into mask. int mask = 0; for (int i = 0; i < 16; i++) { if (rv_mask[i]) { mask |= 1 << i; } } return mask; } #else template uint32_t BucketBase::CompareFP(uint8_t fp) const { static_assert(FpArray{}.size() <= 16); // Replicate 16 times fp to key_data. const __m128i key_data = _mm_set1_epi8(fp); // Loads 16 bytes of src into seg_data. __m128i seg_data = mm_loadu_si128(reinterpret_cast(finger_arr_.data())); // compare 16-byte vectors seg_data and key_data, dst[i] := ( a[i] == b[i] ) ? 0xFF : 0. __m128i rv_mask = _mm_cmpeq_epi8(seg_data, key_data); // collapses 16 msb bits from each byte in rv_mask into mask. int mask = _mm_movemask_epi8(rv_mask); // Note: Last 2 operations can be combined in skylake with _mm_cmpeq_epi8_mask. return mask; } #endif // Bucket slot array goes from left to right: [x, x, ...] // Shift right vacates the first slot on the left by shifting all the elements right and // possibly deleting the last one on the right. template bool BucketBase::ShiftRight() { for (int i = NUM_SLOTS - 1; i > 0; --i) { finger_arr_[i] = finger_arr_[i - 1]; } // confusing but correct - slot bit mask LSB corresponds to left part of slot array. // therefore, we shift left slot mask. bool res = slotb_.ShiftLeft(); assert(slotb_.FindEmptySlot() == 0); return res; } template template auto BucketBase::IterateStash(uint8_t fp, bool is_probe, F&& func) const -> ::std::pair { unsigned om = is_probe ? stash_probe_mask_ : ~stash_probe_mask_; unsigned ob = stash_busy_; for (unsigned i = 0; i < kStashFpLen; ++i) { if ((ob & 1) && (stash_arr_[i] == fp) && (om & 1)) { unsigned pos = (stash_pos_ >> (i * 2)) & 3; auto sid = func(i, pos); if (sid != BucketBase::kNanSlot) { return std::pair(pos, sid); } } ob >>= 1; om >>= 1; } return {0, BucketBase::kNanSlot}; } template void VersionedBB::SetVersion(uint64_t version) { absl::little_endian::Store64(version_, version); } /* ____ ____ ____ _ _ ____ _ _ ___ [__ |___ | __ |\/| |___ |\ | | ___] |___ |__] | | |___ | \| | */ // for clang ignore -Wunused-lambda-capture #ifdef __clang__ #pragma clang diagnostic ignored "-Wunused-lambda-capture" #endif template template auto Segment::Bucket::FindByFp(uint8_t fp_hash, bool probe, Pred&& pred) const -> SlotId { unsigned mask = this->Find(fp_hash, probe); if (!mask) return kNanSlot; unsigned delta = __builtin_ctz(mask); mask >>= delta; for (unsigned i = delta; i < kSlotNum; ++i) { // Filterable just by key if constexpr (std::is_invocable_v) { if ((mask & 1) && pred(key[i])) return i; } // Filterable by key and value if constexpr (std::is_invocable_v) { if ((mask & 1) && pred(key[i], value[i])) return i; } mask >>= 1; }; return kNanSlot; } template bool Segment::Bucket::ShiftRight() { bool res = BucketType::ShiftRight(); for (int i = kSlotNum - 1; i > 0; i--) { std::swap(key[i], key[i - 1]); std::swap(value[i], value[i - 1]); } return res; } // stash_pos is index of the stash bucket, in the range of [0, STASH_BUCKET_NUM). template void Segment::RemoveStashReference(unsigned stash_pos, Hash_t key_hash) { LogicalBid y = HomeIndex(key_hash); uint8_t fp_hash = key_hash & kFpMask; auto* target = &bucket_[y]; auto* next = &bucket_[NextBid(y)]; target->UnsetStashPtr(fp_hash, stash_pos, next); } template auto Segment::TryMoveFromStash(unsigned stash_id, unsigned stash_slot_id, Hash_t key_hash) -> Iterator { LogicalBid bid = HomeIndex(key_hash); uint8_t hash_fp = key_hash & kFpMask; PhysicalBid stash_bid = kBucketNum + stash_id; auto& key = Key(stash_bid, stash_slot_id); auto& value = Value(stash_bid, stash_slot_id); int reg_slot = bucket_[bid].TryInsertToBucket(std::forward(key), std::forward(value), hash_fp, false); if (reg_slot < 0) { bid = NextBid(bid); reg_slot = bucket_[bid].TryInsertToBucket(std::forward(key), std::forward(value), hash_fp, true); } if (reg_slot >= 0) { if constexpr (kUseVersion) { // We maintain the invariant for the physical bucket by updating the version when // the entries move between buckets. uint64_t ver = bucket_[stash_bid].GetVersion(); bucket_[bid].UpdateVersion(ver); } RemoveStashReference(stash_id, key_hash); return Iterator{bid, SlotId(reg_slot)}; } return Iterator{}; } template template auto Segment::Insert(U&& key, V&& value, Hash_t key_hash, Pred&& pred, OnMoveCb&& on_move_cb) -> std::pair { Iterator it = FindIt(key_hash, pred); if (it.found()) { return std::make_pair(it, false); /* duplicate insert*/ } it = InsertUniq(std::forward(key), std::forward(value), key_hash, true, std::forward(on_move_cb)); return std::make_pair(it, it.found()); } template template auto Segment::FindIt(Hash_t key_hash, Pred&& pred) const -> Iterator { LogicalBid bidx = HomeIndex(key_hash); const Bucket& target = bucket_[bidx]; // It helps a bit (10% on my home machine) and more importantly, it does not hurt // since we are going to access this memory in a bit. __builtin_prefetch(&target); uint8_t fp_hash = key_hash & kFpMask; SlotId sid = target.FindByFp(fp_hash, false, pred); if (sid != BucketType::kNanSlot) { return Iterator{bidx, sid}; } LogicalBid nid = NextBid(bidx); const Bucket& probe = GetBucket(nid); sid = probe.FindByFp(fp_hash, true, pred); #ifdef ENABLE_DASH_STATS stats.neighbour_probes++; #endif if (sid != BucketType::kNanSlot) { return Iterator{nid, sid}; } if (!target.HasStash()) { return Iterator{}; } auto stash_cb = [&](unsigned overflow_index, PhysicalBid pos) -> SlotId { assert(pos < kStashBucketNum); pos += kBucketNum; const Bucket& bucket = bucket_[pos]; return bucket.FindByFp(fp_hash, false, pred); }; if (target.HasStashOverflow()) { #ifdef ENABLE_DASH_STATS stats.stash_overflow_probes++; #endif for (unsigned i = 0; i < kStashBucketNum; ++i) { auto sid = stash_cb(0, i); if (sid != BucketType::kNanSlot) { return Iterator{PhysicalBid(kBucketNum + i), sid}; } } // We exit because we searched through all stash buckets anyway, no need to use overflow fps. return Iterator{}; } #ifdef ENABLE_DASH_STATS stats.stash_probes++; #endif auto stash_res = target.IterateStash(fp_hash, false, stash_cb); if (stash_res.second != BucketType::kNanSlot) { return Iterator{PhysicalBid(kBucketNum + stash_res.first), stash_res.second}; } stash_res = probe.IterateStash(fp_hash, true, stash_cb); if (stash_res.second != BucketType::kNanSlot) { return Iterator{PhysicalBid(kBucketNum + stash_res.first), stash_res.second}; } return Iterator{}; } template void Segment::Prefetch(Hash_t key_hash) const { LogicalBid bidx = HomeIndex(key_hash); const Bucket& target = bucket_[bidx]; // Prefetch the home bucket that might hold the key with high probability. __builtin_prefetch(&target, 0, 1); } template template void Segment::TraverseAll(Cb&& cb) const { for (uint8_t i = 0; i < kTotalBuckets; ++i) { bucket_[i].ForEachSlot([&](auto*, SlotId slot, bool) { cb(Iterator{i, slot}); }); } } template void Segment::Clear() { for (unsigned i = 0; i < kTotalBuckets; ++i) { bucket_[i].Clear(); bucket_[i].ClearStashPtrs(); } } template void Segment::Delete(const Iterator& it, Hash_t key_hash) { assert(it.found()); auto& b = bucket_[it.index]; if (it.index >= kBucketNum) { RemoveStashReference(it.index - kBucketNum, key_hash); } b.Delete(it.slot); } // Split items from the left segment to the right during the growth phase. // right segment will have all the items with lsb at local_depth ==1 . template template void Segment::Split(HFunc&& hfn, Segment* dest_right, MoveCb&& on_move_cb) { ++local_depth_; dest_right->local_depth_ = local_depth_; // versioning does not work when entries move across buckets. // we need to setup rules on how we do that // do_versioning(); auto is_mine = [this](Hash_t hash) { return (hash >> (64 - local_depth_) & 1) == 0; }; auto update_version = [dest_right](const Bucket& src, PhysicalBid dest_id) { (void)dest_id; if constexpr (kUseVersion) { // Maintaining consistent versioning. uint64_t ver = src.GetVersion(); dest_right->bucket_[dest_id].UpdateVersion(ver); } }; for (unsigned i = 0; i < kBucketNum; ++i) { uint32_t invalid_mask = 0; auto cb = [&](auto* bucket, unsigned slot, bool probe) { auto& key = bucket->key[slot]; Hash_t hash = hfn(key); // we extract local_depth bits from the left part of the hash. Since we extended local_depth, // we added an additional bit to the right, therefore we need to look at lsb of the extract. if (is_mine(hash)) return; // keep this key in the source invalid_mask |= (1u << slot); // We pass dummy callback because we are not interested to track movements in the newly // created segment. Iterator it = dest_right->InsertUniq(std::forward(bucket->key[slot]), std::forward(bucket->value[slot]), hash, false, [](auto&&...) {}); // we move items residing in a regular bucket to a new segment. // Note 1: in case we are somehow attacked with items that after the split // will go into the same segment, we may have a problem. // It is highly unlikely that this happens with real world data. // Note 2: Dragonfly replication is in fact is such unlikely attack. Since we go over // the source table in a special order (go over all the segments for bucket 0, // then for all the segments for bucket 1 etc), what happens is that the rdb stream is full // of items with the same bucket id, say 0. Lots of items will go to the initial segment // into bucket 0, which will become full, then bucket 1 will get full, // and then the 4 stash buckets in the segment. Then the segment will have to split even // though only 6 buckets are used just because of this // extreme skewness of keys distribution. When a segment splits, we will still // have items going into bucket 0 in the new segment. To alleviate this effect we usually // reserve dash table to have enough segments during full sync to avoid handling those // ill-formed splits. // TODO: To protect ourselves again such situations we should use random seed // for our dash hash function, thus avoiding the case where someone, on purpose or due to // selective bias will be able to hit our dashtable with items with the same bucket id. assert(it.found()); update_version(*bucket, it.index); on_move_cb(segment_id_, i, dest_right->segment_id_, it.index); }; bucket_[i].ForEachSlot(std::move(cb)); bucket_[i].ClearSlots(invalid_mask); } for (unsigned i = 0; i < kStashBucketNum; ++i) { uint32_t invalid_mask = 0; PhysicalBid bid = kBucketNum + i; Bucket& stash = bucket_[bid]; auto cb = [&](auto* bucket, unsigned slot, bool probe) { auto& key = bucket->key[slot]; Hash_t hash = hfn(key); if (is_mine(hash)) { // If the entry stays in the same segment we try to unload it back to the regular bucket. Iterator it = TryMoveFromStash(i, slot, hash); if (it.found()) { invalid_mask |= (1u << slot); on_move_cb(segment_id_, i, segment_id_, it.index); } return; } invalid_mask |= (1u << slot); auto it = dest_right->InsertUniq(std::forward(bucket->key[slot]), std::forward(bucket->value[slot]), hash, false, /* not interested in these movements */ [](auto&&...) {}); (void)it; assert(it.index != kNanBid); update_version(*bucket, it.index); on_move_cb(segment_id_, i, dest_right->segment_id_, it.index); // Remove stash reference pointing to stash bucket i. RemoveStashReference(i, hash); }; stash.ForEachSlot(std::move(cb)); stash.ClearSlots(invalid_mask); } } template int Segment::MoveToOther(bool own_items, unsigned from_bid, unsigned to_bid) { assert(from_bid < kBucketNum && to_bid < kBucketNum); auto& src = bucket_[from_bid]; uint32_t mask = src.GetProbe(!own_items); if (mask == 0) { return -1; } int src_slot = __builtin_ctz(mask); int dst_slot = bucket_[to_bid].TryInsertToBucket(std::forward(src.key[src_slot]), std::forward(src.value[src_slot]), src.Fp(src_slot), own_items); if (dst_slot < 0) return -1; // We never decrease the version of the entry. if constexpr (kUseVersion) { auto& dst = bucket_[to_bid]; dst.UpdateVersion(src.GetVersion()); } src.Delete(src_slot); return src_slot; } template bool Segment::CheckIfMovesToOther(bool own_items, unsigned from, unsigned to) const { const auto& src = GetBucket(from); uint32_t mask = src.GetProbe(!own_items); if (mask == 0) { return false; } const auto& dest = GetBucket(to); return dest.IsFull() ? false : true; } template template auto Segment::InsertUniq(U&& key, V&& value, Hash_t key_hash, bool spread, OnMoveCb&& on_move_cb) -> Iterator { const uint8_t bid = HomeIndex(key_hash); const uint8_t nid = NextBid(bid); Bucket& target = bucket_[bid]; Bucket& neighbor = bucket_[nid]; Bucket* insert_first = ⌖ uint8_t meta_hash = key_hash & kFpMask; unsigned ts = target.Size(), ns = neighbor.Size(); bool probe = false; if (spread && ts > ns) { insert_first = &neighbor; probe = true; } int slot = insert_first->TryInsertToBucket(std::forward(key), std::forward(value), meta_hash, probe); if (slot >= 0) { return Iterator{PhysicalBid(insert_first - bucket_), uint8_t(slot)}; } if (!spread) { int slot = neighbor.TryInsertToBucket(std::forward(key), std::forward(value), meta_hash, true); if (slot >= 0) { return Iterator{nid, uint8_t(slot)}; } } int displace_index = MoveToOther(true, nid, NextBid(nid)); if (displace_index >= 0) { neighbor.Insert(displace_index, std::forward(key), std::forward(value), meta_hash, true); on_move_cb(segment_id_, nid, NextBid(nid)); return Iterator{nid, uint8_t(displace_index)}; } unsigned prev_idx = PrevBid(bid); displace_index = MoveToOther(false, bid, prev_idx); if (displace_index >= 0) { target.Insert(displace_index, std::forward(key), std::forward(value), meta_hash, false); on_move_cb(segment_id_, bid, prev_idx); return Iterator{bid, uint8_t(displace_index)}; } // we balance stash fill rate by starting from y % STASH_BUCKET_NUM. for (unsigned i = 0; i < kStashBucketNum; ++i) { unsigned stash_pos = (bid + i) % kStashBucketNum; int stash_slot = bucket_[kBucketNum + stash_pos].TryInsertToBucket( std::forward(key), std::forward(value), meta_hash, false); if (stash_slot >= 0) { target.SetStashPtr(stash_pos, meta_hash, &neighbor); return Iterator{PhysicalBid(kBucketNum + stash_pos), uint8_t(stash_slot)}; } } return Iterator{}; } template template std::enable_if_t Segment::CVCOnInsert(uint64_t ver_threshold, Hash_t key_hash, uint8_t bid_res[2]) const { const LogicalBid bid = HomeIndex(key_hash); const LogicalBid nid = NextBid(bid); const Bucket& target = GetBucket(bid); const Bucket& neighbor = GetBucket(nid); uint8_t first = target.Size() > neighbor.Size() ? nid : bid; const Bucket& bfirst = bucket_[first]; if (!bfirst.IsFull()) { unsigned cnt = 0; if (!bfirst.IsEmpty() && bfirst.GetVersion() < ver_threshold) { bid_res[cnt++] = first; } return cnt; } // both nid and bid are full. const LogicalBid after_next = NextBid(nid); auto do_fun = [this, ver_threshold, &bid_res](auto bid, auto nid) { unsigned cnt = 0; // We could tighten the checks here and below because // if nid is less than ver_threshold, than nid won't be affected and won't cross // ver_threshold as well. if (GetBucket(bid).GetVersion() < ver_threshold) bid_res[cnt++] = bid; if (!GetBucket(nid).IsEmpty() && GetBucket(nid).GetVersion() < ver_threshold) bid_res[cnt++] = nid; return cnt; }; if (CheckIfMovesToOther(true, nid, after_next)) { return do_fun(nid, after_next); } const uint8_t prev_bid = PrevBid(bid); if (CheckIfMovesToOther(false, bid, prev_bid)) { return do_fun(bid, prev_bid); } // Important to repeat exactly the insertion logic of InsertUnique. for (unsigned i = 0; i < kStashBucketNum; ++i) { PhysicalBid stash_bid = kBucketNum + ((bid + i) % kStashBucketNum); const Bucket& stash = GetBucket(stash_bid); if (!stash.IsFull()) { unsigned cnt = 0; if (!stash.IsEmpty() && stash.GetVersion() < ver_threshold) bid_res[cnt++] = stash_bid; return cnt; } } return UINT16_MAX; } template template std::enable_if_t Segment::CVCOnBump(uint64_t ver_threshold, unsigned bid, unsigned slot, Hash_t hash, uint8_t result_bid[3]) const { if (bid < kBucketNum) { // Right now we do not migrate entries from nid to bid, only from stash to normal buckets. // The reason for this is that CVCOnBump implementation swaps the slots of the same bucket // so there is no further action needed. return 0; } // Stash case. // There are three actors (interesting buckets). The stash bucket, the target bucket and its // adjacent bucket (probe). To understand the code below consider the cases in CVCOnBump: // 1. If the bid is not a stash bucket, then just swap the slots of the target. // 2. If there is empty space in target or probe bucket insert the slot there and remove // it from the stash bucket. // 3. If there is no empty space then we need to swap slots with either the target or the probe // bucket. Furthermore, if the target or the probe have one of their stash bits reference the // stash, then the stash bit entry is cleared. In total 2 buckets are modified. // Case 1 is handled by the if statement above and cases 2 and 3 below. We should return via // result_bid all the buckets(with version less than threshold) that CVCOnBump will modify. // Note, that for case 2 & 3 we might return an extra bucket id even though this bucket was not // changed. An example of that is TryMoveFromStash which will first try to insert on the target // bucket and if that fails it will retry with the probe bucket. Since we don't really know // which of the two we insert to we are pesimistic and assume that both of them got modified. I // suspect we could optimize this out by looking at the fingerprints but for now I care about // correctness and returning the correct modified buckets. Besides, we are on a path of updating // the version anyway which will assert that the bucket won't be send again during snapshotting. unsigned result = 0; if (bucket_[bid].GetVersion() < ver_threshold) { result_bid[result++] = bid; } const uint8_t target_bid = HomeIndex(hash); result_bid[result++] = target_bid; const uint8_t probing_bid = NextBid(target_bid); result_bid[result++] = probing_bid; return result; } template template void Segment::TraverseBucket(PhysicalBid bid, Cb&& cb) { assert(bid < kTotalBuckets); const Bucket& b = GetBucket(bid); b.ForEachSlot([&](auto* bucket, uint8_t slot, bool probe) { cb(Iterator{bid, slot}); }); } template template bool Segment::TraverseLogicalBucket(LogicalBid bid, HashFn&& hfun, Cb&& cb) const { assert(bid < kBucketNum); const Bucket& b = bucket_[bid]; bool found = false; if (b.GetProbe(false)) { // Check items that this bucket owns. b.ForEachSlot([&](auto* bucket, SlotId slot, bool probe) { if (!probe) { found = true; cb(Iterator{bid, slot}); } }); } uint8_t nid = NextBid(bid); const Bucket& next = GetBucket(nid); // check for probing entries in the next bucket, i.e. those that should reside in b. if (next.GetProbe(true)) { next.ForEachSlot([&](auto* bucket, SlotId slot, bool probe) { if (probe) { found = true; assert(HomeIndex(hfun(bucket->key[slot])) == bid); cb(Iterator{nid, slot}); } }); } // Finally go over stash buckets and find those entries that belong to b. if (b.HasStash()) { // do not bother with overflow fps. Just go over all the stash buckets. for (uint8_t j = kBucketNum; j < kTotalBuckets; ++j) { const auto& stashb = bucket_[j]; stashb.ForEachSlot([&](auto* bucket, SlotId slot, bool probe) { if (HomeIndex(hfun(bucket->key[slot])) == bid) { found = true; cb(Iterator{j, slot}); } }); } } return found; } template size_t Segment::SlowSize() const { size_t res = 0; for (unsigned i = 0; i < kTotalBuckets; ++i) { res += bucket_[i].Size(); } return res; } template auto Segment::FindValidStartingFrom(PhysicalBid bid, unsigned slot) const -> Iterator { while (bid < kTotalBuckets) { uint32_t mask = bucket_[bid].GetBusy(); mask >>= slot; if (mask) { return Iterator(bid, slot + __builtin_ctz(mask)); } ++bid; slot = 0; } return Iterator{}; } template template auto Segment::BumpUp(uint8_t bid, SlotId slot, Hash_t key_hash, const BumpPolicy& bp, OnMoveCb&& on_move_cb) -> Iterator { auto& from = GetBucket(bid); if (!bp.CanBump(from.key[slot])) { return Iterator{bid, slot}; } if (bid < kBucketNum) { // non stash case. if (slot > 0 && bp.CanBump(from.key[slot - 1])) { from.Swap(slot - 1, slot); return Iterator{bid, uint8_t(slot - 1)}; } // TODO: We could promote further, by swapping probing bucket with its previous one. return Iterator{bid, slot}; } // stash bucket // We swap the item with the item in the "normal" bucket in the last slot. unsigned stash_pos = bid - kBucketNum; // If we have an empty space for some reason just unload the stash entry. if (Iterator it = TryMoveFromStash(stash_pos, slot, key_hash); it.found()) { // TryMoveFromStash handles versions internally. from.Delete(slot); on_move_cb(segment_id_, bid, it.index); return it; } uint8_t target_bid = HomeIndex(key_hash); uint8_t nid = NextBid(target_bid); uint8_t fp_hash = key_hash & kFpMask; assert(fp_hash == from.Fp(slot)); // determine which bucket one we gonna swap. // we swap with the bucket the references the stash entry, not necessary its owning // bucket. auto& target = bucket_[target_bid]; auto& next = bucket_[nid]; // bucket_offs - 0 if exact bucket, 1 if neighbour unsigned bucket_offs = target.UnsetStashPtr(fp_hash, stash_pos, &next); uint8_t swap_bid = (target_bid + bucket_offs) % kBucketNum; auto& swapb = bucket_[swap_bid]; constexpr unsigned kLastSlot = kSlotNum - 1; assert(swapb.GetBusy() & (1 << kLastSlot)); // Don't move sticky items back to the stash because they're not evictable // TODO: search for first swappable item if (!bp.CanBump(swapb.key[kLastSlot])) { target.SetStashPtr(stash_pos, fp_hash, &next); return Iterator{bid, slot}; } uint8_t swap_fp = swapb.Fp(kLastSlot); // is_probing for the existing entry in swapb. It's unrelated to bucket_offs, // i.e. it could be true even if bucket_offs is 0. bool is_probing = swapb.GetProbe(true) & (1 << kLastSlot); // swap keys, values and fps. update slots meta. std::swap(from.key[slot], swapb.key[kLastSlot]); std::swap(from.value[slot], swapb.value[kLastSlot]); from.Delete(slot); from.SetHash(slot, swap_fp, false); swapb.Delete(kLastSlot); swapb.SetHash(kLastSlot, fp_hash, bucket_offs == 1); // update versions. if constexpr (kUseVersion) { uint64_t from_ver = from.GetVersion(); uint64_t swap_ver = swapb.GetVersion(); if (from_ver < swap_ver) { from.SetVersion(swap_ver); } else { swapb.SetVersion(from_ver); } } // update ptr for swapped items if (is_probing) { LogicalBid prev_bid = PrevBid(swap_bid); auto& prevb = bucket_[prev_bid]; prevb.SetStashPtr(stash_pos, swap_fp, &swapb); } else { // stash_ptr resides in the current or the next bucket. LogicalBid next_bid = NextBid(swap_bid); swapb.SetStashPtr(stash_pos, swap_fp, bucket_ + next_bid); } on_move_cb(segment_id_, bid, swap_bid); on_move_cb(segment_id_, swap_bid, bid); return Iterator{swap_bid, kLastSlot}; } template template unsigned Segment::UnloadStash(HFunc&& hfunc, OnMoveCb&& on_move_cb) { unsigned moved = 0; for (unsigned i = 0; i < kStashBucketNum; ++i) { unsigned bid = kBucketNum + i; Bucket& stash = bucket_[bid]; uint32_t invalid_mask = 0; auto cb = [&](auto* bucket, unsigned slot, bool probe) { auto& key = bucket->key[slot]; Hash_t hash = hfunc(key); Iterator res = TryMoveFromStash(i, slot, hash); if (res.found()) { ++moved; invalid_mask |= (1u << slot); on_move_cb(segment_id_, i, res.index); } }; stash.ForEachSlot(cb); stash.ClearSlots(invalid_mask); } return moved; } } // namespace detail } // namespace dfly ================================================ FILE: src/core/dash_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #define ENABLE_DASH_STATS #include #include #include #include #include #include "base/gtest.h" #include "base/hash.h" #include "base/logging.h" #include "base/zipf_gen.h" #include "core/dash.h" #include "io/file.h" #include "io/line_reader.h" extern "C" { #include "redis/dict.h" #include "redis/sds.h" #include "redis/zmalloc.h" } #if defined(__clang__) #pragma clang diagnostic ignored "-Wunused-const-variable" #endif namespace dfly { static uint64_t callbackHash(const void* key) { return XXH64(&key, sizeof(key), 0); } template auto EqTo(const K& key) { return [&key](const auto& probe) { return key == probe; }; } static dictType IntDict = {callbackHash, NULL, NULL, NULL, NULL, NULL, NULL}; static uint64_t dictSdsHash(const void* key) { return dictGenHashFunction((unsigned char*)key, sdslen((char*)key)); } static int dictSdsKeyCompare(dict*, const void* key1, const void* key2) { int l1, l2; l1 = sdslen((sds)key1); l2 = sdslen((sds)key2); if (l1 != l2) return 0; return memcmp(key1, key2, l1) == 0; } static dictType SdsDict = { dictSdsHash, /* hash function */ NULL, /* key dup */ NULL, /* val dup */ dictSdsKeyCompare, /* key compare */ NULL, // dictSdsDestructor, /* key destructor */ NULL, /* val destructor */ NULL, }; using namespace std; struct Buf24 { char buf[20]; uint32_t index; Buf24(uint32_t i = 0) : index(i) { } }; struct BasicDashPolicy { enum { kSlotNum = 12, kBucketNum = 64 }; static constexpr bool kUseVersion = false; template static void DestroyValue(const U&) { } template static void DestroyKey(const U&) { } template static bool Equal(U&& u, V&& v) { return u == v; } }; struct UInt64Policy : public BasicDashPolicy { static uint64_t HashFn(uint64_t v) { return XXH3_64bits(&v, sizeof(v)); } }; class CappedResource final : public PMR_NS::memory_resource { public: explicit CappedResource(size_t cap) : cap_(cap) { } size_t used() const { return used_; } private: void* do_allocate(std::size_t size, std::size_t align) { if (used_ + size > cap_) throw std::bad_alloc{}; void* res = PMR_NS::get_default_resource()->allocate(size, align); used_ += size; return res; } void do_deallocate(void* ptr, std::size_t size, std::size_t align) { used_ -= size; PMR_NS::get_default_resource()->deallocate(ptr, size, align); } bool do_is_equal(const PMR_NS::memory_resource& o) const noexcept { return this == &o; } size_t cap_; size_t used_ = 0; }; using Segment = detail::Segment; using Dash64 = DashTable; struct RelaxedBumpPolicy { bool CanBump(uint64_t key) const { return true; } void OnMove(Dash64::Cursor source, Dash64::Cursor dest) { } }; constexpr auto kSegTax = Segment::kTaxSize; constexpr size_t kMaxSize = Segment::kMaxSize; constexpr size_t kSegSize = sizeof(Segment); class DashTest : public testing::Test { protected: static void SetUpTestSuite() { init_zmalloc_threadlocal(mi_heap_get_backing()); } DashTest() : segment_(1, 0, PMR_NS::get_default_resource()) { } bool Find(Segment::Key_t key, Segment::Value_t* val) const { uint64_t hash = dt_.DoHash(key); auto it = segment_.FindIt(hash, EqTo(key)); if (!it.found()) return false; *val = segment_.Value(it.index, it.slot); return true; } bool Contains(Segment::Key_t key) const { uint64_t hash = dt_.DoHash(key); auto it = segment_.FindIt(hash, EqTo(key)); return it.found(); } set FillSegment(unsigned bid); Segment segment_; Dash64 dt_; }; set DashTest::FillSegment(unsigned bid) { std::set keys; for (Segment::Key_t key = 0; key < 1000000u; ++key) { uint64_t hash = dt_.DoHash(key); unsigned bi = (hash >> 8) % Segment::kBucketNum; if (bi != bid) continue; uint8_t fp = hash & 0xFF; if (fp > 2) // limit fps considerably to find interesting cases. continue; auto [it, success] = segment_.Insert(key, 0, hash, EqTo(key), [](auto&&...) {}); if (!success) { LOG(INFO) << "Stopped at " << key; break; } CHECK(it.found()); keys.insert(key); } return keys; } TEST_F(DashTest, Hash) { for (uint64_t i = 0; i < 100; ++i) { uint64_t hash = dt_.DoHash(i); if (hash >> 63) { VLOG(1) << "i " << i << ", Hash " << hash; } } } TEST_F(DashTest, SlotBitmap) { detail::SlotBitmap<14> slot; slot.SetSlot(1, true); slot.SetSlot(5, false); EXPECT_EQ(34, slot.GetBusy()); EXPECT_EQ(2, slot.GetProbe(true)); } TEST_F(DashTest, Basic) { Segment::Key_t key = 0; Segment::Value_t val = 0; uint64_t hash = dt_.DoHash(key); EXPECT_TRUE(segment_.Insert(key, val, hash, EqTo(key), [](auto&&...) {}).second); auto [it, res] = segment_.Insert(key, val, hash, EqTo(key), [](auto&&...) {}); EXPECT_TRUE(!res && it.found()); EXPECT_TRUE(Find(key, &val)); EXPECT_EQ(0, val.index); EXPECT_FALSE(Find(1, &val)); EXPECT_EQ(1, segment_.SlowSize()); unsigned has_called = 0; auto cb = [&](const auto& it) { ++has_called; }; auto hfun = &UInt64Policy::HashFn; auto cursor = segment_.TraverseLogicalBucket((hash >> 8) % Segment::kBucketNum, hfun, cb); ASSERT_EQ(1, has_called); ASSERT_EQ(0, segment_.TraverseLogicalBucket(cursor, hfun, cb)); ASSERT_EQ(1, has_called); EXPECT_EQ(0, segment_.GetVersion(0)); } TEST_F(DashTest, Segment) { std::unique_ptr seg(new Segment(1, 0, PMR_NS::get_default_resource())); #ifndef __APPLE__ LOG(INFO) << "Segment size " << sizeof(Segment) << " malloc size: " << malloc_usable_size(seg.get()); #endif set keys = FillSegment(0); EXPECT_TRUE(segment_.GetBucket(0).IsFull() && segment_.GetBucket(1).IsFull()); for (size_t i = 2; i < Segment::kBucketNum; ++i) { EXPECT_EQ(0, segment_.GetBucket(i).Size()); } EXPECT_EQ(6 * Segment::kSlotNum, keys.size()); EXPECT_EQ(6 * Segment::kSlotNum, segment_.SlowSize()); auto hfun = &UInt64Policy::HashFn; unsigned has_called = 0; auto cb = [&](const Segment::Iterator& it) { ++has_called; ASSERT_EQ(1, keys.count(segment_.Key(it.index, it.slot))); }; segment_.TraverseAll(cb); ASSERT_EQ(keys.size(), has_called); ASSERT_TRUE(segment_.GetBucket(Segment::kBucketNum).IsFull()); std::array arr; uint64_t* next = arr.begin(); for (unsigned i = Segment::kBucketNum; i < Segment::kBucketNum + 2; ++i) { const auto* k = &segment_.Key(i, 0); next = std::copy(k, k + Segment::kSlotNum, next); } for (auto k : arr) { auto hash = hfun(k); auto it = segment_.FindIt(hash, [&k](const auto& probe) { return k == probe; }); ASSERT_TRUE(it.found()); segment_.Delete(it, hash); } EXPECT_EQ(4 * Segment::kSlotNum, segment_.SlowSize()); ASSERT_FALSE(Contains(arr.front())); } TEST_F(DashTest, SegmentFull) { std::equal_to<> eq; for (Segment::Key_t key = 8000; key < 15000u; ++key) { uint64_t hash = dt_.DoHash(key); bool res = segment_.Insert(key, 0, hash, eq, [](auto&&...) {}).second; if (!res) { LOG(INFO) << "Stopped at " << key; break; } } EXPECT_GT(segment_.SlowSize(), Segment::capacity() * 0.85); LOG(INFO) << "Utilization " << double(segment_.SlowSize()) / Segment::capacity() << " num probing buckets: " << segment_.NumProbingBuckets(); LOG(INFO) << "NB: " << segment_.stats.neighbour_probes << " SP: " << segment_.stats.stash_probes << " SOP: " << segment_.stats.stash_overflow_probes; segment_.stats.neighbour_probes = segment_.stats.stash_overflow_probes = segment_.stats.stash_probes = 0; for (Segment::Key_t key = 0; key < 10000u; ++key) { Contains(key); } LOG(INFO) << segment_.stats.neighbour_probes << " " << segment_.stats.stash_probes << " " << segment_.stats.stash_overflow_probes; uint32_t busy = segment_.GetBucket(0).GetBusy(); uint32_t probe = segment_.GetBucket(0).GetProbe(true); EXPECT_EQ((1 << 12) - 1, busy); // Size 12 EXPECT_EQ(539, probe); // verified by running since the test is deterministic. unsigned keys[12] = {8045, 8085, 8217, 8330, 8337, 8381, 8432, 8506, 8587, 8605, 8612, 8725}; for (unsigned i = 0; i < 12; ++i) { ASSERT_EQ(keys[i], segment_.Key(0, i)); } } TEST_F(DashTest, FirstStash) { constexpr unsigned kRegularCapacity = Segment::kBucketNum * Segment::kSlotNum; unsigned less_seventy = 0; for (unsigned j = 0; j < 100; ++j) { unsigned num_items = 0; for (unsigned i = 0; i < 1000; ++i) { uint64_t key = i + j * 2000; uint64_t hash = dt_.DoHash(key); auto [it, inserted] = segment_.Insert(key, 0, hash, equal_to<>{}, [](auto&&...) {}); ASSERT_TRUE(inserted); if (it.index >= Segment::kBucketNum) { // stash iterator break; } ++num_items; } segment_.Clear(); // With high probability, we can expect 66% of the keys added without stashes. ASSERT_GT(num_items, kRegularCapacity * 0.66); if (num_items < kRegularCapacity * 0.7) { ++less_seventy; } } LOG(INFO) << "Less than 70% of keys in regular buckets: " << less_seventy; } TEST_F(DashTest, Split) { // fills segment with maximum keys that must reside in bucket id 0. set keys = FillSegment(0); Segment::Value_t val; Segment s2{2, 0, PMR_NS::get_default_resource()}; // segment with local depth 2. segment_.Split(&UInt64Policy::HashFn, &s2, [](auto&...) {}); unsigned sum[2] = {0}; for (auto key : keys) { auto eq = [key](const auto& probe) { return key == probe; }; auto it1 = segment_.FindIt(dt_.DoHash(key), eq); auto it2 = s2.FindIt(dt_.DoHash(key), eq); ASSERT_NE(it1.found(), it2.found()) << key; sum[0] += it1.found(); sum[1] += it2.found(); } ASSERT_EQ(segment_.SlowSize(), sum[0]); EXPECT_EQ(s2.SlowSize(), sum[1]); EXPECT_EQ(keys.size(), sum[0] + sum[1]); EXPECT_EQ(6 * Segment::kSlotNum, keys.size()); } TEST_F(DashTest, Merge) { constexpr size_t kNumItems = 4000; std::vector keys; for (uint64_t i = 0; i < kNumItems; ++i) { auto [it, inserted] = dt_.Insert(i, i); if (inserted) { keys.push_back(i); } } EXPECT_EQ(dt_.depth(), 3); // keep only ~5% size_t keys_to_keep = keys.size() * 0.05; for (size_t i = keys_to_keep; i < keys.size(); ++i) { dt_.Erase(keys[i]); } keys.resize(keys_to_keep); EXPECT_EQ(dt_.unique_segments(), 8); size_t dir_size = dt_.GetSegmentCount(); // Iteratively merge segments until all reach depth 1 // Use multiple passes since merging changes buddy relationships while (true) { bool merged_any = false; for (size_t seg_id = 0; seg_id < dir_size; seg_id++) { auto* seg = dt_.GetSegment(seg_id); size_t local_depth = seg->local_depth(); if (local_depth == 1) continue; size_t buddy_id = dt_.FindBuddyId(seg_id); if (buddy_id == seg_id) continue; // Skip if seg_id > buddy_id to avoid processing the same pair twice // (FindBuddyId is symmetric, so we see each pair from both directions) if (seg_id > buddy_id) continue; auto* buddy = dt_.GetSegment(buddy_id); // Preconditions to merge: (< 25% of capacity) size_t combined_size = seg->SlowSize() + buddy->SlowSize(); size_t safe_threshold = static_cast(0.25 * seg->capacity()); if (combined_size <= safe_threshold) { dt_.Merge(seg_id, buddy_id); merged_any = true; } } if (!merged_any) break; } EXPECT_EQ(dt_.unique_segments(), 2); for (size_t seg_id = 0; seg_id < dir_size; seg_id++) { auto* seg = dt_.GetSegment(seg_id); EXPECT_EQ(seg->local_depth(), 1); } for (size_t key : keys) { EXPECT_EQ(dt_.Find(key).is_done(), false); } EXPECT_EQ(dt_.bucket_count(), (Segment::kBucketNum + Segment::kStashBucketNum) * 2); } TEST_F(DashTest, MergeFailureRollback) { std::vector all_keys; std::vector keep_keys; std::vector buddy_keys; // Insert enough items to create 4 segments (depth 2) and fill them more for (uint64_t i = 0; i < 5000; ++i) { auto [it, inserted] = dt_.Insert(i, i); if (inserted) { all_keys.push_back(i); } } EXPECT_GE(dt_.depth(), 2); unsigned sid = 0; size_t buddy_id = dt_.FindBuddyId(sid); EXPECT_NE(buddy_id, sid); auto* src = dt_.GetSegment(sid); auto* buddy = dt_.GetSegment(buddy_id); for (uint64_t key : all_keys) { auto it = dt_.Find(key); if (!it.is_done()) { uint64_t hash = dt_.DoHash(key); uint32_t seg_id = hash >> (64 - dt_.depth()); if (seg_id == 0) { keep_keys.push_back(key); } else if (seg_id == buddy_id) { buddy_keys.push_back(key); } } } size_t total_size_before = dt_.size(); bool merge_succeeded = dt_.Merge(sid, buddy_id); EXPECT_EQ(dt_.size(), total_size_before); // Bucket layout might change after rollback. We only get data parity, not // a complete layout rollback. // For example, InsertUniq can displace existing items in the keep segment // to make room for items being moved from buddy. // After rollback, src and buddy pointers should still be valid for (auto key : keep_keys) { uint64_t hash = dt_.DoHash(key); auto it = src->FindIt(hash, EqTo(key)); EXPECT_TRUE(it.found()); } for (auto key : buddy_keys) { uint64_t hash = dt_.DoHash(key); auto it = buddy->FindIt(hash, EqTo(key)); EXPECT_TRUE(it.found()); } EXPECT_FALSE(merge_succeeded); } // Verify that FindBuddyId is symmetric: if FindBuddyId(x) = y, then FindBuddyId(y) = x. TEST_F(DashTest, FindBuddySymmetry) { for (uint64_t i = 0; i < 4000; ++i) { dt_.Insert(i, i); } EXPECT_GE(dt_.depth(), 3); size_t dir_size = dt_.GetSegmentCount(); for (size_t seg_id = 0; seg_id < dir_size; seg_id++) { auto* seg = dt_.GetSegment(seg_id); if (seg->local_depth() == 1) continue; size_t buddy_id = dt_.FindBuddyId(seg_id); if (buddy_id == seg_id) continue; // Symmetry check size_t reverse_buddy_id = dt_.FindBuddyId(buddy_id); EXPECT_EQ(reverse_buddy_id, seg_id) << "FindBuddyId not symmetric: FindBuddyId(" << seg_id << ")=" << buddy_id << " but FindBuddyId(" << buddy_id << ")=" << reverse_buddy_id; } } // Verify dt_.size() is unchanged after merge (items moved, not deleted). TEST_F(DashTest, MergePreservesSize) { for (uint64_t i = 0; i < 4000; ++i) { dt_.Insert(i, i); } // Delete most keys to make merge feasible for (uint64_t i = 200; i < 4000; ++i) { dt_.Erase(i); } size_t size_before = dt_.size(); size_t dir_size = dt_.GetSegmentCount(); // Do one merge pass for (size_t seg_id = 0; seg_id < dir_size; seg_id++) { auto* seg = dt_.GetSegment(seg_id); if (seg->local_depth() == 1) continue; size_t buddy_id = dt_.FindBuddyId(seg_id); if (buddy_id == seg_id || seg_id > buddy_id) continue; auto* buddy = dt_.GetSegment(buddy_id); size_t combined_size = seg->SlowSize() + buddy->SlowSize(); if (combined_size <= static_cast(0.25 * seg->capacity())) { bool merged = dt_.Merge(seg_id, buddy_id); if (merged) { // Size must be unchanged after each merge EXPECT_EQ(dt_.size(), size_before) << "size changed after merging seg_id=" << seg_id << " buddy_id=" << buddy_id; } } } } // After merging, verify all remaining keys are still findable via dt_.Find(). // This tests that directory routing is correct after merge. TEST_F(DashTest, MergeKeyLookupConsistency) { constexpr size_t kNumItems = 4000; std::vector all_keys; for (uint64_t i = 0; i < kNumItems; ++i) { auto [it, inserted] = dt_.Insert(i, i); if (inserted) all_keys.push_back(i); } // Keep only ~10% of keys size_t keep_count = all_keys.size() / 10; for (size_t i = keep_count; i < all_keys.size(); ++i) { dt_.Erase(all_keys[i]); } all_keys.resize(keep_count); size_t dir_size = dt_.GetSegmentCount(); // Merge all eligible pairs bool merged_any = true; while (merged_any) { merged_any = false; for (size_t seg_id = 0; seg_id < dir_size; seg_id++) { auto* seg = dt_.GetSegment(seg_id); if (seg->local_depth() == 1) continue; size_t buddy_id = dt_.FindBuddyId(seg_id); if (buddy_id == seg_id || seg_id > buddy_id) continue; auto* buddy = dt_.GetSegment(buddy_id); size_t combined_size = seg->SlowSize() + buddy->SlowSize(); if (combined_size <= static_cast(0.25 * seg->capacity())) { if (dt_.Merge(seg_id, buddy_id)) { merged_any = true; } } } } // All remaining keys must be findable via the table-level Find for (uint64_t key : all_keys) { auto it = dt_.Find(key); EXPECT_FALSE(it.is_done()) << "Key " << key << " not found after merge"; } } // Test that after merging to depth 1, inserting more keys works correctly — // the table can split again and all data remains intact. TEST_F(DashTest, MergeAndGrow) { constexpr size_t kPhase1 = 4000; std::vector surviving_keys; for (uint64_t i = 0; i < kPhase1; ++i) { dt_.Insert(i, i); } // Delete enough to enable merge size_t keep_count = kPhase1 / 20; // ~5% for (uint64_t i = keep_count; i < kPhase1; ++i) { dt_.Erase(i); } for (uint64_t i = 0; i < keep_count; ++i) { surviving_keys.push_back(i); } size_t dir_size = dt_.GetSegmentCount(); bool merged_any = true; while (merged_any) { merged_any = false; for (size_t seg_id = 0; seg_id < dir_size; seg_id++) { auto* seg = dt_.GetSegment(seg_id); if (seg->local_depth() == 1) continue; size_t buddy_id = dt_.FindBuddyId(seg_id); if (buddy_id == seg_id || seg_id > buddy_id) continue; auto* buddy = dt_.GetSegment(buddy_id); size_t combined = seg->SlowSize() + buddy->SlowSize(); if (combined <= static_cast(0.25 * seg->capacity())) { dt_.Merge(seg_id, buddy_id); merged_any = true; } } } EXPECT_EQ(dt_.unique_segments(), 2); // Now insert a new batch — the table should grow (split) again constexpr size_t kPhase2 = 3000; for (uint64_t i = kPhase1; i < kPhase1 + kPhase2; ++i) { auto [it, inserted] = dt_.Insert(i, i); if (inserted) surviving_keys.push_back(i); } EXPECT_GT(dt_.depth(), 1); // ALL surviving keys must be findable after growth for (uint64_t key : surviving_keys) { auto it = dt_.Find(key); EXPECT_FALSE(it.is_done()) << "Key " << key << " lost after merge+grow"; } } // Verify that after merging, all directory entries that span the merged // segment range point to the same segment object (the kept one). TEST_F(DashTest, MergeDirectoryConsistency) { // Insert enough for depth 2 (4 segments) for (uint64_t i = 0; i < 2000; ++i) { dt_.Insert(i, i); } EXPECT_GE(dt_.depth(), 2); // Delete most items to enable merge for (uint64_t i = 50; i < 2000; ++i) { dt_.Erase(i); } unsigned keep_id = 0; unsigned buddy_id = dt_.FindBuddyId(0); if (buddy_id == 0) { // No buddy for segment 0 - try segment 2 keep_id = 2; buddy_id = dt_.FindBuddyId(2); } // Only proceed if we found a mergeable buddy pair if (buddy_id != keep_id) { auto* keep = dt_.GetSegment(keep_id); auto* buddy = dt_.GetSegment(buddy_id); if (keep->local_depth() == buddy->local_depth() && keep->local_depth() > 1 && keep_id < buddy_id) { uint8_t depth = keep->local_depth(); size_t combined = keep->SlowSize() + buddy->SlowSize(); if (combined <= static_cast(0.25 * keep->capacity())) { bool merged = dt_.Merge(keep_id, buddy_id); ASSERT_TRUE(merged); // After merge, all dir entries that covered buddy must now point to keep auto* kept_seg = dt_.GetSegment(keep_id); uint32_t chunk_size = 1u << (dt_.depth() - (depth - 1)); uint32_t start = keep_id & ~(chunk_size - 1u); for (size_t i = start; i < start + chunk_size; ++i) { EXPECT_EQ(dt_.GetSegment(i), kept_seg) << "Directory entry " << i << " does not point to merged segment"; } } } } } // Test merging a table with global_depth > local_depth (aliased directory entries). // When a segment at depth D < global_depth is merged with its buddy, // the merged segment at depth D-1 should span the correct directory range. TEST_F(DashTest, MergeWithAliasedEntries) { // Create depth-3 table (8 dir entries), then merge two depth-3 pairs to get depth-2 segments // alongside other depth-3 segments. This creates aliased entries. for (uint64_t i = 0; i < 4000; ++i) { dt_.Insert(i, i); } EXPECT_EQ(dt_.depth(), 3); // Delete most items for (uint64_t i = 200; i < 4000; ++i) { dt_.Erase(i); } // Merge segments 0 and 1 (both at depth 3) -> depth 2 segment spanning entries {0,1} auto* seg0 = dt_.GetSegment(0); auto* seg1 = dt_.GetSegment(1); if (seg0->local_depth() == 3 && seg1->local_depth() == 3) { size_t combined = seg0->SlowSize() + seg1->SlowSize(); size_t threshold = static_cast(0.25 * seg0->capacity()); if (combined <= threshold) { bool ok = dt_.Merge(0, 1); ASSERT_TRUE(ok); // Now segment at entries 0 and 1 is the same depth-2 object EXPECT_EQ(dt_.GetSegment(0), dt_.GetSegment(1)); EXPECT_EQ(dt_.GetSegment(0)->local_depth(), 2); // global_depth should still be 3 EXPECT_EQ(dt_.depth(), 3); // Entries 2 and 3 should still be distinct depth-3 segments EXPECT_NE(dt_.GetSegment(2), dt_.GetSegment(3)); // Since entries 2 and 3 are still at depth 3 (not yet merged into a depth-2 segment), // the true buddy of the depth-2 segment {0,1} does NOT yet exist. // FindBuddyId computes: bit_pos = global_depth(3) - local_depth(2) = 1 // FindBuddyId(0) -> buddy_idx = 0^2 = 2, GetSegment(2)->local_depth() = 3 != 2 -> returns 0 // FindBuddyId(1) -> buddy_idx = 1^2 = 3, GetSegment(3)->local_depth() = 3 != 2 -> returns 1 // Both aliased entries correctly report "no buddy" (returning themselves). EXPECT_EQ(dt_.FindBuddyId(0), 0u) << "No buddy exists for depth-2 segment when entries 2,3 are still depth-3"; EXPECT_EQ(dt_.FindBuddyId(1), 1u) << "Aliased entry 1 of same depth-2 segment also finds no buddy"; // Now merge entries 2 and 3 to create a second depth-2 segment covering {2,3} auto* seg2 = dt_.GetSegment(2); auto* seg3 = dt_.GetSegment(3); if (seg2 != seg3) { size_t combined23 = seg2->SlowSize() + seg3->SlowSize(); if (combined23 <= static_cast(0.25 * seg2->capacity())) { bool ok23 = dt_.Merge(2, 3); if (ok23) { // Now both {0,1} and {2,3} are depth-2 segments — they ARE buddies // FindBuddyId(0): bit_pos=1, buddy_idx=0^2=2, GetSegment(2)->local_depth()=2 == 2 -> 2 // FindBuddyId(2): bit_pos=1, buddy_idx=2^2=0, GetSegment(0)->local_depth()=2 == 2 -> 0 EXPECT_EQ(dt_.FindBuddyId(0), 2u) << "After both pairs merged to depth-2, FindBuddyId(0)=2"; EXPECT_EQ(dt_.FindBuddyId(2), 0u) << "FindBuddyId(2) should return 0 (symmetric)"; // Aliased entry 1 looks for buddy at 1^2=3 EXPECT_EQ(dt_.FindBuddyId(1), 3u) << "FindBuddyId(1) returns 3 (alias buddy)"; } } } } } } // Test that FindBuddyId resolves to the same buddy *instance* for all alias ids in a stripe. // // When global_depth > local_depth a segment is referenced by a contiguous "stripe" of // stripe_size = 2^(global_depth - local_depth) directory entries that all point to the // same segment object. // The canonical id is the stripe's first entry (lowest index). // // FindBuddyId(alias) computes: // depth = GetSegment(alias)->local_depth() // reads from the instance, same for all // bit_pos = global_depth - depth // same for every alias in the stripe // buddy_ix = alias ^ (1 << bit_pos) // XOR differs per alias // // For a stripe starting at canonical id C (i.e. C is a multiple of stripe_size): // alias k = C + k (0 <= k < stripe_size) // buddy_ix(k) = (C + k) ^ (1 << bit_pos) // = C ^ (1 << bit_pos) + k (because k < stripe_size = 1<= 3, giving segments at local_depth 3. for (uint64_t i = 0; i < 8000; ++i) { dt_.Insert(i, i); } ASSERT_GE(dt_.depth(), 3u); // Erase most items so segments are sparse enough to merge. for (uint64_t i = 100; i < 8000; ++i) { dt_.Erase(i); } // To get a real buddy we must merge TWO adjacent pairs at the same depth. // After merging pair A (keep_a, buddy_a) the kept segment drops to depth d-1, // but its buddy stripe still has the old depth d, so FindBuddyId returns self. // Only after merging the adjacent pair B (keep_b, buddy_b) to d-1 as well do // the two resulting stripes become buddies of each other. // // We find four consecutive canonical segments at the same depth d > 2 and merge // pairs (0,1) and (2,3) within that group. unsigned keep_a = UINT_MAX, bud_a = UINT_MAX, keep_b = UINT_MAX, bud_b = UINT_MAX; for (size_t i = 0; i < dt_.GetSegmentCount();) { auto* s0 = dt_.GetSegment(i); uint8_t d = s0->local_depth(); if (d <= 2) { i = dt_.NextSeg(i); continue; } size_t i1 = dt_.NextSeg(i); if (i1 >= dt_.GetSegmentCount()) break; size_t i2 = dt_.NextSeg(i1); if (i2 >= dt_.GetSegmentCount()) break; size_t i3 = dt_.NextSeg(i2); if (i3 >= dt_.GetSegmentCount()) break; auto* s1 = dt_.GetSegment(i1); auto* s2 = dt_.GetSegment(i2); auto* s3 = dt_.GetSegment(i3); size_t cap = s0->capacity(); if (s1->local_depth() == d && s2->local_depth() == d && s3->local_depth() == d && s0->SlowSize() + s1->SlowSize() <= static_cast(0.25 * cap) && s2->SlowSize() + s3->SlowSize() <= static_cast(0.25 * cap)) { keep_a = static_cast(i); bud_a = static_cast(i1); keep_b = static_cast(i2); bud_b = static_cast(i3); break; } i = dt_.NextSeg(i); } ASSERT_NE(keep_a, UINT_MAX); ASSERT_TRUE(dt_.Merge(keep_a, bud_a)); ASSERT_TRUE(dt_.Merge(keep_b, bud_b)); // After both merges: // - segment at keep_a has local_depth = d-1, aliased by stripe {keep_a, keep_a+1} // - segment at keep_b has local_depth = d-1, aliased by stripe {keep_b, keep_b+1} // - The two stripes are buddies of each other (same depth, adjacent subtrees). auto* seg_a = dt_.GetSegment(keep_a); uint8_t new_depth = seg_a->local_depth(); ASSERT_GE(new_depth, 2u); // depth<=1 guard in FindBuddyId must not fire size_t stripe_size = 1u << (dt_.depth() - new_depth); size_t stripe_start = keep_a & ~(stripe_size - 1); // FindBuddyId from the canonical id of stripe A must resolve to seg_b. auto* seg_b = dt_.GetSegment(keep_b); unsigned canonical_bid = dt_.FindBuddyId(static_cast(stripe_start)); ASSERT_EQ(dt_.GetSegment(canonical_bid), seg_b) << "FindBuddyId from canonical id must resolve to the buddy segment"; EXPECT_EQ(stripe_size, 2); for (size_t k = 0; k < stripe_size; ++k) { size_t alias = stripe_start + k; EXPECT_EQ(dt_.GetSegment(alias), seg_a) << "Directory entry " << alias << " must alias seg_a"; unsigned bid = dt_.FindBuddyId(static_cast(alias)); // Different alias -> different buddy id value, but same buddy instance. EXPECT_EQ(bid, canonical_bid + k) << "FindBuddyId(" << alias << ") should equal canonical_bid + " << k; EXPECT_EQ(dt_.GetSegment(bid), seg_b) << "FindBuddyId(" << alias << ") must resolve to seg_b for all aliases"; // Stripe B is at higher indices than stripe A (Merge requires keep_id < buddy_id). EXPECT_GT(bid, alias); } } // Test that NextSeg is correct when called with the canonical (first) id of a stripe, // and documents the expected behavior for non-canonical (middle-of-stripe) ids. // // NextSeg(sid) computes: // delta = 1 << (global_depth - segment_[sid]->local_depth()) // return sid + delta // // For the canonical (first) id of a stripe, sid is already aligned to a multiple of // delta, so sid + delta is exactly the first id of the next stripe — correct. // // For a non-canonical id sid = canonical + k (0 < k < delta), the result is // (canonical + k) + delta // which lands k positions into the next stripe, not at its start. TEST_F(DashTest, NextSegCanonicalBehavior) { // Build a table large enough for global_depth >= 2. for (uint64_t i = 0; i < 2000; ++i) { dt_.Insert(i, i); } ASSERT_GE(dt_.depth(), 2u); // NextSeg from id 0 always uses canonical ids (0 is always canonical). // Verify it visits every distinct segment exactly once by comparing against // unique_segments() which is maintained as a counter by Insert/Merge. size_t visited = 0; for (size_t i = 0; i < dt_.GetSegmentCount(); i = dt_.NextSeg(i)) { ++visited; } EXPECT_EQ(visited, dt_.unique_segments()) << "NextSeg traversal from id 0 (canonical) must visit each unique segment once"; // Erase most entries and merge to create a stripe (local_depth < global_depth). for (uint64_t i = 100; i < 2000; ++i) { dt_.Erase(i); } // Find and perform a merge to produce a stripe. for (size_t i = 0; i < dt_.GetSegmentCount(); i = dt_.NextSeg(i)) { auto* seg = dt_.GetSegment(i); if (seg->local_depth() <= 1) continue; size_t next = dt_.NextSeg(i); if (next >= dt_.GetSegmentCount()) break; auto* buddy = dt_.GetSegment(next); if (buddy->local_depth() == seg->local_depth() && seg->SlowSize() + buddy->SlowSize() <= static_cast(0.25 * seg->capacity())) { bool ok = dt_.Merge(static_cast(i), static_cast(next)); if (ok) break; } } // After a potential merge, re-verify that canonical traversal is consistent. size_t manual2 = 0; for (size_t i = 0; i < dt_.GetSegmentCount(); i = dt_.NextSeg(i)) { ++manual2; } EXPECT_EQ(manual2, dt_.unique_segments()) << "After merge, canonical NextSeg traversal must still match unique_segments()"; // Show the non-canonical case: for any stripe of size > 1, NextSeg from a non-first // alias does NOT land on the start of the next stripe. for (size_t i = 0; i < dt_.GetSegmentCount(); i = dt_.NextSeg(i)) { auto* seg = dt_.GetSegment(i); size_t delta = 1u << (dt_.depth() - seg->local_depth()); if (delta <= 1) continue; // no stripe aliases for this segment // i is canonical; i+1 is a non-canonical alias of the same segment. size_t non_canonical = i + 1; ASSERT_LT(non_canonical, i + delta) << "non_canonical must still be within the stripe"; // NextSeg from the non-canonical id lands at (non_canonical + delta), which is // one position past the start of the next stripe — demonstrating the offset. size_t next_from_canonical = dt_.NextSeg(i); // i + delta (correct) size_t next_from_alias = dt_.NextSeg(non_canonical); // i+1+delta (offset by 1) EXPECT_EQ(next_from_alias, next_from_canonical + 1) << "NextSeg from a non-canonical alias is offset by the same amount as the alias " "itself; callers must always use canonical (stripe-start) ids"; break; // one example is sufficient to document the behavior } } TEST_F(DashTest, BumpUp) { set keys = FillSegment(0); constexpr unsigned kFirstStashId = Segment::kBucketNum; constexpr unsigned kSecondStashId = Segment::kBucketNum + 1; constexpr unsigned kSlotNum = Segment::kSlotNum; EXPECT_TRUE(segment_.GetBucket(0).IsFull()); EXPECT_TRUE(segment_.GetBucket(1).IsFull()); EXPECT_TRUE(segment_.GetBucket(kFirstStashId).IsFull()); EXPECT_TRUE(segment_.GetBucket(kSecondStashId).IsFull()); // Segment::Iterator it{kFirstStashId, 1}; Segment::Key_t key = segment_.Key(1, 2); // key at bucket 1, slot 2 uint8_t touched_bid[3]; uint64_t hash = dt_.DoHash(key); segment_.Delete(Segment::Iterator{1, 2}, hash); EXPECT_FALSE(segment_.GetBucket(1).IsFull()); segment_.SetVersion(kFirstStashId, 1); key = segment_.Key(kFirstStashId, 5); hash = dt_.DoHash(key); EXPECT_EQ(2, segment_.CVCOnBump(1, kFirstStashId, 5, hash, touched_bid)); EXPECT_EQ(touched_bid[0], 0); EXPECT_EQ(touched_bid[1], 1); // Bump up std::vector> moved_buckets; auto move_cb = [&moved_buckets](uint32_t /* segment_id */, uint8_t a, uint8_t b) { moved_buckets.emplace_back(a, b); }; segment_.BumpUp(kFirstStashId, 5, hash, RelaxedBumpPolicy{}, move_cb); // expect the key to move EXPECT_TRUE(segment_.GetBucket(1).IsFull()); EXPECT_FALSE(segment_.GetBucket(kFirstStashId).IsFull()); EXPECT_EQ(segment_.Key(1, 2), key); EXPECT_EQ(moved_buckets.size(), 1); EXPECT_EQ(moved_buckets.at(0).first, kFirstStashId); EXPECT_EQ(moved_buckets.at(0).second, 1); moved_buckets.clear(); EXPECT_TRUE(Contains(key)); // 9 is just a random slot id. key = segment_.Key(kSecondStashId, 9); hash = dt_.DoHash(key); EXPECT_EQ(3, segment_.CVCOnBump(2, kSecondStashId, 9, hash, touched_bid)); EXPECT_EQ(touched_bid[0], kSecondStashId); // Bumpup will move the key to either its original bucket or a probing bucket. // Since we can't determine the exact bucket before calling bumpup, CVCOnBump // returns both the original bucket and the probing bucket. EXPECT_EQ(touched_bid[1], 0); EXPECT_EQ(touched_bid[2], 1); auto it = segment_.BumpUp(kSecondStashId, 9, hash, RelaxedBumpPolicy{}, move_cb); ASSERT_TRUE(key == segment_.Key(0, kSlotNum - 1) || key == segment_.Key(1, kSlotNum - 1)); EXPECT_TRUE(segment_.GetBucket(kSecondStashId).IsFull()); EXPECT_TRUE(Contains(key)); EXPECT_TRUE(segment_.Key(kSecondStashId, 9)); EXPECT_EQ(moved_buckets.size(), 2); EXPECT_EQ(moved_buckets.at(0).first, kSecondStashId); EXPECT_EQ(moved_buckets.at(0).second, it.index); EXPECT_EQ(moved_buckets.at(1).first, it.index); EXPECT_EQ(moved_buckets.at(1).second, kSecondStashId); } TEST_F(DashTest, BumpPolicy) { struct RestrictedBumpPolicy { bool CanBump(uint64_t key) const { return false; } void OnMove(Dash64::Cursor source, Dash64::Cursor dest) { } }; set keys = FillSegment(0); constexpr unsigned kFirstStashId = Segment::kBucketNum; EXPECT_TRUE(segment_.GetBucket(0).IsFull()); EXPECT_TRUE(segment_.GetBucket(1).IsFull()); EXPECT_TRUE(segment_.GetBucket(kFirstStashId).IsFull()); // check items are immovable in bucket Segment::Key_t key = segment_.Key(1, 2); uint64_t hash = dt_.DoHash(key); segment_.BumpUp(1, 2, hash, RestrictedBumpPolicy{}, [](auto&&...) {}); EXPECT_EQ(key, segment_.Key(1, 2)); // check items don't swap from stash key = segment_.Key(kFirstStashId, 2); hash = dt_.DoHash(key); segment_.BumpUp(kFirstStashId, 2, hash, RestrictedBumpPolicy{}, [](auto&&...) {}); EXPECT_EQ(key, segment_.Key(kFirstStashId, 2)); } TEST_F(DashTest, Insert2) { uint64_t k = 1191; ASSERT_EQ(2019837007031366716, UInt64Policy::HashFn(k)); Dash64 dt; for (unsigned i = 0; i < 2000; ++i) { dt.Insert(i, 0); } } TEST_F(DashTest, InsertOOM) { CappedResource resource(1 << 15); Dash64 dt{1, UInt64Policy{}, &resource}; ASSERT_THROW( { for (size_t i = 0; i < (1 << 14); ++i) { dt.Insert(i, 0); } }, bad_alloc); } struct Item { char buf[24]; }; constexpr size_t ItemAlign = alignof(Item); struct MyBucket : public detail::BucketBase<16> { Item key[14]; }; constexpr size_t kMySz = sizeof(MyBucket); constexpr size_t kBBSz = sizeof(detail::BucketBase<16>); TEST_F(DashTest, Custom) { using ItemSegment = detail::Segment; constexpr double kTax = ItemSegment::kTaxSize; constexpr size_t kMaxSize = ItemSegment::kMaxSize; constexpr size_t kSegSize = sizeof(ItemSegment); constexpr size_t kBuckSz = ItemSegment::kBucketSz; (void)kTax; (void)kMaxSize; (void)kSegSize; (void)kBuckSz; ItemSegment seg{2, 0, PMR_NS::get_default_resource()}; auto eq = [v = Item{1, 1}](auto u) { return v.buf[0] == u.buf[0] && v.buf[1] == u.buf[1]; }; auto it = seg.FindIt(42, eq); ASSERT_FALSE(it.found()); } TEST_F(DashTest, FindByValue) { using ItemSegment = detail::Segment; auto no_op_cb = [](auto&&...) {}; // Insert three different values with the same hash ItemSegment segment{2, 0, PMR_NS::get_default_resource()}; segment.Insert( Item{1}, 1, 42, [](const auto& pred) { return pred.buf[0] == 1; }, no_op_cb); segment.Insert( Item{2}, 2, 42, [](const auto& pred) { return pred.buf[0] == 2; }, no_op_cb); segment.Insert( Item{3}, 3, 42, [](const auto& pred) { return pred.buf[0] == 3; }, no_op_cb); // We should be able to find the middle one by value auto it = segment.FindIt(42, [](const auto& key, const auto& value) { return value == 2; }); EXPECT_TRUE(it.found()); EXPECT_EQ(segment.Value(it.index, it.slot), 2); } TEST_F(DashTest, Reserve) { unsigned bc = dt_.capacity(); for (unsigned i = 0; i <= bc * 2; ++i) { dt_.Reserve(i); ASSERT_GE((1 << dt_.depth()) * Dash64::kSegCapacity, i); } } TEST_F(DashTest, Insert) { constexpr size_t kNumItems = 10000; double sum = 0; for (size_t i = 0; i < kNumItems; ++i) { dt_.Insert(i, i); double u = (dt_.size() * 100.0) / (dt_.unique_segments() * Segment::capacity()); sum += u; VLOG(1) << "Num items " << dt_.size() << ", load factor " << u << ", size per entry " << double(dt_.mem_usage()) / dt_.size(); } EXPECT_EQ(kNumItems, dt_.size()); LOG(INFO) << "Average load factor is " << sum / kNumItems; for (size_t i = 0; i < kNumItems; ++i) { Dash64::const_iterator it = dt_.Find(i); ASSERT_TRUE(it != dt_.end()); ASSERT_EQ(it->second, i); ASSERT_LE(dt_.load_factor(), 1) << i; } for (size_t i = kNumItems; i < kNumItems * 10; ++i) { Dash64::const_iterator it = dt_.Find(i); ASSERT_TRUE(it == dt_.end()); } EXPECT_EQ(kNumItems, dt_.size()); EXPECT_EQ(1, dt_.Erase(0)); EXPECT_EQ(0, dt_.Erase(0)); EXPECT_EQ(kNumItems - 1, dt_.size()); auto it = dt_.begin(); ASSERT_FALSE(it.is_done()); auto some_val = it->second; dt_.Erase(it); ASSERT_TRUE(dt_.Find(some_val).is_done()); } TEST_F(DashTest, Traverse) { constexpr auto kNumItems = 50; for (size_t i = 0; i < kNumItems; ++i) { dt_.Insert(i, i); } Dash64::Cursor cursor; vector nums; auto tr_cb = [&](Dash64::iterator it) { nums.push_back(it->first); VLOG(1) << it.bucket_id() << " " << it.slot_id() << " " << it->first; }; do { cursor = dt_.Traverse(cursor, tr_cb); } while (cursor); sort(nums.begin(), nums.end()); nums.resize(unique(nums.begin(), nums.end()) - nums.begin()); ASSERT_EQ(kNumItems, nums.size()); EXPECT_EQ(0, nums[0]); EXPECT_EQ(kNumItems - 1, nums.back()); } TEST_F(DashTest, TraverseSegmentOrder) { constexpr auto kNumItems = 50; for (size_t i = 0; i < kNumItems; ++i) { dt_.Insert(i, i); } vector nums; auto tr_cb = [&](Dash64::iterator it) { nums.push_back(it->first); VLOG(1) << it.bucket_id() << " " << it.slot_id() << " " << it->first; }; Dash64::Cursor cursor; do { cursor = dt_.TraverseBySegmentOrder(cursor, tr_cb); } while (cursor); sort(nums.begin(), nums.end()); nums.resize(unique(nums.begin(), nums.end()) - nums.begin()); ASSERT_EQ(kNumItems, nums.size()); EXPECT_EQ(0, nums[0]); EXPECT_EQ(kNumItems - 1, nums.back()); } TEST_F(DashTest, TraverseBucketOrder) { constexpr auto kNumItems = 18000; for (size_t i = 0; i < kNumItems; ++i) { dt_.Insert(i, i); } for (size_t i = 0; i < kNumItems; ++i) { dt_.Erase(i); } constexpr auto kSparseItems = kNumItems / 50; for (size_t i = 0; i < kSparseItems; ++i) { // create sparse table dt_.Insert(i, i); } vector nums; auto tr_cb = [&](Dash64::bucket_iterator it) { VLOG(1) << "call cb"; while (!it.is_done()) { nums.push_back(it->first); VLOG(1) << it.bucket_id() << " " << it.slot_id() << " " << it->first; ++it; } }; Dash64::Cursor cursor; do { cursor = dt_.TraverseBuckets(cursor, tr_cb); } while (cursor); sort(nums.begin(), nums.end()); nums.resize(unique(nums.begin(), nums.end()) - nums.begin()); ASSERT_EQ(kSparseItems, nums.size()); EXPECT_EQ(0, nums[0]); EXPECT_EQ(kSparseItems - 1, nums.back()); } struct TestEvictionPolicy { static constexpr bool can_evict = true; static constexpr bool can_gc = false; explicit TestEvictionPolicy(unsigned max_cap) : max_capacity(max_cap) { } bool CanGrow(const Dash64& tbl) const { return tbl.capacity() < max_capacity; } void OnMove(Dash64::Cursor source, Dash64::Cursor dest) { } void RecordSplit(Dash64::Segment_t*) { } unsigned Evict(const Dash64::HotBuckets& hotb, Dash64* me) const { if (!evict_enabled) return 0; auto it = hotb.probes.by_type.regular_buckets[0]; unsigned res = 0; for (; !it.is_done(); ++it) { LOG(INFO) << "Deleting " << it->first; me->Erase(it); ++res; } return res; } bool evict_enabled = false; unsigned max_capacity; }; TEST_F(DashTest, Eviction) { TestEvictionPolicy ev(1540); size_t num = 0; auto loop = [&] { for (; num < 5000; ++num) { dt_.Insert(num, 0, ev); } }; ASSERT_THROW(loop(), bad_alloc); ASSERT_LT(num, 5000); ASSERT_EQ(2, dt_.unique_segments()); EXPECT_LT(dt_.size(), ev.max_capacity); LOG(INFO) << "size is " << dt_.size(); set keys; Dash64::bucket_iterator bit = dt_.begin(); unsigned last_slot = 0; while (!bit.is_done()) { keys.insert(bit->first); last_slot = bit.slot_id(); ++bit; } ASSERT_LT(last_slot, Dash64::kSlotNum); bit = dt_.begin(); dt_.ShiftRight(bit); bit = dt_.begin(); size_t sz = 0; while (!bit.is_done()) { EXPECT_EQ(1, keys.count(bit->first)); ++sz; ++bit; } EXPECT_EQ(sz, keys.size()); while (!dt_.GetSegment(0)->GetBucket(0).IsFull()) { try { dt_.Insert(num++, 0, ev); } catch (bad_alloc&) { } } // Now the bucket is full. keys.clear(); uint64_t last_key = dt_.GetSegment(0)->Key(0, Dash64::kSlotNum - 1); for (Dash64::bucket_iterator bit = dt_.begin(); !bit.is_done(); ++bit) { keys.insert(bit->first); } bit = dt_.begin(); dt_.ShiftRight(bit); bit = dt_.begin(); sz = 0; while (!bit.is_done()) { EXPECT_NE(last_key, bit->first); EXPECT_EQ(1, keys.count(bit->first)); ++sz; ++bit; } EXPECT_EQ(sz + 1, keys.size()); ev.evict_enabled = true; unsigned bucket_cnt = dt_.bucket_count(); auto [it, res] = dt_.Insert(num, 0, ev); EXPECT_TRUE(res); EXPECT_EQ(bucket_cnt, dt_.bucket_count()); } struct VersionPolicy : public BasicDashPolicy { static constexpr bool kUseVersion = true; static uint64_t HashFn(int v) { return XXH3_64bits(&v, sizeof(v)); } }; using VersionDT = DashTable; TEST_F(DashTest, Version) { VersionDT dt; auto [it, inserted] = dt.Insert(1, 1); EXPECT_EQ(0, it.GetVersion()); it.SetVersion(5); EXPECT_EQ(5, it.GetVersion()); dt.Clear(); ASSERT_EQ(0, dt.size()); ASSERT_EQ(2, dt.unique_segments()); ASSERT_EQ(136, dt.bucket_count()); constexpr int kNum = 68000; for (int i = 0; i < kNum; ++i) { auto it = dt.Insert(i, 0).first; it.SetVersion(i + 65000); if (i) { auto p = dt.Find(i - 1); ASSERT_GE(p.GetVersion(), i - 1 + 65000) << i; } } unsigned items = 0; for (auto it = dt.begin(); it != dt.end(); ++it) { ASSERT_FALSE(it.is_done()); ASSERT_GE(it.GetVersion(), it->first + 65000) << it.segment_id() << " " << it.bucket_id() << " " << it.slot_id(); ++items; } ASSERT_EQ(kNum, items); } TEST_F(DashTest, CVCUponInsert) { VersionDT dt; auto [it, added] = dt.Insert(10, 20); // added to slot 0 ASSERT_TRUE(added); int i = 11; while (true) { auto [it2, added] = dt.Insert(i, 30); if (it2.bucket_id() == it.bucket_id() && it2.segment_id() == it.segment_id()) { ASSERT_EQ(1, it2.slot_id()); break; } ++i; } // freed slot 0 but the bucket still has i at slot 1. dt.Erase(10); auto cb = [](VersionDT::bucket_iterator bit) { LOG(INFO) << "sid: " << bit.segment_id() << " " << bit.bucket_id(); while (!bit.is_done()) { LOG(INFO) << "key: " << bit->first; ++bit; } }; dt.CVCUponInsert(1, i, cb); } TEST_F(DashTest, CVCUponInsertStress) { VersionDT dt; for (int i = 0; i < 5000; ++i) { dt.CVCUponInsert(1, i, [](VersionDT::bucket_iterator) { // empty callback }); dt.Insert(i, 0); } } struct A { int a = 0; unsigned moved = 0; A(int i = 0) : a(i) { } A(const A&) = delete; A(A&& o) : a(o.a), moved(o.moved + 1) { o.a = -1; } A& operator=(const A&) = delete; A& operator=(A&& o) noexcept { o.moved = o.moved + 1; a = o.a; o.a = -1; return *this; } bool operator==(const A& o) const { return o.a == a; } }; struct ADashPolicy : public BasicDashPolicy { static uint64_t HashFn(const A& a) { auto val = XXH3_64bits(&a.a, sizeof(a.a)); return val; } }; TEST_F(DashTest, Moveable) { using DType = DashTable; DType table{1}; ASSERT_TRUE(table.Insert(A{1}, A{2}).second); ASSERT_FALSE(table.Insert(A{1}, A{3}).second); EXPECT_EQ(1, table.size()); table.Clear(); EXPECT_EQ(0, table.size()); } struct SdsDashPolicy { enum { kSlotNum = 12, kBucketNum = 64, kStashBucketNum = 2 }; static constexpr bool kUseVersion = false; static uint64_t HashFn(sds u) { return XXH3_64bits(reinterpret_cast(u), sdslen(u)); } static uint64_t HashFn(std::string_view u) { return XXH3_64bits(u.data(), u.size()); } static void DestroyValue(uint64_t) { } static void DestroyKey(sds s) { sdsfree(s); } static bool Equal(sds u1, sds u2) { return dictSdsKeyCompare(nullptr, u1, u2) == 0; } static bool Equal(sds u1, std::string_view u2) { return u2 == std::string_view{u1, sdslen(u1)}; } }; TEST_F(DashTest, Sds) { DashTable dt; sds foo = sdscatlen(sdsempty(), "foo", 3); dt.Insert(foo, 0); // dt.Insert(std::string_view{"bar"}, 1); } struct BlankPolicy : public BasicDashPolicy { static uint64_t HashFn(uint64_t v) { return v; } }; // The bug was that for very rare cases when during segment splitting we move all the items // into a new segment, not every item finds a place. TEST_F(DashTest, SplitBug) { DashTable table; string path = base::ProgramRunfile("testdata/ids.txt.zst"); io::Result src = io::OpenUncompressed(path); ASSERT_TRUE(src) << src.error(); io::LineReader lr(*src, TAKE_OWNERSHIP); string_view line; uint64_t val; while (lr.Next(&line)) { CHECK(absl::SimpleHexAtoi(line, &val)) << line; table.Insert(val, 0); } EXPECT_EQ(746, table.size()); } /** ______ _ _ _ _______ _ | ____| (_) | | (_) |__ __| | | | |____ ___ ___| |_ _ ___ _ __ | | ___ ___| |_ ___ | __\ \ / / |/ __| __| |/ _ \| '_ \ | |/ _ \/ __| __/ __| | |___\ V /| | (__| |_| | (_) | | | | | | __/\__ \ |_\__ \ |______\_/ |_|\___|\__|_|\___/|_| |_| |_|\___||___/\__|___/ * */ struct EvictParams { bool use_bumpups; double zipf_param; string PrintTo() const { string name = absl::StrCat(use_bumpups ? "" : "no", "bumps"); absl::StrAppend(&name, unsigned(zipf_param * 1000)); return name; } }; string PrintParams(const testing::TestParamInfo& info) { return info.param.PrintTo(); } struct U64DashPolicy { enum { kSlotNum = 14, kBucketNum = 64, kStashBucketNum = 4 }; static constexpr bool kUseVersion = false; static void DestroyValue(uint64_t) { } static void DestroyKey(uint64_t) { } static bool Equal(uint64_t u, uint64_t v) { return u == v; } static uint64_t HashFn(uint64_t v) { return XXH3_64bits(&v, sizeof(v)); } }; using U64Dash = DashTable; struct SimpleEvictPolicy { static constexpr bool can_gc = false; static constexpr bool can_evict = true; bool CanGrow(const U64Dash& tbl) { return tbl.capacity() + U64Dash::kSegCapacity < max_capacity; } void OnMove(U64Dash::Cursor source, U64Dash::Cursor dest) { } void RecordSplit(U64Dash::Segment_t* segment) { } // Required interface in case can_gc is true // returns number of items evicted from the table. // 0 means - nothing has been evicted. unsigned Evict(const U64Dash::HotBuckets& hotb, U64Dash* me) { constexpr unsigned kBucketNum = U64Dash::HotBuckets::kNumBuckets; uint32_t bid = hotb.key_hash % kBucketNum; unsigned slot_index = (hotb.key_hash >> 32) % U64Dash::kSlotNum; for (unsigned i = 0; i < kBucketNum; ++i) { auto it = hotb.at((bid + i) % kBucketNum); it += slot_index; if (it.is_done()) continue; me->Erase(it); ++evicted; return 1; } return 0; } size_t max_capacity = SIZE_MAX; unsigned evicted = 0; // default_random_engine rand_eng_{42}; }; struct ShiftRightPolicy { absl::flat_hash_map evicted; size_t max_capacity = SIZE_MAX; unsigned evicted_sum = 0; static constexpr bool can_gc = false; static constexpr bool can_evict = true; bool CanGrow(const U64Dash& tbl) { return tbl.capacity() + U64Dash::kSegCapacity < max_capacity; } void RecordSplit(U64Dash::Segment_t* segment) { } void OnMove(U64Dash::Cursor source, U64Dash::Cursor dest) { } unsigned Evict(const U64Dash::HotBuckets& hotb, U64Dash* me) { constexpr unsigned kNumStashBuckets = ABSL_ARRAYSIZE(hotb.probes.by_type.stash_buckets); unsigned stash_pos = hotb.key_hash % kNumStashBuckets; auto stash_it = hotb.probes.by_type.stash_buckets[stash_pos]; stash_it += (U64Dash::kSlotNum - 1); // go to the last slot. uint64_t k = stash_it->first; DVLOG(1) << "Deleting key " << k << " from " << unsigned(stash_it.bucket_id()) << "/" << stash_it.slot_id(); evicted[k]++; CHECK(me->ShiftRight(stash_it)); ++evicted_sum; return 1; }; }; class EvictionPolicyTest : public testing::TestWithParam { protected: template void FillUniform(unsigned max_range, Policy& policy); uint64_t Rand() { return zipf_ ? zipf_->Next(rand_eng_) : udist_(rand_eng_); } void SetUp() final { if (GetParam().zipf_param > 0) zipf_.emplace(0, 15000, GetParam().zipf_param); else { uniform_int_distribution::param_type p{0, 15000}; udist_.param(p); } } default_random_engine rand_eng_{42}; U64Dash dt_; std::optional zipf_; uniform_int_distribution udist_; }; template void EvictionPolicyTest::FillUniform(unsigned max_range, Policy& policy) { std::uniform_int_distribution dist(0, max_range - 1); for (unsigned i = 0; i < 100000; ++i) { auto [it, res] = dt_.Insert(dist(rand_eng_), 0, policy); if (!res && it.is_done()) // filled up till the capacity limit break; } LOG(INFO) << dt_.size(); } TEST_P(EvictionPolicyTest, HitRate) { CHECK_LT(GetParam().zipf_param, 1); SimpleEvictPolicy ev_policy; ev_policy.max_capacity = 3000; FillUniform(15000, ev_policy); unsigned hits = 0; for (unsigned i = 0; i < 150000; ++i) { auto [it, res] = dt_.Insert(Rand(), 0, ev_policy); CHECK(!it.is_done()); if (!res) { ++hits; } } LOG(INFO) << "Zipf: " << GetParam().zipf_param << ", hits " << hits << " evictions " << ev_policy.evicted; } TEST_P(EvictionPolicyTest, HitRateZipf) { base::ZipfianGenerator gen(1, 15000, 0.9); SimpleEvictPolicy ev_policy; ev_policy.max_capacity = 3000; FillUniform(15000, ev_policy); bool use_bumps = GetParam().use_bumpups; unsigned hits = 0; for (unsigned i = 0; i < 150000; ++i) { uint64_t key = Rand(); auto [it, res] = dt_.Insert(key, 0, ev_policy); CHECK(!it.is_done()); if (res) { DVLOG(1) << "Inserted new key " << key << " to bucket " << it.bucket_id() << " slot " << it.slot_id(); } else { if (use_bumps) { RelaxedBumpPolicy policy; dt_.BumpUp(it, policy); } ++hits; } } LOG(INFO) << "Zipf: " << GetParam().PrintTo() << " hits " << hits << " evictions " << ev_policy.evicted; } TEST_P(EvictionPolicyTest, HitRateZipfShr) { ShiftRightPolicy ev_policy; ev_policy.max_capacity = 3000; FillUniform(15000, ev_policy); unsigned hits = 0; unsigned inserted_evicted = 0; bool use_bumps = GetParam().use_bumpups; for (unsigned i = 0; i < 150000; ++i) { unsigned key = Rand(); auto [it, res] = dt_.Insert(key, 0, ev_policy); if (!it.is_done()) { if (res) { DVLOG(1) << "Inserted new key " << key << " to bucket " << it.bucket_id() << " slot " << it.slot_id(); if (ev_policy.evicted.contains(key)) { ++inserted_evicted; } } else { if (use_bumps) { RelaxedBumpPolicy policy; dt_.BumpUp(it, policy); DVLOG(1) << "Bump up key " << key << " " << it.bucket_id() << " slot " << it.slot_id(); } else { DVLOG(1) << "Hit on key " << key; } ++hits; } } } vector> freq_evicted; for (const auto& k_v : ev_policy.evicted) { freq_evicted.emplace_back(k_v.second, k_v.first); } sort(freq_evicted.rbegin(), freq_evicted.rend()); LOG(INFO) << "Params " << GetParam().PrintTo() << " hits " << hits << " evictions " << ev_policy.evicted_sum << " " << "reinserted " << inserted_evicted; unsigned num_outs = 0; for (const auto& k_v : freq_evicted) { LOG(INFO) << "Evicted " << k_v.first << " : " << k_v.second; if (++num_outs > 100 || k_v.first < 5) break; } } INSTANTIATE_TEST_SUITE_P(Eviction, EvictionPolicyTest, testing::Values(EvictParams{false, 0}, EvictParams{false, 0.9}, EvictParams{true, 0.9}), PrintParams); // Benchmarks static void BM_Insert(benchmark::State& state) { unsigned count = state.range(0); size_t next = 0; while (state.KeepRunning()) { Dash64 dt; for (unsigned i = 0; i < count; ++i) { dt.Insert(next++, 0); } } } BENCHMARK(BM_Insert)->Arg(10000)->Arg(100000)->Arg(1000000); struct NoDestroySdsPolicy : public SdsDashPolicy { static void DestroyKey(sds s) { } }; static void BM_StringInsert(benchmark::State& state) { unsigned count = state.range(0); std::vector strs(count); for (unsigned i = 0; i < count; ++i) { strs[i] = sdscatprintf(sdsempty(), "key__%x", 100 + i); } while (state.KeepRunning()) { DashTable dt; for (unsigned i = 0; i < count; ++i) { dt.Insert(strs[i], 0); } } for (sds s : strs) { sdsfree(s); } } BENCHMARK(BM_StringInsert)->Arg(1000)->Arg(10000)->Arg(100000); static void BM_FindExisting(benchmark::State& state) { unsigned count = state.range(0); Dash64 dt; for (unsigned i = 0; i < count; ++i) { dt.Insert(i, 0); } size_t next = 0; while (state.KeepRunning()) { for (unsigned i = 0; i < 100; ++i) { dt.Find(next++); } } } BENCHMARK(BM_FindExisting)->Arg(1000000)->Arg(2000000); // dict memory usage is in [32*n + 8*n, 32*n + 16*n], or // per entry usage is [40, 48]. static void BM_RedisDictFind(benchmark::State& state) { unsigned count = state.range(0); dict* d = dictCreate(&IntDict); for (unsigned i = 0; i < count; ++i) { size_t key = i; dictAdd(d, (void*)key, nullptr); } size_t next = 0; while (state.KeepRunning()) { for (size_t i = 0; i < 100; ++i) { size_t k = next++; dictFind(d, (void*)k); } } dictRelease(d); } BENCHMARK(BM_RedisDictFind)->Arg(1000000)->Arg(2000000); // dict memory usage is in [32*n + 8*n, 32*n + 16*n], or // per entry usage is [40, 48]. static void BM_RedisDictInsert(benchmark::State& state) { unsigned count = state.range(0); size_t next = 0; while (state.KeepRunning()) { dict* d = dictCreate(&IntDict); for (unsigned i = 0; i < count; ++i) { dictAdd(d, (void*)next, nullptr); ++next; } dictRelease(d); } } BENCHMARK(BM_RedisDictInsert)->Arg(10000)->Arg(100000)->Arg(1000000); static void BM_RedisStringInsert(benchmark::State& state) { unsigned count = state.range(0); std::vector strs(count); for (unsigned i = 0; i < count; ++i) { strs[i] = sdscatprintf(sdsempty(), "key__%x", 100 + i); } while (state.KeepRunning()) { dict* d = dictCreate(&SdsDict); for (unsigned i = 0; i < count; ++i) { dictAdd(d, strs[i], nullptr); } dictRelease(d); } for (sds s : strs) { sdsfree(s); } } BENCHMARK(BM_RedisStringInsert)->Arg(1000)->Arg(10000)->Arg(100000); } // namespace dfly ================================================ FILE: src/core/dense_set.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/dense_set.h" #include #include #include #include #include #include #include "absl/random/distributions.h" #include "absl/random/random.h" #include "base/logging.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { using namespace std; constexpr size_t kMinSizeShift = 2; constexpr size_t kMinSize = 1 << kMinSizeShift; constexpr bool kAllowDisplacements = true; thread_local absl::InsecureBitGen tl_bit_gen; #define PREFETCH_READ(x) __builtin_prefetch(x, 0, 1) DenseSet::IteratorBase::IteratorBase(const DenseSet* owner, bool is_end) : owner_(const_cast(owner)), curr_entry_(nullptr) { curr_list_ = is_end ? owner_->entries_.end() : owner_->entries_.begin(); // Even if `is_end` is `false`, the list can be empty. if (curr_list_ == owner->entries_.end()) { curr_entry_ = nullptr; owner_ = nullptr; } else { curr_entry_ = &(*curr_list_); owner->ExpireIfNeeded(nullptr, curr_entry_); // find the first non null entry if (curr_entry_->IsEmpty()) { Advance(); } } } void DenseSet::IteratorBase::SetExpiryTime(uint32_t ttl_sec) { DensePtr* ptr = curr_entry_->IsLink() ? curr_entry_->AsLink() : curr_entry_; void* src = ptr->GetObject(); if (!HasExpiry()) { const size_t old_size = owner_->ObjectAllocSize(ptr->Raw()); void* new_obj = owner_->ObjectClone(src, false, true); ptr->SetObject(new_obj); const size_t new_size = owner_->ObjectAllocSize(ptr->Raw()); // Important: we set the ttl bit on the wrapping pointer. curr_entry_->SetTtl(true); owner_->ObjDelete(src, false); src = new_obj; // Because setting TTL requires an extra 4 bytes for the key, the allocated size may push the // object into a different mi-malloc page category (e.g. 16 byte page -> 32 byte page). This // results in increased reporting in ObjAllocSize. // // If this size increase is not accounted for, it will cause an overflow in // DenseSet::AddOrReplaceObj due to subtracting larger size from smaller and the type of // obj_malloc_used_ being size_t. if (old_size != new_size) { owner_->DecreaseMallocUsed(old_size); owner_->IncreaseMallocUsed(new_size); } } owner_->ObjUpdateExpireTime(src, ttl_sec); } void DenseSet::IteratorBase::Advance() { bool step_link = false; DCHECK(curr_entry_); if (curr_entry_->IsLink()) { DenseLinkKey* plink = curr_entry_->AsLink(); if (!owner_->ExpireIfNeeded(curr_entry_, &plink->next) || curr_entry_->IsLink()) { curr_entry_ = &plink->next; step_link = true; } } if (!step_link) { DCHECK(curr_list_ != owner_->entries_.end()); do { ++curr_list_; if (curr_list_ == owner_->entries_.end()) { curr_entry_ = nullptr; owner_ = nullptr; return; } owner_->ExpireIfNeeded(nullptr, &(*curr_list_)); } while (curr_list_->IsEmpty()); DCHECK(curr_list_ != owner_->entries_.end()); curr_entry_ = &(*curr_list_); } DCHECK(!curr_entry_->IsEmpty()); } DenseSet::DenseSet() { static_assert(sizeof(entries_) == 24); } DenseSet::~DenseSet() { // We can not call Clear from the base class because it internally calls ObjDelete which is // a virtual function. Therefore, destructor of the derived classes must clean up the table. CHECK(entries_.empty()); } size_t DenseSet::PushFront(DenseSet::ChainVectorIterator it, void* data, bool has_ttl) { // if this is an empty list assign the value to the empty placeholder pointer DCHECK(!it->IsDisplaced()); if (it->IsEmpty()) { it->SetObject(data); } else { // otherwise make a new link and connect it to the front of the list it->SetLink(NewLink(data, *it)); } if (has_ttl) { it->SetTtl(true); expiration_used_ = true; } return ObjectAllocSize(data); } void DenseSet::PushFront(DenseSet::ChainVectorIterator it, DenseSet::DensePtr ptr) { DVLOG(2) << "PushFront to " << distance(entries_.begin(), it) << ", " << ObjectAllocSize(ptr.GetObject()); DCHECK(!it->IsDisplaced()); if (it->IsEmpty()) { it->SetObject(ptr.GetObject()); if (ptr.HasTtl()) { it->SetTtl(true); expiration_used_ = true; } if (ptr.IsLink()) { FreeLink(ptr.AsLink()); } } else if (ptr.IsLink()) { // if the pointer is already a link then no allocation needed. *ptr.Next() = *it; *it = ptr; DCHECK(!it->AsLink()->next.IsEmpty()); } else { DCHECK(ptr.IsObject()); // allocate a new link if needed and copy the pointer to the new link it->SetLink(NewLink(ptr.Raw(), *it)); if (ptr.HasTtl()) { it->SetTtl(true); expiration_used_ = true; } DCHECK(!it->AsLink()->next.IsEmpty()); } } auto DenseSet::PopPtrFront(DenseSet::ChainVectorIterator it) -> DensePtr { if (it->IsEmpty()) { return DensePtr{}; } DensePtr front = *it; // if this is an object, then it's also the only record in this chain. // therefore, we should just reset DensePtr. if (it->IsObject()) { it->Reset(); } else { DCHECK(it->IsLink()); DenseLinkKey* link = it->AsLink(); *it = link->next; } return front; } uint32_t DenseSet::ClearStep(uint32_t start, uint32_t count) { constexpr unsigned kArrLen = 32; ClearItem arr[kArrLen]; unsigned len = 0; size_t end = min(entries_.size(), start + count); for (size_t i = start; i < end; ++i) { DensePtr& ptr = entries_[i]; if (ptr.IsEmpty()) continue; auto& dest = arr[len++]; dest.has_ttl = ptr.HasTtl(); PREFETCH_READ(ptr.Raw()); if (ptr.IsObject()) { dest.obj = ptr.Raw(); dest.ptr.Reset(); } else { dest.ptr = ptr; dest.obj = nullptr; } ptr.Reset(); if (len == kArrLen) { ClearBatch(kArrLen, arr); len = 0; } } ClearBatch(len, arr); if (size_ == 0) { entries_.clear(); num_links_ = 0; obj_malloc_used_ = 0; expiration_used_ = false; } return end; } bool DenseSet::Equal(DensePtr dptr, const void* ptr, uint32_t cookie) const { if (dptr.IsEmpty()) { return false; } return ObjEqual(dptr.GetObject(), ptr, cookie); } void DenseSet::CloneBatch(unsigned len, CloneItem* items, DenseSet* other) const { // We handle a batch of items to minimize data dependencies when accessing memory for a single // item. We prefetch the memory for entire batch before actually reading data from any of the // elements. auto clone = [this](void* obj, bool has_ttl, DenseSet* other) { // The majority of the CPU is spent in this block. void* new_obj = other->ObjectClone(obj, has_ttl, false); uint64_t hash = this->Hash(obj, 0); other->AddUnique(new_obj, has_ttl, hash); }; while (len) { unsigned dest_id = 0; // we walk "len" linked lists in parallel, and prefetch their next, obj pointers // before actually processing them. for (unsigned i = 0; i < len; ++i) { auto& src = items[i]; if (src.obj) { clone(src.obj, src.has_ttl, other); src.obj = nullptr; } if (src.ptr.IsEmpty()) { continue; } if (src.ptr.IsObject()) { clone(src.ptr.Raw(), src.has_ttl, other); } else { auto& dest = items[dest_id++]; DenseLinkKey* link = src.ptr.AsLink(); dest.obj = link->Raw(); DCHECK(!link->HasTtl()); // ttl is attached to the wrapping pointer. dest.has_ttl = src.ptr.HasTtl(); dest.ptr = link->next; PREFETCH_READ(dest.ptr.Raw()); PREFETCH_READ(dest.obj); } } // update the length of the batch for the next iteration. len = dest_id; } } void DenseSet::ClearBatch(unsigned len, ClearItem* items) { while (len) { unsigned dest_id = 0; // we walk "len" linked lists in parallel, and prefetch their next, obj pointers // before actually processing them. for (unsigned i = 0; i < len; ++i) { auto& src = items[i]; if (src.obj) { ObjDelete(src.obj, src.has_ttl); --size_; src.obj = nullptr; } if (src.ptr.IsEmpty()) continue; if (src.ptr.IsObject()) { ObjDelete(src.ptr.Raw(), src.has_ttl); --size_; } else { auto& dest = items[dest_id++]; DenseLinkKey* link = src.ptr.AsLink(); DCHECK(!link->HasTtl()); dest.obj = link->Raw(); dest.has_ttl = src.ptr.HasTtl(); dest.ptr = link->next; PREFETCH_READ(dest.ptr.Raw()); PREFETCH_READ(dest.obj); FreeLink(link); } } // update the length of the batch for the next iteration. len = dest_id; } } bool DenseSet::NoItemBelongsBucket(uint32_t bid) const { auto& entries = const_cast(this)->entries_; DensePtr* curr = &entries[bid]; ExpireIfNeeded(nullptr, curr); if (!curr->IsEmpty() && !curr->IsDisplaced()) { return false; } if (bid + 1 < entries_.size()) { DensePtr* right_bucket = &entries[bid + 1]; ExpireIfNeeded(nullptr, right_bucket); if (!right_bucket->IsEmpty() && right_bucket->IsDisplaced() && right_bucket->GetDisplacedDirection() == 1) return false; } if (bid > 0) { DensePtr* left_bucket = &entries[bid - 1]; ExpireIfNeeded(nullptr, left_bucket); if (!left_bucket->IsEmpty() && left_bucket->IsDisplaced() && left_bucket->GetDisplacedDirection() == -1) return false; } return true; } auto DenseSet::FindEmptyAround(uint32_t bid) -> ChainVectorIterator { ExpireIfNeeded(nullptr, &entries_[bid]); if (entries_[bid].IsEmpty()) { return entries_.begin() + bid; } if (!kAllowDisplacements) { return entries_.end(); } if (bid + 1 < entries_.size()) { auto it = next(entries_.begin(), bid + 1); ExpireIfNeeded(nullptr, &(*it)); if (it->IsEmpty()) return it; } if (bid) { auto it = next(entries_.begin(), bid - 1); ExpireIfNeeded(nullptr, &(*it)); if (it->IsEmpty()) return it; } return entries_.end(); } void DenseSet::Reserve(size_t sz) { sz = std::max(sz, kMinSize); sz = absl::bit_ceil(sz); if (sz > entries_.size()) { size_t prev_size = entries_.size(); entries_.resize(sz); capacity_log_ = absl::bit_width(sz) - 1; Grow(prev_size); } } void DenseSet::ShrinkBucket(size_t bucket_idx) { // Take the entire bucket to avoid infinite loop when new_bid == bucket_idx DensePtr bucket = entries_[bucket_idx]; entries_[bucket_idx].Reset(); // Process the taken bucket chain while (!bucket.IsEmpty()) { // Pop front from local chain DensePtr dptr = bucket; bucket = bucket.IsObject() ? DensePtr{} : bucket.AsLink()->next; void* obj = dptr.GetObject(); bool has_ttl = dptr.HasTtl(); // Free link unconditionally - PushFront will create new one if needed if (dptr.IsLink()) { FreeLink(dptr.AsLink()); } if (has_ttl && ObjExpireTime(obj) <= time_now_) { ObjDelete(obj, true); --size_; continue; } uint32_t new_bid = BucketId(obj, 0); DVLOG(2) << " Shrink: Moving from " << bucket_idx << " to " << new_bid; PushFront(entries_.begin() + new_bid, obj, has_ttl); } } void DenseSet::Shrink(size_t new_size) { DCHECK(absl::has_single_bit(new_size)); DCHECK_GE(new_size, kMinSize); DCHECK_LT(new_size, entries_.size()); size_t prev_size = entries_.size(); capacity_log_ = absl::bit_width(new_size) - 1; // Process from low to high (opposite of Grow). // This prevents double-processing: when moving elements from bucket i to bucket j < i, // bucket j has already been processed, so the element won't be processed again. for (size_t i = 0; i < prev_size; ++i) { ShrinkBucket(i); } entries_.resize(new_size); } void DenseSet::Fill(DenseSet* other) const { DCHECK(other->entries_.empty()); other->Reserve(UpperBoundSize()); constexpr unsigned kArrLen = 32; CloneItem arr[kArrLen]; unsigned len = 0; for (auto it = entries_.begin(); it != entries_.end(); ++it) { DensePtr ptr = *it; if (ptr.IsEmpty()) continue; auto& item = arr[len++]; item.has_ttl = ptr.HasTtl(); if (ptr.IsObject()) { item.ptr.Reset(); item.obj = ptr.Raw(); PREFETCH_READ(item.obj); } else { item.ptr = ptr; item.obj = nullptr; PREFETCH_READ(item.ptr.Raw()); } if (len == kArrLen) { CloneBatch(kArrLen, arr, other); len = 0; } } CloneBatch(len, arr, other); } void DenseSet::Grow(size_t prev_size) { DensePtr first; // Corner case. Usually elements are moved to higher buckets during rehashing. // By moving upper elements first we make sure that there are no displaced elements // when we move the lower elements. // However the (displaced) elements at bucket_id=1 can move to bucket 0, and // bucket 0 can host displaced elements from bucket 1. To avoid this situation, we // stash the displaced element from bucket 0 and move it to the correct bucket at the end. if (entries_.front().IsDisplaced()) { first = PopPtrFront(entries_.begin()); } // perform rehashing of items in the array, chain by chain. for (long i = prev_size - 1; i >= 0; --i) { DensePtr* curr = &entries_[i]; DensePtr* prev = nullptr; do { if (ExpireIfNeeded(prev, curr)) { // if curr has disappeared due to expiry and prev was converted from Link to a // regular DensePtr if (prev && !prev->IsLink()) break; } if (curr->IsEmpty()) break; void* ptr = curr->GetObject(); DCHECK(ptr != nullptr && ObjectAllocSize(ptr)); uint32_t bid = BucketId(ptr, 0); // if the item does not move from the current chain, ensure // it is not marked as displaced and move to the next item in the chain if (bid == i) { curr->ClearDisplaced(); prev = curr; curr = curr->Next(); if (curr == nullptr) break; } else { // if the entry is in the wrong chain remove it and // add it to the correct chain. This will also correct // displaced entries auto dest = entries_.begin() + bid; DensePtr dptr = *curr; if (curr->IsObject()) { if (prev) { DCHECK(prev->IsLink()); DenseLinkKey* plink = prev->AsLink(); DCHECK(&plink->next == curr); // we want to make *prev a DensePtr instead of DenseLink and we // want to deallocate the link. DensePtr tmp = DensePtr::From(plink); // Important to transfer the ttl flag. tmp.SetTtl(prev->HasTtl()); DCHECK(ObjectAllocSize(tmp.GetObject())); FreeLink(plink); // we deallocated the link, curr is invalid now. curr = nullptr; *prev = tmp; } else { // prev == nullptr curr->Reset(); // reset the root placeholder. } } else { // !curr.IsObject *curr = *dptr.Next(); DCHECK(!curr->IsEmpty()); } DVLOG(2) << " Pushing to " << bid << " " << dptr.GetObject(); DCHECK_EQ(BucketId(dptr.GetObject(), 0), bid); PushFront(dest, dptr); } } while (curr); } if (!first.IsEmpty()) { uint32_t bid = BucketId(first.GetObject(), 0); PushFront(entries_.begin() + bid, first); } } // Assumes that the object does not exist in the set. void DenseSet::AddUnique(void* obj, bool has_ttl, uint64_t hashcode) { if (entries_.empty()) { capacity_log_ = kMinSizeShift; entries_.resize(kMinSize); } uint32_t bucket_id = BucketId(hashcode); DCHECK_LT(bucket_id, entries_.size()); // Try insert into flat surface first. Also handle the grow case // if utilization is too high. for (unsigned j = 0; j < 2; ++j) { ChainVectorIterator list = FindEmptyAround(bucket_id); if (list != entries_.end()) { obj_malloc_used_ += PushFront(list, obj, has_ttl); if (std::distance(entries_.begin(), list) != bucket_id) { list->SetDisplaced(std::distance(entries_.begin() + bucket_id, list)); } ++size_; return; } if (size_ < entries_.size()) { break; } size_t prev_size = entries_.size(); entries_.resize(prev_size * 2); ++capacity_log_; Grow(prev_size); bucket_id = BucketId(hashcode); } DCHECK(!entries_[bucket_id].IsEmpty()); /** * Since the current entry is not empty, it is either a valid chain * or there is a displaced node here. In the latter case it is best to * move the displaced node to its correct bucket. However there could be * a displaced node there and so forth. Keep to avoid having to keep a stack * of displacements we can keep track of the current displaced node, add it * to the correct chain, and if the correct chain contains a displaced node * unlink it and repeat the steps */ DensePtr to_insert(obj); if (has_ttl) { to_insert.SetTtl(true); expiration_used_ = true; } while (!entries_[bucket_id].IsEmpty() && entries_[bucket_id].IsDisplaced()) { DensePtr unlinked = PopPtrFront(entries_.begin() + bucket_id); PushFront(entries_.begin() + bucket_id, to_insert); to_insert = unlinked; bucket_id -= unlinked.GetDisplacedDirection(); } DCHECK_EQ(BucketId(to_insert.GetObject(), 0), bucket_id); ChainVectorIterator list = entries_.begin() + bucket_id; PushFront(list, to_insert); obj_malloc_used_ += ObjectAllocSize(obj); DCHECK(!entries_[bucket_id].IsDisplaced()); ++size_; } void DenseSet::Prefetch(uint64_t hash) { uint32_t bid = BucketId(hash); PREFETCH_READ(&entries_[bid]); } auto DenseSet::Find2(const void* ptr, uint32_t bid, uint32_t cookie) -> tuple { DCHECK_LT(bid, entries_.size()); DensePtr* curr = &entries_[bid]; ExpireIfNeeded(nullptr, curr); if (Equal(*curr, ptr, cookie)) { return {bid, nullptr, curr}; } // first look for displaced nodes since this is quicker than iterating a potential long chain if (bid > 0) { curr = &entries_[bid - 1]; if (curr->IsDisplaced() && curr->GetDisplacedDirection() == -1) { ExpireIfNeeded(nullptr, curr); if (Equal(*curr, ptr, cookie)) { return {bid - 1, nullptr, curr}; } } } if (bid + 1 < entries_.size()) { curr = &entries_[bid + 1]; if (curr->IsDisplaced() && curr->GetDisplacedDirection() == 1) { ExpireIfNeeded(nullptr, curr); if (Equal(*curr, ptr, cookie)) { return {bid + 1, nullptr, curr}; } } } // if the node is not displaced, search the correct chain DensePtr* prev = &entries_[bid]; curr = prev->Next(); while (curr != nullptr) { ExpireIfNeeded(prev, curr); if (Equal(*curr, ptr, cookie)) { return {bid, prev, curr}; } prev = curr; curr = curr->Next(); } // not in the Set return {0, nullptr, nullptr}; } void* DenseSet::Delete(DensePtr* prev, DensePtr* ptr, bool detach) { void* obj = nullptr; if (ptr->IsObject()) { obj = ptr->Raw(); ptr->Reset(); if (prev) { DCHECK(prev->IsLink()); DenseLinkKey* plink = prev->AsLink(); DensePtr tmp = DensePtr::From(plink); // Transfer TTL flag tmp.SetTtl(prev->HasTtl()); DCHECK(ObjectAllocSize(tmp.GetObject())); FreeLink(plink); *prev = tmp; DCHECK(!prev->IsLink()); } } else { DCHECK(ptr->IsLink()); DenseLinkKey* link = ptr->AsLink(); obj = link->Raw(); *ptr = link->next; FreeLink(link); } obj_malloc_used_ -= ObjectAllocSize(obj); --size_; if (detach) { return obj; } ObjDelete(obj, false); return nullptr; } DenseSet::ChainVectorIterator DenseSet::GetRandomChain() { if (entries_.empty() || size_ == 0) { return entries_.end(); } size_t offset = absl::Uniform(tl_bit_gen, 0u, entries_.size()); // Start at random position and scan linearly with wrap-around auto it = entries_.begin() + offset; for (size_t n = 0; n < entries_.size(); n++) { // Check IsEmpty first to avoid ExpireIfNeeded overhead on empty buckets if (!it->IsEmpty()) { ExpireIfNeeded(nullptr, &*it); if (!it->IsEmpty()) { return it; } } if (++it == entries_.end()) { it = entries_.begin(); } } return entries_.end(); } DenseSet::IteratorBase DenseSet::GetRandomIterator() { ChainVectorIterator chain_it = GetRandomChain(); if (chain_it == entries_.end()) return IteratorBase{}; DensePtr* ptr = &*chain_it; while (ptr->IsLink() && absl::Bernoulli(tl_bit_gen, 0.5)) { DensePtr* next = ptr->Next(); if (ExpireIfNeeded(ptr, next)) // stop if we break the chain with expiration break; ptr = next; } return IteratorBase{(DenseSet*)this, chain_it, ptr}; } void* DenseSet::PopInternal() { auto bucket_iter = GetRandomChain(); // Find first non empty chain if (bucket_iter == entries_.end()) return nullptr; // unlink the first node in the first non-empty chain obj_malloc_used_ -= ObjectAllocSize(bucket_iter->GetObject()); DensePtr front = PopPtrFront(bucket_iter); void* ret = front.GetObject(); if (front.IsLink()) { FreeLink(front.AsLink()); } --size_; return ret; } void* DenseSet::AddOrReplaceObj(void* obj, bool has_ttl) { uint64_t hc = Hash(obj, 0); DensePtr* dptr = entries_.empty() ? nullptr : Find(obj, BucketId(hc), 0).second; if (dptr) { // replace existing object. // A bit confusing design: ttl bit is located on the wrapping pointer, // therefore we must set ttl bit before unrapping below. dptr->SetTtl(has_ttl); if (dptr->IsLink()) // unwrap the pointer. dptr = dptr->AsLink(); void* res = dptr->Raw(); const size_t res_sz = ObjectAllocSize(res); DCHECK_GE(obj_malloc_used_, res_sz); obj_malloc_used_ -= res_sz; obj_malloc_used_ += ObjectAllocSize(obj); dptr->SetObject(obj); return res; } AddUnique(obj, has_ttl, hc); return nullptr; } /** * stable scanning api. has the same guarantees as redis scan command. * we avoid doing bit-reverse by using a different function to derive a bucket id * from hash values. By using msb part of hash we make it "stable" with respect to * rehashes. For example, with table log size 4 (size 16), entries in bucket id * 1110 come from hashes 1110XXXXX.... When a table grows to log size 5, * these entries can move either to 11100 or 11101. So if we traversed with our cursor * range [0000-1110], it's guaranteed that in grown table we do not need to cover again * [00000-11100]. Similarly with shrinkage, if a table is shrunk to log size 3, * keys from 1110 and 1111 will move to bucket 111. Again, it's guaranteed that we * covered the range [000-111] (all keys in that case). * Returns: next cursor or 0 if reached the end of scan. * cursor = 0 - initiates a new scan. */ uint32_t DenseSet::Scan(uint32_t cursor, const ItemCb& cb) const { // empty set if (capacity_log_ == 0) { return 0; } uint32_t entries_idx = cursor >> (32 - capacity_log_); auto& entries = const_cast(this)->entries_; // First find the bucket to scan, skip empty buckets. // A bucket is empty if the current index is empty and the data is not displaced // to the right or to the left. while (entries_idx < entries_.size() && NoItemBelongsBucket(entries_idx)) { ++entries_idx; } if (entries_idx == entries_.size()) { return 0; } DensePtr* curr = &entries[entries_idx]; // Check home bucket if (!curr->IsEmpty() && !curr->IsDisplaced()) { // scanning add all entries in a given chain while (true) { cb(curr->GetObject()); if (!curr->IsLink()) break; DensePtr* mcurr = const_cast(curr); if (ExpireIfNeeded(mcurr, &mcurr->AsLink()->next) && !mcurr->IsLink()) { break; } curr = &curr->AsLink()->next; } } // Check if the bucket on the left belongs to the home bucket. if (entries_idx > 0) { DensePtr* left_bucket = &entries[entries_idx - 1]; ExpireIfNeeded(nullptr, left_bucket); if (left_bucket->IsDisplaced() && left_bucket->GetDisplacedDirection() == -1) { // left of the home bucket cb(left_bucket->GetObject()); } } // move to the next index for the next scan and check if we are done ++entries_idx; if (entries_idx >= entries_.size()) { return 0; } // Check if the bucket on the right belongs to the home bucket. DensePtr* right_bucket = &entries[entries_idx]; ExpireIfNeeded(nullptr, right_bucket); if (right_bucket->IsDisplaced() && right_bucket->GetDisplacedDirection() == 1) { // right of the home bucket cb(right_bucket->GetObject()); } return entries_idx << (32 - capacity_log_); } auto DenseSet::NewLink(void* data, DensePtr next) -> DenseLinkKey* { using LinkAllocator = StatelessAllocator; LinkAllocator la; DenseLinkKey* lk = la.allocate(1); la.construct(lk); lk->next = next; lk->SetObject(data); ++num_links_; return lk; } bool DenseSet::ExpireIfNeededInternal(DensePtr* prev, DensePtr* node) const { DCHECK(node != nullptr); DCHECK(node->HasTtl()); bool deleted = false; do { uint32_t obj_time = ObjExpireTime(node->GetObject()); if (obj_time > time_now_) { break; } // updates the *node to next item if relevant or resets it to empty. const_cast(this)->Delete(prev, node); deleted = true; } while (node->HasTtl()); return deleted; } void DenseSet::CollectExpired() { // Simply iterating over all items will remove expired auto it = IteratorBase(this, false); while (it.curr_entry_ != nullptr) { it.Advance(); } } size_t DenseSet::SizeSlow() { CollectExpired(); return size_; } } // namespace dfly ================================================ FILE: src/core/dense_set.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include "core/detail/stateless_allocator.h" namespace dfly { // DenseSet is a nice but over-optimized data-structure. Probably is not worth it in the first // place but sometimes the OCD kicks in and one can not resist. // The advantage of it over redis-dict is smaller meta-data waste. // dictEntry is 24 bytes, i.e it uses at least 32N bytes where N is the expected length. // dict requires to allocate dictEntry per each addition in addition to the supplied key. // It also wastes space in case of a set because it stores a value pointer inside dictEntry. // To summarize: // 100% utilized dict uses N*24 + N*8 = 32N bytes not including the key space. // for 75% utilization (1/0.75 buckets): N*1.33*8 + N*24 = 35N // // This class uses 8 bytes per bucket (similarly to dictEntry*) but it used it for both // links and keys. For most cases, we remove the need for another redirection layer // and just store the key, so no "dictEntry" allocations occur. // For those cells that require chaining, the bucket is // changed in run-time to represent a linked chain. // Additional feature - in order to to reduce collisions, we insert items into // neighbour cells but only if they are empty (not chains). This way we reduce the number of // empty (unused) spaces at full utilization from 36% to ~21%. // 100% utilized table requires: N*8 + 0.2N*16 = 11.2N bytes or ~20 bytes savings. // 75% utilization: N*1.33*8 + 0.12N*16 = 13N or ~22 bytes savings per record. // with potential replacements of hset/zset data structures. // static_assert(sizeof(dictEntry) == 24); class DenseSet { struct DenseLinkKey; // we can assume that high 12 bits of user address space // can be used for tagging. At most 52 bits of address are reserved for // some configurations, and usually it's 48 bits. // https://docs.kernel.org/arch/arm64/memory.html static constexpr size_t kLinkBit = 1ULL << 52; static constexpr size_t kDisplaceBit = 1ULL << 53; static constexpr size_t kDisplaceDirectionBit = 1ULL << 54; static constexpr size_t kTtlBit = 1ULL << 55; static constexpr size_t kTagMask = 4095ULL << 52; // we reserve 12 high bits. class DensePtr { public: explicit DensePtr(void* p = nullptr) : ptr_(p) { } // Imports the object with its metadata except the link bit that is reset. static DensePtr From(DenseLinkKey* o) { DensePtr res; res.ptr_ = (void*)(o->uptr() & (~kLinkBit)); return res; } uint64_t uptr() const { return uint64_t(ptr_); } bool IsObject() const { return (uptr() & kLinkBit) == 0; } bool IsLink() const { return (uptr() & kLinkBit) != 0; } bool HasTtl() const { return (uptr() & kTtlBit) != 0; } bool IsEmpty() const { return ptr_ == nullptr; } void* Raw() const { return (void*)(uptr() & ~kTagMask); } bool IsDisplaced() const { return (uptr() & kDisplaceBit) == kDisplaceBit; } void SetLink(DenseLinkKey* lk) { ptr_ = (void*)(uintptr_t(lk) | kLinkBit); } void SetDisplaced(int direction) { ptr_ = (void*)(uptr() | kDisplaceBit); if (direction == 1) { ptr_ = (void*)(uptr() | kDisplaceDirectionBit); } } void ClearDisplaced() { ptr_ = (void*)(uptr() & ~(kDisplaceBit | kDisplaceDirectionBit)); } // returns 1 if the displaced node is right of the correct bucket and -1 if it is left int GetDisplacedDirection() const { return (uptr() & kDisplaceDirectionBit) == kDisplaceDirectionBit ? 1 : -1; } void SetTtl(bool b) { if (b) ptr_ = (void*)(uptr() | kTtlBit); else ptr_ = (void*)(uptr() & (~kTtlBit)); } void Reset() { ptr_ = nullptr; } void* GetObject() const { if (IsObject()) { return Raw(); } return AsLink()->Raw(); } // Sets pointer but preserves tagging info void SetObject(void* obj) { assert(IsObject()); ptr_ = (void*)((uptr() & kTagMask) | (uintptr_t(obj) & ~kTagMask)); } DenseLinkKey* AsLink() { return (DenseLinkKey*)Raw(); } const DenseLinkKey* AsLink() const { return (const DenseLinkKey*)Raw(); } DensePtr* Next() { if (!IsLink()) { return nullptr; } return &AsLink()->next; } const DensePtr* Next() const { if (!IsLink()) { return nullptr; } return &AsLink()->next; } private: void* ptr_ = nullptr; }; struct DenseLinkKey : public DensePtr { DensePtr next; // could be LinkKey* or Object *. }; static_assert(sizeof(DensePtr) == sizeof(uintptr_t)); static_assert(sizeof(DenseLinkKey) == 2 * sizeof(uintptr_t)); protected: using DensePtrAllocator = StatelessAllocator; using ChainVectorIterator = std::vector::iterator; using ChainVectorConstIterator = std::vector::const_iterator; class IteratorBase { friend class DenseSet; public: IteratorBase(DenseSet* owner, ChainVectorIterator list_it, DensePtr* e) : owner_(owner), curr_list_(list_it), curr_entry_(e) { } // returns the expiry time of the current entry or UINT32_MAX if no ttl is set. uint32_t ExpiryTime() const { return curr_entry_->HasTtl() ? owner_->ObjExpireTime(curr_entry_->GetObject()) : UINT32_MAX; } void SetExpiryTime(uint32_t ttl_sec); bool HasExpiry() const { return curr_entry_->HasTtl(); } protected: IteratorBase() : owner_(nullptr), curr_entry_(nullptr) { } IteratorBase(const DenseSet* owner, bool is_end); void Advance(); DenseSet* owner_; ChainVectorIterator curr_list_; DensePtr* curr_entry_; }; public: static constexpr uint32_t kMaxBatchLen = 32; explicit DenseSet(); virtual ~DenseSet(); void Clear() { ClearStep(0, entries_.size()); } // Returns the next bucket index that should be cleared. // Returns BucketCount when all objects are erased. uint32_t ClearStep(uint32_t start, uint32_t count); // Returns the number of elements in the map. Note that it might be that some of these elements // have expired and can't be accessed. size_t UpperBoundSize() const { return size_; } // Returns an accurate size, post-expiration. O(n). size_t SizeSlow(); bool Empty() const { return size_ == 0; } size_t BucketCount() const { return entries_.size(); } size_t ObjMallocUsed() const { return obj_malloc_used_; } size_t SetMallocUsed() const { return entries_.capacity() * sizeof(DensePtr) + num_links_ * sizeof(DenseLinkKey); } using ItemCb = std::function; uint32_t Scan(uint32_t cursor, const ItemCb& cb) const; void Reserve(size_t sz); // Shrinks the table to the specified size. The size must be a power of 2, // >= kMinSize, and >= current number of elements. // This method should be called explicitly when memory reclamation is needed. void Shrink(size_t new_size); void Fill(DenseSet* other) const; // set an abstract time that allows expiry. void set_time(uint32_t val) { time_now_ = val; } uint32_t time_now() const { return time_now_; } bool ExpirationUsed() const { return expiration_used_; } protected: // Virtual functions to be implemented for generic data virtual uint64_t Hash(const void* obj, uint32_t cookie) const = 0; virtual bool ObjEqual(const void* left, const void* right, uint32_t right_cookie) const = 0; virtual size_t ObjectAllocSize(const void* obj) const = 0; virtual uint32_t ObjExpireTime(const void* obj) const = 0; virtual void ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) = 0; virtual void ObjDelete(void* obj, bool has_ttl) const = 0; virtual void* ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const = 0; void CollectExpired(); bool EraseInternal(void* obj, uint32_t cookie) { auto [prev, found] = Find(obj, BucketId(obj, cookie), cookie); if (found) { Delete(prev, found); return true; } return false; } // Like EraseInternal but returns the detached object instead of deleting it. // Returns nullptr if the object was not found. void* DetachInternal(void* obj, uint32_t cookie) { auto [prev, found] = Find(obj, BucketId(obj, cookie), cookie); if (found) { return Delete(prev, found, true); } return nullptr; } void* FindInternal(const void* obj, uint64_t hashcode, uint32_t cookie) const; IteratorBase FindIt(const void* ptr, uint32_t cookie) { if (Empty()) return IteratorBase{}; auto [bid, _, curr] = Find2(ptr, BucketId(ptr, cookie), cookie); if (curr) { return IteratorBase(this, entries_.begin() + bid, curr); } return IteratorBase{}; } // Get iterator to start of random non-empty chain (bucket) ChainVectorIterator GetRandomChain(); // Wrap RandomChain() into iterator and advance with reservoir sampling IteratorBase GetRandomIterator(); void* PopInternal(); void IncreaseMallocUsed(size_t delta) { obj_malloc_used_ += delta; } void DecreaseMallocUsed(size_t delta) { obj_malloc_used_ -= delta; } // Returns the previous object if it has been replaced. // nullptr, if obj was added. void* AddOrReplaceObj(void* obj, bool has_ttl); // Assumes that the object does not exist in the set. void AddUnique(void* obj, bool has_ttl, uint64_t hashcode); void Prefetch(uint64_t hash); private: DenseSet(const DenseSet&) = delete; DenseSet& operator=(DenseSet&) = delete; bool Equal(DensePtr dptr, const void* ptr, uint32_t cookie) const; struct CloneItem { DensePtr ptr; void* obj = nullptr; bool has_ttl = false; }; void CloneBatch(unsigned len, CloneItem* items, DenseSet* other) const; using ClearItem = CloneItem; void ClearBatch(unsigned len, ClearItem* items); uint32_t BucketId(uint64_t hash) const { assert(capacity_log_ > 0); return hash >> (64 - capacity_log_); } uint32_t BucketId(const void* ptr, uint32_t cookie) const { return BucketId(Hash(ptr, cookie)); } // return a ChainVectorIterator (a.k.a iterator) or end if there is an empty chain found ChainVectorIterator FindEmptyAround(uint32_t bid); // Return if bucket has no item which is not displaced and right/left bucket has no displaced item // belong to given bid bool NoItemBelongsBucket(uint32_t bid) const; void Grow(size_t prev_size); // ============ Pseudo Linked List Functions for interacting with Chains ================== size_t PushFront(ChainVectorIterator, void* obj, bool has_ttl); void PushFront(ChainVectorIterator, DensePtr); DensePtr PopPtrFront(ChainVectorIterator); // ============ Pseudo Linked List in DenseSet end ================== // returns (prev, item) pair. If item is root, then prev is null. std::pair Find(const void* ptr, uint32_t bid, uint32_t cookie) { auto [_, p, c] = Find2(ptr, bid, cookie); return {p, c}; } // returns bid and (prev, item) pair. If item is root, then prev is null. std::tuple Find2(const void* ptr, uint32_t bid, uint32_t cookie); DenseLinkKey* NewLink(void* data, DensePtr next); inline void FreeLink(DenseLinkKey* plink) { // deallocate the link if it is no longer a link as it is now in an empty list DensePtrAllocator::resource()->deallocate(plink, sizeof(DenseLinkKey), alignof(DenseLinkKey)); --num_links_; } // Returns true if *node was deleted. bool ExpireIfNeeded(DensePtr* prev, DensePtr* node) const { if (node->HasTtl()) { return ExpireIfNeededInternal(prev, node); } return false; } bool ExpireIfNeededInternal(DensePtr* prev, DensePtr* node) const; // Deletes the object pointed by ptr and removes it from the set. // If ptr is a link then it will be deleted internally. // If detach is true, returns the raw object instead of calling ObjDelete. void* Delete(DensePtr* prev, DensePtr* ptr, bool detach = false); // Processes a single bucket during Shrink, relocating elements as needed. void ShrinkBucket(size_t bucket_idx); std::vector entries_; mutable size_t obj_malloc_used_ = 0; mutable uint32_t size_ = 0; // number of elements in the set. mutable uint32_t num_links_ = 0; // number of links in the set. unsigned capacity_log_ = 0; uint32_t time_now_ = 0; mutable bool expiration_used_ = false; }; inline void* DenseSet::FindInternal(const void* obj, uint64_t hashcode, uint32_t cookie) const { if (entries_.empty()) return nullptr; uint32_t bid = BucketId(hashcode); DensePtr* ptr = const_cast(this)->Find(obj, bid, cookie).second; return ptr ? ptr->GetObject() : nullptr; } } // namespace dfly ================================================ FILE: src/core/detail/bitpacking.cc ================================================ // Copyright 2022, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "src/core/detail/bitpacking.h" #include #include "base/logging.h" #include "core/sse_port.h" using namespace std; namespace dfly { namespace detail { #if defined(__GNUC__) && !defined(__clang__) #pragma GCC push_options #pragma GCC optimize("Ofast") #endif static inline uint64_t Compress8x7bit(uint64_t x) { x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F); x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF); x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF); return x; } #if defined(__SSE3__) || defined(__aarch64__) static inline pair simd_variant1_pack(const char* ascii, const char* end, uint8_t* bin) { __m128i val, rpart, lpart; // Skips 8th byte (indexc 7) in the lower 8-byte part. const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0); // Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111 while (ascii <= end) { val = mm_loadu_si128(reinterpret_cast(ascii)); /* x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F); x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF); x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF); */ rpart = _mm_and_si128(val, _mm_set1_epi64x(0x007F007F007F007F)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0x7F007F007F007F00)); val = _mm_or_si128(_mm_srli_epi64(lpart, 1), rpart); rpart = _mm_and_si128(val, _mm_set1_epi64x(0x00003FFF00003FFF)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0x3FFF00003FFF0000)); val = _mm_or_si128(_mm_srli_epi64(lpart, 2), rpart); rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000)); val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart); val = _mm_shuffle_epi8(val, control); _mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val); bin += 14; ascii += 16; } return make_pair(ascii, bin); } static inline pair simd_variant2_pack(const char* ascii, const char* end, uint8_t* bin) { // Skips 8th byte (indexc 7) in the lower 8-byte part. const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0); __m128i val, rpart, lpart; // Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111 while (ascii <= end) { val = mm_loadu_si128(reinterpret_cast(ascii)); /* x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F); x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF); x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF); */ val = _mm_maddubs_epi16(_mm_set1_epi16(0x8001), val); val = _mm_madd_epi16(_mm_set1_epi32(0x40000001), val); rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000)); val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart); val = _mm_shuffle_epi8(val, control); _mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val); bin += 14; ascii += 16; } return make_pair(ascii, bin); } #endif // Daniel Lemire's function validate_ascii_fast() - under Apache/MIT license. // See https://github.com/lemire/fastvalidate-utf-8/ // The function returns true (1) if all chars passed in src are // 7-bit values (0x00..0x7F). Otherwise, it returns false (0). #ifdef __s390x__ bool validate_ascii_fast(const char* src, size_t len) { size_t i = 0; // Initialize a vector in which all the elements are set to zero. vector unsigned char has_error = vec_splat_s8(0); if (len >= 16) { for (; i <= len - 16; i += 16) { // Load 16 bytes from buffer into a vector. vector unsigned char current_bytes = vec_load_len((signed char*)(src + i), 16); // Perform a bitwise OR operation between the current and the previously loaded contents. has_error = vec_orc(has_error, current_bytes); } } // Initialize a vector in which all the elements are set to an invalid ASCII value. vector unsigned char rep_invalid_values = vec_splat_s8(0x80); // Perform bitwise AND-complement operation between two vectors. vector unsigned char andc_result = vec_andc(rep_invalid_values, has_error); // Tests whether any of corresponding elements of the given vectors are not equal. // After the bitwise operation, both vectors should be equal if ASCII values. if (!vec_all_eq(rep_invalid_values, andc_result)) { return false; } for (; i < len; i++) { if (src[i] & 0x80) { return false; } } return true; } #else bool validate_ascii_fast(const char* src, size_t len) { size_t i = 0; __m128i has_error = _mm_setzero_si128(); if (len >= 16) { for (; i <= len - 16; i += 16) { __m128i current_bytes = mm_loadu_si128((const __m128i*)(src + i)); has_error = _mm_or_si128(has_error, current_bytes); } } int error_mask = _mm_movemask_epi8(has_error); char tail_has_error = 0; for (; i < len; i++) { tail_has_error |= src[i]; } error_mask |= (tail_has_error & 0x80); return !error_mask; } #endif // len must be at least 16 void ascii_pack(const char* ascii, size_t len, uint8_t* bin) { uint64_t val; const char* end = ascii + len; while (ascii + 8 <= end) { val = absl::little_endian::Load64(ascii); uint64_t dest = (val & 0xFF); for (unsigned i = 1; i <= 7; ++i) { val >>= 1; dest |= (val & (0x7FUL << 7 * i)); } memcpy(bin, &dest, 7); bin += 7; ascii += 8; } // epilog - we do not pack since we have less than 8 bytes. while (ascii < end) { *bin++ = *ascii++; } } void ascii_pack2(const char* ascii, size_t len, uint8_t* bin) { uint64_t val; const char* end = ascii + len; while (ascii + 8 <= end) { val = absl::little_endian::Load64(ascii); val = Compress8x7bit(val); memcpy(bin, &val, 7); bin += 7; ascii += 8; } // epilog - we do not pack since we have less than 8 bytes. while (ascii < end) { *bin++ = *ascii++; } } // The algo - do in parallel what ascii_pack does on two uint64_t integers void ascii_pack_simd(const char* ascii, size_t len, uint8_t* bin) { #if defined(__SSE3__) || defined(__aarch64__) // I leave out 16 bytes in addition to 16 that we load in the loop // because we store into bin full 16 bytes instead of 14. To prevent data // overwrite we finish loop one iteration earlier. const char* end = ascii + len - 32; tie(ascii, bin) = simd_variant1_pack(ascii, end, bin); end += 32; // Bring back end. DCHECK(ascii < end); ascii_pack(ascii, end - ascii, bin); #else ascii_pack(ascii, len, bin); #endif } void ascii_pack_simd2(const char* ascii, size_t len, uint8_t* bin) { #if defined(__SSE3__) || defined(__aarch64__) // I leave out 16 bytes in addition to 16 that we load in the loop // because we store into bin full 16 bytes instead of 14. To prevent data // overwrite we finish loop one iteration earlier. const char* end = ascii + len - 32; // on arm var #if defined(__aarch64__) tie(ascii, bin) = simd_variant1_pack(ascii, end, bin); #else tie(ascii, bin) = simd_variant2_pack(ascii, end, bin); #endif end += 32; // Bring back end. DCHECK(ascii < end); ascii_pack(ascii, end - ascii, bin); #else ascii_pack(ascii, len, bin); #endif } // unpacks 8->7 encoded blob back to ascii. // generally, we can not unpack inplace because ascii (dest) buffer is 8/7 bigger than // the source buffer. // however, if binary data is positioned on the right of the ascii buffer with empty space on the // left than we can unpack inplace. void ascii_unpack(const uint8_t* bin, size_t ascii_len, char* ascii) { constexpr uint8_t kM = 0x7F; uint8_t p = 0; unsigned i = 0; while (ascii_len >= 8) { for (i = 0; i < 7; ++i) { uint8_t src = *bin; // keep on stack in case we unpack inplace. *ascii++ = (p >> (8 - i)) | ((src << i) & kM); p = src; ++bin; } ascii_len -= 8; *ascii++ = p >> 1; } DCHECK_LT(ascii_len, 8u); for (i = 0; i < ascii_len; ++i) { *ascii++ = *bin++; } } uint8_t ascii_unpack_byte(const uint8_t* bin, size_t ascii_len, size_t idx) { DCHECK(idx < ascii_len) << "Index oob for ascii byte unpacking: " << idx << " >= " << ascii_len; const size_t packed_groups = ascii_len / 8; const size_t group = idx / 8; const size_t idx_in_group = idx % 8; // Tail bytes (after the last full 8-char group) are stored unpacked. if (group >= packed_groups) { return bin[packed_groups * 7 + idx_in_group]; } // Unpack ascii group and return byte at idx. char buf[8]; ascii_unpack(bin + group * 7, 8, buf); return buf[idx_in_group]; } void ascii_pack_byte(uint8_t* bin, size_t ascii_len, size_t idx, uint8_t val) { DCHECK(idx < ascii_len) << "Index oob for ascii byte packing: " << idx << " >= " << ascii_len; DCHECK_LT(val, 128u) << "Only 7-bit ASCII values can be packed"; const size_t packed_groups = ascii_len / 8; const size_t group = idx / 8; const size_t idx_in_group = idx % 8; // Tail bytes (after the last full 8-char group) are stored unpacked. if (group >= packed_groups) { bin[packed_groups * 7 + idx_in_group] = val; return; } // Unpack ascii group and return, modify byte at idx and pack back. uint8_t* group_bin = bin + group * 7; char buf[8]; ascii_unpack(group_bin, 8, buf); buf[idx_in_group] = val; ascii_pack(buf, 8, group_bin); } // See CompactObjectTest.AsanTriggerReadOverflow for more details. void ascii_unpack_simd(const uint8_t* bin, size_t ascii_len, char* ascii) { #if defined(__SSE3__) || defined(__aarch64__) if (ascii_len < 18) { // ascii_len >=18 means bin length >=16. ascii_unpack(bin, ascii_len, ascii); return; } __m128i val, rpart, lpart; // we read 16 bytes from bin even when we need only 14 bytes. // So for last iteration we may access 2 bytes outside of the bin buffer. // To prevent this we need to round down the length of the bin buffer but since we // limit by ascii_len we reduce the ascii_len by two before computing number of iterations. size_t simd_len = ((ascii_len - 2) / 16) * 16; const char* end = ascii + simd_len; // shifts the second 7-byte blob to the left. const __m128i control = _mm_set_epi8(14, 13, 12, 11, 10, 9, 8, 7, -1, 6, 5, 4, 3, 2, 1, 0); while (ascii < end) { val = mm_loadu_si128(reinterpret_cast(bin)); val = _mm_shuffle_epi8(val, control); rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0x00FFFFFFF0000000)); val = _mm_or_si128(_mm_slli_epi64(lpart, 4), rpart); rpart = _mm_and_si128(val, _mm_set1_epi64x(0x00003FFF00003FFF)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0xFFFFC000FFFFC000)); val = _mm_or_si128(_mm_slli_epi64(lpart, 2), rpart); rpart = _mm_and_si128(val, _mm_set1_epi64x(0x007F007F007F007F)); lpart = _mm_and_si128(val, _mm_set1_epi64x(0x7F807F807F807F80)); val = _mm_or_si128(_mm_slli_epi64(lpart, 1), rpart); _mm_storeu_si128(reinterpret_cast<__m128i*>(ascii), val); ascii += 16; bin += 14; } ascii_len -= simd_len; if (ascii_len) ascii_unpack(bin, ascii_len, ascii); #else ascii_unpack(bin, ascii_len, ascii); #endif } // compares packed and unpacked strings. packed must be of length = binpacked_len(ascii_len). bool compare_packed(const uint8_t* packed, const char* ascii, size_t ascii_len) { unsigned i = 0; bool res = true; const char* end = ascii + ascii_len; while (ascii + 8 <= end) { for (i = 0; i < 7; ++i) { uint8_t conv = (ascii[0] >> i) | (ascii[1] << (7 - i)); res &= (conv == *packed); ++ascii; ++packed; } if (!res) return false; ++ascii; } while (ascii < end) { if (*ascii++ != *packed++) { return false; } } return true; } #if defined(__GNUC__) && !defined(__clang__) #pragma GCC pop_options #endif } // namespace detail } // namespace dfly ================================================ FILE: src/core/detail/bitpacking.h ================================================ // Copyright 2022, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace dfly { namespace detail { bool validate_ascii_fast(const char* src, size_t len); // unpacks 8->7 encoded blob back to ascii. // generally, we can not unpack inplace because ascii (dest) buffer is 8/7 bigger than // the source buffer. // however, if binary data is positioned on the right of the ascii buffer with empty space on the // left than we can unpack inplace. void ascii_unpack(const uint8_t* bin, size_t ascii_len, char* ascii); void ascii_unpack_simd(const uint8_t* bin, size_t ascii_len, char* ascii); // Access a single byte in a 7-bit ASCII-packed string without unpacking the entire buffer. // These helpers read/write the ASCII byte at logical position `idx` in the unpacked string // directly from/into the packed `bin` representation. // It's up to caller to verify: // `1. idx` must be less than `ascii_len` to avoid out-of-bounds access. // 2. `ascii` must be less than 128 (7-bit ASCII) for packing. uint8_t ascii_unpack_byte(const uint8_t* bin, size_t ascii_len, size_t idx); void ascii_pack_byte(uint8_t* bin, size_t ascii_len, size_t idx, uint8_t ascii); // packs ascii string (does not verify) into binary form saving 1 bit per byte on average (12.5%). void ascii_pack(const char* ascii, size_t len, uint8_t* bin); void ascii_pack2(const char* ascii, size_t len, uint8_t* bin); // SIMD implementation 1 of ascii_pack. void ascii_pack_simd(const char* ascii, size_t len, uint8_t* bin); // SIMD implementation 2 of ascii_pack. void ascii_pack_simd2(const char* ascii, size_t len, uint8_t* bin); bool compare_packed(const uint8_t* packed, const char* ascii, size_t ascii_len); // maps ascii len to 7-bit packed length. Each 8 bytes are converted to 7 bytes. inline constexpr size_t binpacked_len(size_t ascii_len) { return (ascii_len * 7 + 7) / 8; /* rounded up */ } // converts 7-bit packed length back to ascii length. Note that this conversion // is not accurate since it maps 7 bytes to 8 bytes (rounds up), while we may have // 7 byte strings converted to 7 byte as well. inline constexpr size_t ascii_len(size_t bin_len) { return (bin_len * 8) / 7; } } // namespace detail } // namespace dfly ================================================ FILE: src/core/detail/bptree_internal.h ================================================ // Copyright 2023, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include namespace dfly { template class BPTree; namespace detail { // Internal classes related to B+tree implementation. The design is largely based on the // implementation of absl::bPtree_map/set. // The motivation for replacing zskiplist - significant size reduction: // we reduce the metadata overhead per record from 45 bytes in zskiplist to just a // few bytes with b-tree. The trick is using significantly large nodes (256 bytes) so that // their overhead is negligible compared to the items they store. // Why not use absl::bPtree_set? We must support Rank tree functionality that // absl does not supply. // Hacking into absl is not a simple task, implementing our own tree is easier. // Below some design decisions: // 1. We use predefined node size of 256 bytes and derive number of items in each node from it. // Inner nodes have less items than leaf nodes because they also need to store child pointers. // 2. BPTreeNode does not predeclare fields besides the 8 bytes metadata - everything else is // calculated at run-time and has dedicated accessors (similarly to absl). This allows // dense and efficient representation of tree nodes. // 3. We assume that we store small items (8, 16 bytes) which will have a large branching // factor (248/16), meaning the tree will stay shallow even for sizes reaching billion nodes. // 4. We do not store parent pointer like in absl tree. Instead we use BPTreePath to store // hierarchy of parent nodes. That should reduce our overhead even further by few bits per item. // 5. We assume we store trivially copyable types - this reduces the // complexity of the generics in the code. // 6. We support pmr memory resource. This allows us to use pluggable heaps. // // TODO: (all the ideas taken from absl implementation) // 1. to introduce slices when removing items from the tree (avoid shifts). // 2. to avoid merging/rebalancing when removing max/min items from the tree. // 3. Small tree optimization: when the tree is small with a single root node, we can // allocate less then 256 bytes (special case) to avoid relative blowups in memory for // small trees. constexpr uint16_t kBPNodeSize = 256; /** * @brief The BPNodeLayout class is a helper class that defines the layout of the B+tree node. * The inner node looks like this: * | 4 bytes metadata | keys ... | 4 bytes tree-count | children nodes | * The leaf node looks like this: * | 4 bytes metadata | keys ... | * * @tparam T */ template class BPNodeLayout { static_assert(std::is_trivially_copyable::value, "KeyT must be triviall copyable"); static constexpr uint16_t kKeyOffset = 4; // 4 bytes for metadata static constexpr uint16_t kSubTreeLen = sizeof(uint32_t); // 4 bytes for count. public: static constexpr uint16_t kKeySize = sizeof(T); static constexpr uint16_t kMaxLeafKeys = (kBPNodeSize - kKeyOffset) / kKeySize; static constexpr uint16_t kMinLeafKeys = kMaxLeafKeys / 2; // internal node: // x slots, (x+1) children: x * kKeySize + (x+1) * sizeof(BPTreeNode*) = x * (kKeySize + 8) + 8 // x = (kBPNodeSize - kInnerKeyOffset - 8) / (kKeySize + 8) static constexpr uint16_t kMaxInnerKeys = (kBPNodeSize - sizeof(void*) - kKeyOffset - kSubTreeLen) / (kKeySize + sizeof(void*)); static constexpr uint16_t kMinInnerKeys = kMaxInnerKeys / 2; using KeyT = T; // The class is constructed inside a block of memory of size kBPNodeSize. // Only BPTree can create it, hence it can access the memory outside its fields. static uint8_t* KeyPtr(unsigned index, void* node) { return reinterpret_cast(node) + kKeyOffset + kKeySize * index; } static const uint8_t* KeyPtr(unsigned index, const void* node) { return reinterpret_cast(node) + kKeyOffset + kKeySize * index; } static uint8_t* TreeCountPtr(void* node) { return reinterpret_cast(node) + kKeyOffset + kKeySize * kMaxInnerKeys; } static const uint8_t* TreeCountPtr(const void* node) { return reinterpret_cast(node) + kKeyOffset + kKeySize * kMaxInnerKeys; } static uint8_t* ChildrenStart(void* node) { return TreeCountPtr(node) + kSubTreeLen; } static const uint8_t* ChildrenStart(const void* node) { return TreeCountPtr(node) + kSubTreeLen; } static_assert(kMaxLeafKeys < 128); }; template class BPTreeNode { template friend class ::dfly::BPTree; BPTreeNode(const BPTreeNode&) = delete; BPTreeNode& operator=(const BPTreeNode&) = delete; BPTreeNode(bool leaf) : num_items_(0), leaf_(leaf) { } using Layout = BPNodeLayout; public: using KeyT = T; void InitSingle(T key) { SetKey(0, key); num_items_ = 1; } KeyT Key(unsigned index) const { KeyT res; memcpy(&res, Layout::KeyPtr(index, this), sizeof(KeyT)); return res; } void SetKey(size_t index, KeyT item) { uint8_t* slot = Layout::KeyPtr(index, this); memcpy(slot, &item, sizeof(KeyT)); } bool IsLeaf() const { return leaf_; } struct SearchResult { uint16_t index; bool found; }; // Searches for key in the node using binary search. // Returns SearchResult with index of the smallest key for which comp(key) >=0. // comp: is a three way comparator. template SearchResult BSearch(Comp&& comp) const; void Split(BPTreeNode* right, KeyT* median); unsigned NumItems() const { return num_items_; } unsigned AvailableSlotCount() const { return MaxItems() - num_items_; } unsigned MaxItems() const { return IsLeaf() ? Layout::kMaxLeafKeys : Layout::kMaxInnerKeys; } unsigned MinItems() const { return IsLeaf() ? Layout::kMinLeafKeys : Layout::kMinInnerKeys; } // Returns the overall number of iterms for a subtree rooted at this node. // Equals to NumItems() for leaf nodes and GetInnerTreeCount() for inner nodes. uint32_t TreeCount() const { return IsLeaf() ? NumItems() : GetInnerTreeCount(); } void ShiftRight(unsigned index); void ShiftLeft(unsigned index, bool child_step_right = false); void LeafEraseRight() { assert(IsLeaf() && num_items_ > 0); --num_items_; } // Inserts item into a leaf node. // Assumes: the node is IsLeaf() and has some space. void LeafInsert(unsigned index, KeyT item) { assert(IsLeaf() && NumItems() < MaxItems()); InsertItem(index, item); } void Validate(KeyT upper_bound) const; // // Below is the inner node API // BPTreeNode* Child(unsigned i) { BPTreeNode* res; memcpy(&res, Layout::ChildrenStart(this) + sizeof(BPTreeNode*) * i, sizeof(BPTreeNode*)); return res; } const BPTreeNode* Child(unsigned i) const { BPTreeNode* res; memcpy(&res, Layout::ChildrenStart(this) + sizeof(BPTreeNode*) * i, sizeof(BPTreeNode*)); return res; } void SetChild(unsigned i, BPTreeNode* child) { memcpy(Layout::ChildrenStart(this) + sizeof(BPTreeNode*) * i, &child, sizeof(BPTreeNode*)); } // TODO: instead of storing counts at nodes we could keep at parent level // along the children array. Unfortunately, this complicates implementation of the tree, // so we will do it after the whole functionality is completed. uint32_t GetChildTreeCount(unsigned i) { return Child(i)->TreeCount(); } void SetChildTreeCount(unsigned i, uint32_t cnt) { Child(i)->SetTreeCount(cnt); } void IncreaseTreeCount(int32_t delta) { uint32_t cnt = GetInnerTreeCount(); cnt += delta; memcpy(Layout::TreeCountPtr(this), &cnt, sizeof(uint32_t)); } // Rebalance a full child at position pos, at which we tried to insert at insert_pos. // Returns the node and the position to insert into if rebalancing succeeded. // Returns nullptr if rebalancing did not succeed. std::pair RebalanceChild(unsigned pos, unsigned insert_pos); // We do not update tree count and it is done on the caller side. // Inserts item into a inner node at position pos and adds `child` at position pos+1. void InnerInsert(unsigned index, KeyT item, BPTreeNode* child) { InsertItem(index, item); SetChild(index + 1, child); } // Tries to merge the child at position pos with its sibling. // If we did not succeed to merge, we try to rebalance. // Returns retired BPTreeNode* if children got merged and this parent node's children // count decreased, otherwise, we return nullptr (rebalanced). BPTreeNode* MergeOrRebalanceChild(unsigned pos); uint32_t DEBUG_TreeCount() const { uint32_t res = NumItems(); if (!IsLeaf()) { for (unsigned i = 0; i <= NumItems(); ++i) { res += Child(i)->DEBUG_TreeCount(); } } return res; } private: void SetTreeCount(uint32_t cnt) { assert(!IsLeaf()); memcpy(Layout::TreeCountPtr(this), &cnt, sizeof(uint32_t)); } void RebalanceChildToLeft(unsigned child_pos, unsigned count); void RebalanceChildToRight(unsigned child_pos, unsigned count); void MergeFromRight(KeyT key, BPTreeNode* right); void InsertItem(unsigned index, KeyT item) { assert(index <= num_items_); ShiftRight(index); SetKey(index, item); } uint32_t GetInnerTreeCount() const { assert(!IsLeaf()); uint32_t res; memcpy(&res, Layout::TreeCountPtr(this), sizeof(uint32_t)); return res; } struct { uint32_t num_items_ : 7; uint32_t leaf_ : 1; uint32_t : 24; }; }; // Contains parent/index pairs. Meaning that node0->Child(index0) == node1. template class BPTreePath { static constexpr unsigned kMaxDepth = 16; public: void Push(BPTreeNode* node, unsigned pos) { assert(depth_ < kMaxDepth); assert(depth_ == 0 || !record_[depth_ - 1].node->IsLeaf()); record_[depth_].node = node; record_[depth_].pos = pos; depth_++; } unsigned Depth() const { return depth_; } void Clear() { depth_ = 0; } bool Empty() const { return depth_ == 0; } std::pair*, unsigned> Last() const { assert(depth_ > 0u); return {record_[depth_ - 1].node, record_[depth_ - 1].pos}; } BPTreeNode* Node(unsigned i) const { assert(i < depth_); return record_[i].node; } unsigned Position(unsigned i) const { assert(i < depth_); return record_[i].pos; } void Pop() { assert(depth_ > 0u); depth_--; } bool HasValidTerminal() const { return depth_ > 0u && Last().second < Last().first->NumItems(); } T Terminal() const { assert(Last().second < Last().first->NumItems()); return Last().first->Key(Last().second); } /// @brief Returns the rank of the path's terminal item. /// Requires that the path is valid and has a terminal item. uint32_t Rank() const; /// @brief Advances the path to the next item. /// @return true if succeeded, false if reached the end. bool Next(); /// @brief Advances the path to the previous item. /// @return true if succeeded, false if reached the end. bool Prev(); // Extend the path to the leaf by always taking the rightmost child. void DigRight(); private: struct Record { BPTreeNode* node; unsigned pos; }; std::array record_; unsigned depth_ = 0; }; // Returns the position of the first item whose key is greater or equal than key. // if all items are smaller than key, returns num_items_. template template auto BPTreeNode::BSearch(Comp&& cmp_op) const -> SearchResult { uint16_t lo = 0; uint16_t hi = num_items_; assert(hi > 0); // optimization: check the last item first. int cmp_res = cmp_op(Key(hi - 1)); if (cmp_res >= 0) { return cmp_res > 0 ? SearchResult{.index = hi, .found = false} : SearchResult{.index = uint16_t(hi - 1), .found = true}; } // key < Key(hi - 1) --hi; while (lo < hi) { uint16_t mid = (lo + hi) >> 1; assert(mid < hi); KeyT item = Key(mid); int cmp_res = cmp_op(item); if (cmp_res == 0) { return SearchResult{.index = mid, .found = true}; } if (cmp_res < 0) { hi = mid; } else { lo = mid + 1; // we never return indices upto mid because they are strictly less than key. } } assert(lo == hi); return {.index = hi, .found = false}; } template void BPTreeNode::ShiftRight(unsigned index) { unsigned num_items_to_shift = num_items_ - index; if (num_items_to_shift > 0) { uint8_t* ptr = Layout::KeyPtr(index, this); memmove(ptr + Layout::kKeySize, ptr, num_items_to_shift * Layout::kKeySize); if (!IsLeaf()) { uint8_t* src = Layout::ChildrenStart(this) + index * sizeof(BPTreeNode*); uint8_t* dest = src + sizeof(BPTreeNode*); memmove(dest, src, (num_items_to_shift + 1) * sizeof(BPTreeNode*)); } } num_items_++; } template void BPTreeNode::ShiftLeft(unsigned index, bool child_step_right) { assert(index < num_items_); unsigned num_items_to_shift = num_items_ - index - 1; if (num_items_to_shift > 0) { memmove(Layout::KeyPtr(index, this), Layout::KeyPtr(index + 1, this), num_items_to_shift * Layout::kKeySize); if (!leaf_) { index += unsigned(child_step_right); num_items_to_shift = num_items_ - index; if (num_items_to_shift > 0) { uint8_t* dest = Layout::ChildrenStart(this) + index * sizeof(BPTreeNode*); uint8_t* src = dest + sizeof(BPTreeNode*); memmove(dest, src, num_items_to_shift * sizeof(BPTreeNode*)); } } } num_items_--; } /*** * Rebalances the (full) child at position pos with its sibling. `this` node is an inner node. * It first tried to rebalance (move items) from the full child to its left sibling. If the left * sibling does not have enough space, it tries to rebalance to the right sibling. The caller * passes the original position of the item it tried to insert into the full child. In case the * rebalance succeeds the function returns the new node and the position to insert into. Otherwise, * it returns result.first == nullptr. */ template std::pair*, unsigned> BPTreeNode::RebalanceChild(unsigned pos, unsigned insert_pos) { unsigned to_move = 0; BPTreeNode* node = Child(pos); if (pos > 0) { BPTreeNode* left = Child(pos - 1); unsigned dest_free = left->AvailableSlotCount(); if (dest_free > 0) { // We bias rebalancing based on the position being inserted. If we're // inserting at the end of the right node then we bias rebalancing to // fill up the left node. if (insert_pos == node->NumItems()) { to_move = dest_free; assert(to_move < node->NumItems()); } else if (dest_free > 1) { // we move less than left free capacity which leaves as some space in the node. to_move = dest_free / 2; } if (to_move) { unsigned dest_old_count = left->NumItems(); RebalanceChildToLeft(pos, to_move); assert(node->AvailableSlotCount() == to_move); if (insert_pos < to_move) { assert(left->AvailableSlotCount() > 0u); // we did not fill up the left node. insert_pos = dest_old_count + insert_pos + 1; // +1 because we moved the separator. node = left; } else { insert_pos -= to_move; } return {node, insert_pos}; } } } if (pos < NumItems()) { BPTreeNode* right = Child(pos + 1); unsigned dest_free = right->AvailableSlotCount(); if (dest_free > 0) { if (insert_pos == 0) { to_move = dest_free; assert(to_move < node->NumItems()); } else if (dest_free > 1) { to_move = dest_free / 2; } if (to_move) { RebalanceChildToRight(pos, to_move); if (insert_pos > node->NumItems()) { insert_pos -= (node->NumItems() + 1); node = right; } return {node, insert_pos}; } } } return {nullptr, 0}; } template void BPTreeNode::RebalanceChildToLeft(unsigned child_pos, unsigned count) { assert(child_pos > 0u); BPTreeNode* src = Child(child_pos); BPTreeNode* dest = Child(child_pos - 1); assert(src->NumItems() >= count); assert(count >= 1u); assert(dest->AvailableSlotCount() >= count); unsigned dest_items = dest->NumItems(); // Move the delimiting value to the left node. dest->SetKey(dest_items, Key(child_pos - 1)); // Copy src keys [0, count-1] to dest keys [dest_items+1, dest_items+count]. for (unsigned i = 1; i < count; ++i) { dest->SetKey(dest_items + i, src->Key(i - 1)); } SetKey(child_pos - 1, src->Key(count - 1)); // Shift the values in the right node to their correct position. for (unsigned i = count; i < src->NumItems(); ++i) { src->SetKey(i - count, src->Key(i)); } if (!src->IsLeaf()) { // Move the child pointers from the right to the left node. uint32_t src_move_count = 0; for (unsigned i = 0; i < count; ++i) { src_move_count += src->GetChildTreeCount(i); dest->SetChild(1 + dest->NumItems() + i, src->Child(i)); } uint32_t dest_tree_count = GetChildTreeCount(child_pos - 1); uint32_t src_tree_count = GetChildTreeCount(child_pos); SetChildTreeCount(child_pos - 1, dest_tree_count + src_move_count + count); SetChildTreeCount(child_pos, src_tree_count - src_move_count - count); for (unsigned i = count; i <= src->NumItems(); ++i) { src->SetChild(i - count, src->Child(i)); src->SetChild(i, NULL); } } // Fixup the counts on the src and dest nodes. dest->num_items_ += count; src->num_items_ -= count; } template void BPTreeNode::RebalanceChildToRight(unsigned child_pos, unsigned count) { assert(child_pos < NumItems()); BPTreeNode* src = Child(child_pos); BPTreeNode* dest = Child(child_pos + 1); assert(src->NumItems() >= count); assert(count >= 1u); assert(dest->AvailableSlotCount() >= count); unsigned dest_items = dest->NumItems(); assert(dest_items > 0u); // Shift the values in the right node to their correct position. for (int i = dest_items - 1; i >= 0; --i) { dest->SetKey(i + count, dest->Key(i)); } // Move the delimiting value to the left node and the new delimiting value // from the right node. KeyT new_delim = src->Key(src->NumItems() - count); for (unsigned i = 1; i < count; ++i) { unsigned src_id = src->NumItems() - count + i; dest->SetKey(i - 1, src->Key(src_id)); } // Move parent's delimiter to destination and update it with new delimiter. dest->SetKey(count - 1, Key(child_pos)); SetKey(child_pos, new_delim); if (!src->IsLeaf()) { // Shift child pointers in the right node to their correct position. for (int i = dest_items; i >= 0; --i) { dest->SetChild(i + count, dest->Child(i)); } // Move child pointers from the left node to the right. uint32_t src_move_count = 0; for (unsigned i = 0; i < count; ++i) { unsigned src_id = src->NumItems() - (count - 1) + i; src_move_count += src->Child(src_id)->TreeCount(); dest->SetChild(i, src->Child(src_id)); src->SetChild(src_id, NULL); } uint32_t dest_tree_count = GetChildTreeCount(child_pos + 1); uint32_t src_tree_count = GetChildTreeCount(child_pos); SetChildTreeCount(child_pos + 1, dest_tree_count + src_move_count + count); SetChildTreeCount(child_pos, src_tree_count - src_move_count - count); } // Fixup the counts on the src and dest nodes. dest->num_items_ += count; src->num_items_ -= count; } template BPTreeNode* BPTreeNode::MergeOrRebalanceChild(unsigned pos) { BPTreeNode* node = Child(pos); BPTreeNode* left = nullptr; assert(NumItems() >= 1u); assert(node->NumItems() < node->MinItems()); if (pos > 0) { left = Child(pos - 1); if (left->NumItems() + 1 + node->NumItems() <= left->MaxItems()) { left->MergeFromRight(Key(pos - 1), node); ShiftLeft(pos - 1, true); return node; } } if (pos < NumItems()) { BPTreeNode* right = Child(pos + 1); if (node->NumItems() + 1 + right->NumItems() <= right->MaxItems()) { node->MergeFromRight(Key(pos), right); ShiftLeft(pos, true); return right; } // Try rebalancing with our right sibling. // TODO: don't perform rebalancing if // we deleted the first element from node and the node is not // empty. This is a small optimization for the common pattern of deleting // from the front of the tree. if (true) { unsigned to_move = (right->NumItems() - node->NumItems()) / 2; assert(to_move < right->NumItems()); RebalanceChildToLeft(pos + 1, to_move); return nullptr; } } assert(left); if (left) { // Try rebalancing with our left sibling. // TODO: don't perform rebalancing if we deleted the last element from node and the // node is not empty. This is a small optimization for the common pattern of deleting // from the back of the tree. if (true) { unsigned to_move = (left->NumItems() - node->NumItems()) / 2; assert(to_move < left->NumItems()); RebalanceChildToRight(pos - 1, to_move); return nullptr; } } return nullptr; } // splits the node into two nodes. The left node is the current node and the right node is // is filled with the right half of the items. The median key is returned in *median. template void BPTreeNode::Split(BPTreeNode* right, T* median) { unsigned mid = num_items_ / 2; *median = Key(mid); right->leaf_ = leaf_; right->num_items_ = num_items_ - (mid + 1); memmove(Layout::KeyPtr(0, right), Layout::KeyPtr(mid + 1, this), right->num_items_ * Layout::kKeySize); if (!IsLeaf()) { uint32_t right_subtree_count = right->num_items_; for (size_t i = 0; i <= right->num_items_; i++) { BPTreeNode* child = Child(mid + 1 + i); right_subtree_count += child->TreeCount(); right->SetChild(i, child); } right->SetTreeCount(right_subtree_count); IncreaseTreeCount(-(right_subtree_count + 1)); } num_items_ = mid; } template void BPTreeNode::MergeFromRight(KeyT key, BPTreeNode* right) { assert(NumItems() + 1 + right->NumItems() <= MaxItems()); unsigned dest_items = NumItems(); SetKey(dest_items, key); for (unsigned i = 0; i < right->NumItems(); ++i) { SetKey(dest_items + 1 + i, right->Key(i)); } if (!IsLeaf()) { for (unsigned i = 0; i <= right->NumItems(); ++i) { SetChild(dest_items + 1 + i, right->Child(i)); } IncreaseTreeCount(right->TreeCount() + 1); } num_items_ += 1 + right->NumItems(); right->num_items_ = 0; } template uint32_t BPTreePath::Rank() const { uint32_t rank = 0; unsigned bound = Depth(); for (unsigned i = 0; i < bound; ++i) { auto* node = Node(i); unsigned pos = Position(i); if (!node->IsLeaf()) { unsigned delta = (i == bound - 1) ? 1 : 0; for (unsigned j = 0; j < pos + delta; ++j) { rank += node->Child(j)->TreeCount(); } } rank += pos; } return rank; } template bool BPTreePath::Next() { assert(depth_ > 0); BPTreeNode* node = Last().first; // The data in BPTree is stored in both the leaf nodes and the inner nodes. if (node->IsLeaf()) { ++record_[depth_ - 1].pos; if (record_[depth_ - 1].pos < node->NumItems()) { return true; } // Advance to the next item, which is Key(i) in some ascendent of the subtree with // root Child(i). i in that case must be less than NumItems(). // Note, that subtree Child(i) in a inner node is located before Key(i). do { Pop(); } while (depth_ > 0 && Position(depth_ - 1) == Node(depth_ - 1)->NumItems()); // we either point now on separator Key(i) in the parent node or we finished the tree. return depth_ > 0; } // We are in the inner node after the ascent from the leaf node. We need to advance to the next // Child and dig left. assert(!node->IsLeaf()); assert(record_[depth_ - 1].pos < node->NumItems()); // we are in the inner node pointing to the separator. // now we need to advance to the next child and dig to the leftmost leaf. record_[depth_ - 1].pos++; do { node = node->Child(record_[depth_ - 1].pos); Push(node, 0); } while (!node->IsLeaf()); return true; } template bool BPTreePath::Prev() { assert(depth_ > 0); auto* node = record_[depth_ - 1].node; if (node->IsLeaf()) { /* node / \ l r We must go left (decrement pos), and if there is no left, we must go up until we can go left. */ while (record_[depth_ - 1].pos == 0) { Pop(); if (depth_ == 0) { return false; } } assert(depth_ > 0 && record_[depth_ - 1].pos > 0); // we finished backtracking from child(i+1) or stayed in the leaf. // either way stop at the next key on the left. --record_[depth_ - 1].pos; return true; } DigRight(); return true; } template void BPTreePath::DigRight() { assert(depth_ > 0); BPTreeNode* node = Last().first; assert(!node->IsLeaf()); // we are in the inner node pointing to the separator. // we now must explore the left subtree which is located under the same index as the separator. // we go far-right in the left subtree. do { node = node->Child(record_[depth_ - 1].pos); Push(node, node->NumItems()); } while (!node->IsLeaf()); // we reached the leaf node, fix the position to point to the last key. assert(record_[depth_ - 1].node->IsLeaf()); --record_[depth_ - 1].pos; } } // namespace detail } // namespace dfly ================================================ FILE: src/core/detail/gen_utils.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly { inline std::string GetRandomHex(absl::InsecureBitGen& gen, size_t len, size_t len_deviation = 0) { static_assert(std::is_same::value); if (len_deviation) { len += (gen() % len_deviation); } std::string res(len, '\0'); size_t indx = 0; for (size_t i = 0; i < len / 16; ++i) { // 2 chars per byte absl::numbers_internal::FastHexToBufferZeroPad16(gen(), res.data() + indx); indx += 16; } if (indx < res.size()) { char buf[32]; absl::numbers_internal::FastHexToBufferZeroPad16(gen(), buf); for (unsigned j = 0; indx < res.size(); indx++, j++) { res[indx] = buf[j]; } } return res; } } // namespace dfly ================================================ FILE: src/core/detail/listpack.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/detail/listpack.h" #include "base/logging.h" namespace dfly { namespace detail { using namespace std; QList::Entry ListPack::GetEntry(uint8_t* pos) { unsigned int slen; long long lval; uint8_t* vstr = lpGetValue(pos, &slen, &lval); return vstr ? QList::Entry(reinterpret_cast(vstr), slen) : QList::Entry(lval); } string ListPack::Pop(QList::Where where) { uint8_t* pos = GetFirst(where); DCHECK(pos); string res = GetEntry(pos).to_string(); lp_ = lpDelete(lp_, pos, nullptr); return res; } void ListPack::Push(string_view value, QList::Where where) { if (where == QList::HEAD) { lp_ = lpPrepend(lp_, (unsigned char*)value.data(), value.size()); } else { lp_ = lpAppend(lp_, (unsigned char*)value.data(), value.size()); } } string ListPack::First(QList::Where where) const { uint8_t* pos = GetFirst(where); DCHECK(pos); return GetEntry(pos).to_string(); } std::optional ListPack::At(long index) const { uint8_t* pos = lpSeek(lp_, index); if (!pos) return nullopt; return GetEntry(pos).to_string(); } vector ListPack::Pos(string_view element, uint32_t rank, uint32_t count, uint32_t max_len, QList::Where where) const { DCHECK_GT(rank, 0u); vector matches; uint8_t* p = GetFirst(where); unsigned index = 0; while (p && (max_len == 0 || index < max_len)) { if (GetEntry(p) == element) { if (rank == 1) { size_t sz = lpLength(lp_); auto k = (where == QList::HEAD) ? index : sz - index - 1; matches.push_back(k); if (count && matches.size() >= count) break; } else { rank--; } } index++; p = (where == QList::HEAD) ? lpNext(lp_, p) : lpPrev(lp_, p); } return matches; } uint8_t* ListPack::Find(std::string_view elem) const { uint8_t* p = lpFirst(lp_); while (p) { if (GetEntry(p) == elem) { return p; } p = lpNext(lp_, p); } return nullptr; } unsigned ListPack::Remove(const CollectionEntry& elem, unsigned count, QList::Where where) { unsigned removed = 0; auto is_match = [&](const QList::Entry& entry) { return elem.is_int() ? entry.is_int() && entry.ival() == elem.ival() : entry == elem.view(); }; uint8_t* p = GetFirst(where); while (p) { if (is_match(GetEntry(p))) { // lpDelete returns pointer to the element AFTER the deleted one (toward tail) lp_ = lpDelete(lp_, p, &p); if (where == QList::TAIL) { // Iterating backward (from TAIL): need to get the previous element if (p) { p = lpPrev(lp_, p); } else { // Deleted the tail element, lpDelete returned nullptr (no element after tail). // We need to continue from the new tail to keep moving towards HEAD. p = lpLast(lp_); } } // For HEAD direction, 'p' already points to the next element to check removed++; if (count && removed == count) break; continue; } p = (where == QList::HEAD) ? lpNext(lp_, p) : lpPrev(lp_, p); } return removed; } } // namespace detail } // namespace dfly ================================================ FILE: src/core/detail/listpack.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "core/qlist.h" extern "C" { #include "redis/listpack.h" } namespace dfly { namespace detail { // A listpack wrapper that provides basic list operations. // Unfortunately, we already have a listpack wrapper in core/detail/listpack_wrap.h but // it's more map oriented and doesn't provide the basic list operations we need here. // TODO: to unify both wrappers into one. class ListPack { public: explicit ListPack(uint8_t* lp = nullptr) : lp_(lp) { } size_t Size() const { return lpLength(lp_); } // Removes and returns an element from the specified end (HEAD or TAIL). std::string Pop(QList::Where where); // Adds an element to the specified end (HEAD or TAIL). void Push(std::string_view value, QList::Where where); // Returns the first element from the specified end without removing it. std::string First(QList::Where where) const; // Returns the element at the specified index, or std::nullopt if out of bounds. std::optional At(long index) const; // Finds positions of an element matching the given criteria. std::vector Pos(std::string_view element, uint32_t rank, uint32_t count, uint32_t max_len, QList::Where where) const; uint8_t* Find(std::string_view elem) const; uint8_t* Seek(long index) const { return lpSeek(lp_, index); } // Inserts an element before or after the specified pivot element. void Insert(uint8_t* pivot, std::string_view elem, QList::InsertOpt insert_opt) { int where = (insert_opt == QList::BEFORE) ? LP_BEFORE : LP_AFTER; lp_ = lpInsertString(lp_, (unsigned char*)elem.data(), elem.size(), pivot, where, nullptr); } // Removes up to count occurrences of elem from the specified direction. unsigned Remove(const CollectionEntry& elem, unsigned count, QList::Where where); // Replaces the element at the specified index with a new value. void Replace(uint8_t* pos, std::string_view elem) { lp_ = lpReplace(lp_, &pos, (unsigned char*)elem.data(), elem.size()); } // Removes count elements starting from the specified index. void Erase(long start, long count) { lp_ = lpDeleteRange(lp_, start, count); } // Returns the raw listpack pointer. uint8_t* GetPointer() const { return lp_; } size_t BytesSize() const { return lpBytes(lp_); } private: static CollectionEntry GetEntry(uint8_t* pos); uint8_t* GetFirst(QList::Where where) const { return (where == QList::HEAD) ? lpFirst(lp_) : lpLast(lp_); } uint8_t* lp_; }; } // namespace detail } // namespace dfly ================================================ FILE: src/core/detail/listpack_wrap.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/detail/listpack_wrap.h" #include "base/logging.h" extern "C" { #include "redis/listpack.h" } namespace dfly::detail { ListpackWrap::Iterator::Iterator(uint8_t* lp, uint8_t* ptr, IntBuf& intbuf) : lp_{lp}, ptr_{ptr}, next_ptr_{nullptr}, intbuf_(intbuf) { static_assert(sizeof(intbuf_[0]) >= LP_INTBUF_SIZE); // to avoid header dependency Read(); } ListpackWrap::Iterator& ListpackWrap::Iterator::operator++() { ptr_ = next_ptr_; Read(); return *this; } void ListpackWrap::Iterator::Read() { if (!ptr_) return; key_v_ = GetView(ptr_, intbuf_[0]); next_ptr_ = lpNext(lp_, ptr_); value_v_ = GetView(next_ptr_, intbuf_[1]); next_ptr_ = lpNext(lp_, next_ptr_); } ListpackWrap::~ListpackWrap() { DCHECK(!dirty_); } ListpackWrap ListpackWrap::WithCapacity(size_t capacity) { return ListpackWrap{lpNew(capacity)}; } uint8_t* ListpackWrap::GetPointer() { dirty_ = false; return lp_; } ListpackWrap::Iterator ListpackWrap::Find(std::string_view key) const { if (size() == 0) return end(); uint8_t* ptr = lpFind(lp_, lpFirst(lp_), (unsigned char*)key.data(), key.size(), 1); return Iterator{lp_, ptr, intbuf_}; } bool ListpackWrap::Delete(std::string_view key) { if (size() == 0) return false; uint8_t* ptr = lpFind(lp_, lpFirst(lp_), (unsigned char*)key.data(), key.size(), 1); if (ptr == nullptr) return false; lp_ = lpDeleteRangeWithEntry(lp_, &ptr, 2); dirty_ = true; return true; } bool ListpackWrap::Insert(std::string_view key, std::string_view value, bool skip_exists) { uint8_t* vptr; uint8_t* fptr = lpFirst(lp_); uint8_t* fsrc = key.empty() ? lp_ : (uint8_t*)key.data(); // if we vsrc is NULL then lpReplace will delete the element, which is not what we want. // therefore, for an empty val we set it to some other valid address so that lpReplace // will do the right thing and encode empty string instead of deleting the element. uint8_t* vsrc = value.empty() ? lp_ : (uint8_t*)value.data(); bool updated = false; if (fptr) { fptr = lpFind(lp_, fptr, fsrc, key.size(), 1); if (fptr) { if (skip_exists) return false; // Grab pointer to the value (fptr points to the field) vptr = lpNext(lp_, fptr); // Replace value lp_ = lpReplace(lp_, &vptr, vsrc, value.size()); DCHECK_EQ(0u, lpLength(lp_) % 2); dirty_ = true; updated = true; } } if (!updated) { // Push new field/value pair onto the tail of the listpack. // TODO: we should at least allocate once for both elements lp_ = lpAppend(lp_, fsrc, key.size()); lp_ = lpAppend(lp_, vsrc, value.size()); dirty_ = true; } return !updated; } size_t ListpackWrap::size() const { return lpLength(lp_) / 2; } ListpackWrap::Iterator ListpackWrap::begin() const { return Iterator{lp_, lpFirst(lp_), intbuf_}; } ListpackWrap::Iterator ListpackWrap::end() const { return Iterator{lp_, nullptr, intbuf_}; } size_t ListpackWrap::UsedBytes() const { return lpBytes(lp_); } std::string_view ListpackWrap::GetView(uint8_t* lp_it, uint8_t int_buf[]) { int64_t ele_len = 0; uint8_t* elem = lpGet(lp_it, &ele_len, int_buf); DCHECK(elem); return std::string_view{reinterpret_cast(elem), size_t(ele_len)}; } bool ListpackWrap::Iterator::operator==(const Iterator& other) const { return lp_ == other.lp_ && ptr_ == other.ptr_; } } // namespace dfly::detail ================================================ FILE: src/core/detail/listpack_wrap.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly::detail { // Wrapper around map data structure based on listpack struct ListpackWrap { private: using IntBuf = uint8_t[2][24]; public: ~ListpackWrap(); struct Iterator { using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = std::pair; using reference = value_type; using pointer = value_type*; Iterator(uint8_t* lp, uint8_t* ptr, IntBuf& intbuf); Iterator& operator++(); value_type operator*() const { return {key_v_, value_v_}; } bool operator==(const Iterator& other) const; bool operator!=(const Iterator& other) const { return !(operator==(other)); } private: void Read(); // Read next entry at ptr and determine next_ptr uint8_t *lp_ = nullptr, *ptr_ = nullptr, *next_ptr_ = nullptr; std::string_view key_v_, value_v_; IntBuf& intbuf_; }; explicit ListpackWrap(uint8_t* lp) : lp_{lp} { } // Create listpack with capacity static ListpackWrap WithCapacity(size_t capacity); uint8_t* GetPointer(); // Get new updated pointer Iterator Find(std::string_view key) const; // Linear search bool Delete(std::string_view key); bool Insert(std::string_view key, std::string_view value, bool skip_exists); Iterator begin() const; Iterator end() const; size_t size() const; // number of entries size_t UsedBytes() const; // Get view from raw listpack iterator static std::string_view GetView(uint8_t* lp_it, uint8_t int_buf[]); private: uint8_t* lp_; // the listpack itself mutable IntBuf intbuf_; // buffer for integers decoded to strings bool dirty_ = false; // whether lp_ was updated, but never retrieved with GetPointer }; } // namespace dfly::detail ================================================ FILE: src/core/detail/stateless_allocator.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #pragma once #include #include "base/pmr/memory_resource.h" namespace dfly { namespace detail { inline thread_local PMR_NS::memory_resource* tl_mr = nullptr; } // namespace detail template class StatelessAllocatorBase { public: using value_type = T; using size_type = std::size_t; using difference_type = std::ptrdiff_t; using is_always_equal = std::true_type; template void construct(U* __p, _Args&&... __args) { ::new (static_cast(__p)) U(std::forward<_Args>(__args)...); } static value_type* allocate(size_type n) { static_assert( std::is_empty_v, "StatelessAllocator must not contain state, so it can use empty base optimization"); void* ptr = Impl::resource()->allocate(n * sizeof(value_type), alignof(value_type)); return static_cast(ptr); } static void deallocate(value_type* ptr, size_type n) noexcept { Impl::resource()->deallocate(ptr, n * sizeof(value_type), alignof(value_type)); } }; template class StatelessAllocator : public StatelessAllocatorBase> { public: StatelessAllocator() noexcept { assert(detail::tl_mr != nullptr); } template StatelessAllocator(const StatelessAllocator&) noexcept { // NOLINT } static PMR_NS::memory_resource* resource() { return detail::tl_mr; } }; template bool operator==(const StatelessAllocator&, const StatelessAllocator&) noexcept { return true; } template bool operator!=(const StatelessAllocator&, const StatelessAllocator&) noexcept { return false; } inline void InitTLStatelessAllocMR(PMR_NS::memory_resource* mr) { detail::tl_mr = mr; } inline void CleanupStatelessAllocMR() { detail::tl_mr = nullptr; } } // namespace dfly ================================================ FILE: src/core/dfly_core_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #include #ifdef USE_PCRE2 #define PCRE2_CODE_UNIT_WIDTH 8 #include #endif #ifdef USE_RE2 #include #endif #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/glob_matcher.h" #include "core/huff_coder.h" #include "core/intent_lock.h" #include "core/tx_queue.h" namespace dfly { using namespace std; std::random_device rd; static string GetRandomHex(size_t len) { std::string res(len, '\0'); size_t indx = 0; for (; indx < len; indx += 16) { // 2 chars per byte absl::numbers_internal::FastHexToBufferZeroPad16(rd(), res.data() + indx); } if (indx < len) { char buf[24]; absl::numbers_internal::FastHexToBufferZeroPad16(rd(), buf); for (unsigned j = 0; indx < len; indx++, j++) { res[indx] = buf[j]; } } return res; } extern int stringmatchlen(const char* pattern, int patternLen, const char* string, int stringLen, int nocase); class TxQueueTest : public ::testing::Test { protected: TxQueueTest() { } uint64_t Pop() { if (pq_.Empty()) return uint64_t(-1); TxQueue::ValueType val = pq_.Front(); pq_.PopFront(); return std::get(val); } TxQueue pq_; }; TEST_F(TxQueueTest, Basic) { pq_.Insert(4); pq_.Insert(3); pq_.Insert(2); unsigned cnt = 0; auto head = pq_.Head(); auto it = head; do { ++cnt; it = pq_.Next(it); } while (it != head); EXPECT_EQ(3, cnt); ASSERT_EQ(2, Pop()); ASSERT_EQ(3, Pop()); ASSERT_EQ(4, Pop()); ASSERT_TRUE(pq_.Empty()); EXPECT_EQ(TxQueue::kEnd, pq_.Head()); pq_.Insert(10); ASSERT_EQ(10, Pop()); } class IntentLockTest : public ::testing::Test { protected: IntentLock lk_; }; TEST_F(IntentLockTest, Basic) { ASSERT_TRUE(lk_.Acquire(IntentLock::SHARED)); ASSERT_FALSE(lk_.Acquire(IntentLock::EXCLUSIVE)); lk_.Release(IntentLock::EXCLUSIVE); ASSERT_FALSE(lk_.Check(IntentLock::EXCLUSIVE)); lk_.Release(IntentLock::SHARED); ASSERT_TRUE(lk_.Check(IntentLock::EXCLUSIVE)); } class StringMatchTest : public ::testing::Test { protected: // wrapper around stringmatchlen with stringview arguments bool MatchLen(string_view pattern, string_view str, bool nocase) { GlobMatcher matcher(pattern, !nocase); return matcher.Matches(str); } }; TEST_F(StringMatchTest, Glob2Regex) { EXPECT_EQ(GlobMatcher::Glob2Regex(""), ""); EXPECT_EQ(GlobMatcher::Glob2Regex("*"), ".*"); EXPECT_EQ(GlobMatcher::Glob2Regex("\\*"), "\\*"); EXPECT_EQ(GlobMatcher::Glob2Regex("\\?"), "\\?"); EXPECT_EQ(GlobMatcher::Glob2Regex("[abc]"), "[abc]"); EXPECT_EQ(GlobMatcher::Glob2Regex("[^abc]"), "[^abc]"); EXPECT_EQ(GlobMatcher::Glob2Regex("h\\[^|"), "h\\[\\^\\|"); EXPECT_EQ(GlobMatcher::Glob2Regex("[$?^]a"), "[$?^]a"); EXPECT_EQ(GlobMatcher::Glob2Regex("[^]a"), ".a"); EXPECT_EQ(GlobMatcher::Glob2Regex("[]a"), "[]a"); EXPECT_EQ(GlobMatcher::Glob2Regex("\\d"), "d"); EXPECT_EQ(GlobMatcher::Glob2Regex("[\\d]"), "[\\\\d]"); EXPECT_EQ(GlobMatcher::Glob2Regex("abc\\"), "abc\\\\"); EXPECT_EQ(GlobMatcher::Glob2Regex("[\\]]"), "[\\]]"); reflex::Matcher matcher("abc[\\\\d]e"); matcher.input("abcde"); ASSERT_TRUE(matcher.find()); } TEST_F(StringMatchTest, Basic) { EXPECT_EQ(MatchLen("", "", 0), 1); EXPECT_EQ(MatchLen("*", "", 0), 0); EXPECT_EQ(MatchLen("*", "", 1), 0); EXPECT_EQ(MatchLen("\\\\", "\\", 0), 1); EXPECT_EQ(MatchLen("h\\\\llo", "h\\llo", 0), 1); EXPECT_EQ(MatchLen("a\\bc", "ABC", 1), 1); // ExactMatch EXPECT_EQ(MatchLen("hello", "hello", 0), 1); EXPECT_EQ(MatchLen("hello", "world", 0), 0); // Wildcards EXPECT_EQ(MatchLen("*", "hello", 0), 1); EXPECT_EQ(MatchLen("*", "1234567890123456", 0), 1); EXPECT_EQ(MatchLen("h*", "hello", 0), 1); EXPECT_EQ(MatchLen("h*", "abc", 0), 0); EXPECT_EQ(MatchLen("h*o", "hello", 0), 1); EXPECT_EQ(MatchLen("hel*o*", "hello*", 0), 1); EXPECT_EQ(MatchLen("h\\*llo", "h*llo", 0), 1); // Single character wildcard EXPECT_EQ(MatchLen("h[aeiou]llo", "hello", 0), 1); EXPECT_EQ(MatchLen("h[aeiou]llo", "hallo", 0), 1); EXPECT_EQ(MatchLen("h[^aeiou]llo", "hallo", 0), 0); EXPECT_EQ(MatchLen("h[a-z]llo", "hello", 0), 1); EXPECT_EQ(MatchLen("h[A-Z]llo", "HeLLO", 1), 1); EXPECT_EQ(MatchLen("[[]", "[", 0), 1); EXPECT_EQ(MatchLen("[^]a", "xa", 0), 1); // ? EXPECT_EQ(MatchLen("h?llo", "hello", 0), 1); EXPECT_EQ(MatchLen("h??llo", "ha llo", 0), 1); EXPECT_EQ(MatchLen("h??llo", "hallo", 0), 0); EXPECT_EQ(MatchLen("h\\?llo", "hallo", 0), 0); EXPECT_EQ(MatchLen("h\\?llo", "h?llo", 0), 1); EXPECT_EQ(MatchLen("abc?", "abc\n", 0), 1); } #define TEST_STRINGMATCH(pattern, str, case_res, nocase_res) \ { \ EXPECT_EQ(int(MatchLen(pattern, str, 0)), case_res); \ EXPECT_EQ(int(MatchLen(pattern, str, 1)), nocase_res); \ } TEST_F(StringMatchTest, Special) { EXPECT_TRUE(MatchLen("h\\[^|", "h[^|", 0)); EXPECT_FALSE(MatchLen("[^", "[^", 0)); EXPECT_TRUE(MatchLen("[$?^]a", "?a", 0)); EXPECT_TRUE(MatchLen("abc[\\d]e", "abcde", 0)); EXPECT_TRUE(MatchLen("foo\\", "foo\\", 0)); /* Case sensitivity: */ TEST_STRINGMATCH("a", "a", 1, 1); TEST_STRINGMATCH("a", "A", 0, 1); TEST_STRINGMATCH("A", "A", 1, 1); TEST_STRINGMATCH("A", "a", 0, 1); TEST_STRINGMATCH("\\a", "a", 1, 1); TEST_STRINGMATCH("\\a", "A", 0, 1); TEST_STRINGMATCH("\\A", "A", 1, 1); TEST_STRINGMATCH("\\A", "a", 0, 1); TEST_STRINGMATCH("[\\a]", "a", 1, 1); // TODO: to fix this: TEST_STRINGMATCH("[\\a]", "A", 0, 1); TEST_STRINGMATCH("[\\A]", "A", 1, 1); // TODO: to fix this: TEST_STRINGMATCH("[\\A]", "a", 0, 1); /* Escaped metacharacters: */ TEST_STRINGMATCH("\\*", "*", 1, 1); TEST_STRINGMATCH("\\?", "?", 1, 1); TEST_STRINGMATCH("\\\\", "\\", 1, 1); TEST_STRINGMATCH("\\[", "[", 1, 1); TEST_STRINGMATCH("\\]", "]", 1, 1); TEST_STRINGMATCH("\\^", "^", 1, 1); TEST_STRINGMATCH("\\-", "-", 1, 1); TEST_STRINGMATCH("[\\*]", "*", 1, 1); TEST_STRINGMATCH("[\\?]", "?", 1, 1); TEST_STRINGMATCH("[\\\\]", "\\", 1, 1); TEST_STRINGMATCH("[\\[]", "[", 1, 1); TEST_STRINGMATCH("[\\]]", "]", 1, 1); TEST_STRINGMATCH("[\\^]", "^", 1, 1); TEST_STRINGMATCH("[\\-]", "-", 1, 1); /* Not special outside character classes: */ TEST_STRINGMATCH("]", "]", 1, 1); TEST_STRINGMATCH("^", "^", 1, 1); TEST_STRINGMATCH("-", "-", 1, 1); /* Not special inside character classes: */ TEST_STRINGMATCH("[*]", "*", 1, 1); TEST_STRINGMATCH("[?]", "?", 1, 1); TEST_STRINGMATCH("[[]", "[", 1, 1); /* Not special as the first character in a character class: */ TEST_STRINGMATCH("[-]", "-", 1, 1); /* Not special as range end (undocumented): */ TEST_STRINGMATCH("[+-]]", "*", 0, 0); /* but not * (below) */ TEST_STRINGMATCH("[+-]]", "^", 0, 0); /* or ^ (above) */ TEST_STRINGMATCH("[+--]", ",", 1, 1); /* ASCII range + to - includes , */ TEST_STRINGMATCH("[+--]", "*", 0, 0); /* but not * (below) */ TEST_STRINGMATCH("[+--]", ".", 0, 0); /* or . (above) */ /* And the same, but unclosed: */ TEST_STRINGMATCH("[+-]", "*", 0, 0); TEST_STRINGMATCH("[+-]", "^", 0, 0); TEST_STRINGMATCH("[+--", ",", 1, 1); TEST_STRINGMATCH("[+--", "*", 0, 0); TEST_STRINGMATCH("[+--", ".", 0, 0); /* Escaped ] alone is literal: */ TEST_STRINGMATCH("[\\]a]", "]", 1, 1); TEST_STRINGMATCH("[\\]a]", "a", 1, 1); /* Escapes at range end: */ TEST_STRINGMATCH("[+-\\\\]", ",", 1, 1); /* ASCII range + to \ includes , */ TEST_STRINGMATCH("[+-\\\\]", "*", 0, 0); /* but not * (below) */ TEST_STRINGMATCH("[+-\\]]", "*", 0, 0); /* but not * (below) */ TEST_STRINGMATCH("[+-\\]]", "^", 0, 0); /* or ^ (above) */ /* Unclosed is the same: */ TEST_STRINGMATCH("[+-\\\\", ",", 1, 1); TEST_STRINGMATCH("[+-\\\\", "*", 0, 0); TEST_STRINGMATCH("[+-\\\\", "]", 0, 0); TEST_STRINGMATCH("[+-\\]", ",", 1, 1); TEST_STRINGMATCH("[+-\\]", "*", 0, 0); TEST_STRINGMATCH("[+-\\]", "^", 0, 0); /* An incomplete escape is treated as literal backslash: */ TEST_STRINGMATCH("[+-\\", ",", 1, 1); TEST_STRINGMATCH("[+-\\", "*", 0, 0); TEST_STRINGMATCH("[+-\\", "]", 0, 0); /* Empty character class matches nothing: */ TEST_STRINGMATCH("[]", "", 0, 0); TEST_STRINGMATCH("[]", "a", 0, 0); TEST_STRINGMATCH("[", "", 0, 0); /* Unclosed is the same */ TEST_STRINGMATCH("[", "a", 0, 0); /* Empty negated character class is equivalent to pattern "?": */ TEST_STRINGMATCH("[^]", "", 0, 0); TEST_STRINGMATCH("[^]", "a", 1, 1); TEST_STRINGMATCH("[^]", "ab", 0, 0); TEST_STRINGMATCH("[^", "", 0, 0); /* Unclosed is the same */ TEST_STRINGMATCH("[^", "a", 1, 1); TEST_STRINGMATCH("[^", "ab", 0, 0); /* Unclosed character classes are not an error (undocumented): */ TEST_STRINGMATCH("[A-", "B", 0, 0); } class HuffCoderTest : public ::testing::Test { protected: HuffmanEncoder encoder_; HuffmanDecoder decoder_; string error_msg_; const string_view good_table_{ "\x1b\x10\xd8\n\n\x19\xc6\x0c\xc3\x30\x0c\x43\x1e\x93\xe4\x11roB\xf6\xde\xbb\x18V\xc2Zk\x03"sv}; }; TEST_F(HuffCoderTest, Load) { string data("bad"); ASSERT_FALSE(encoder_.Load(data, &error_msg_)); data = good_table_; ASSERT_TRUE(encoder_.Load(data, &error_msg_)) << error_msg_; data.append("foo"); encoder_.Reset(); ASSERT_FALSE(encoder_.Load(data, &error_msg_)); } TEST_F(HuffCoderTest, Encode) { ASSERT_TRUE(encoder_.Load(good_table_, &error_msg_)) << error_msg_; EXPECT_EQ(1, encoder_.GetNBits('x')); EXPECT_EQ(3, encoder_.GetNBits(':')); EXPECT_EQ(5, encoder_.GetNBits('2')); EXPECT_EQ(5, encoder_.GetNBits('3')); string data("x:23xx"); array dest; uint32_t dest_size = dest.size(); ASSERT_TRUE(encoder_.Encode(data, dest.data(), &dest_size, &error_msg_)); ASSERT_EQ(3, dest_size); // testing small destination buffer. data = "3333333333333333333"; dest_size = 16; EXPECT_TRUE(encoder_.Encode(data, dest.data(), &dest_size, &error_msg_)); // destination too small ASSERT_EQ(0, dest_size); ASSERT_EQ("", error_msg_); } TEST_F(HuffCoderTest, Decode) { array hist; hist.fill(1); hist['a'] = 100; hist['b'] = 50; ASSERT_TRUE(encoder_.Build(hist.data(), hist.size() - 1, &error_msg_)); string data("aab"); array encoded{0}; uint32_t encoded_size = encoded.size(); ASSERT_TRUE(encoder_.Encode(data, encoded.data(), &encoded_size, &error_msg_)); ASSERT_EQ(1, encoded_size); EXPECT_EQ(2, encoder_.GetNBits('a')); EXPECT_EQ(3, encoder_.GetNBits('b')); string bindata = encoder_.Export(); ASSERT_TRUE(decoder_.Load(bindata, &error_msg_)) << error_msg_; const char* src_ptr = reinterpret_cast(encoded.data()); array decode_dest{0}; size_t decoded_size = data.size(); ASSERT_TRUE(decoder_.Decode({src_ptr, encoded_size}, decoded_size, decode_dest.data())); ASSERT_EQ("aab", string_view(decode_dest.data(), decoded_size)); } TEST_F(HuffCoderTest, HugeHistogram) { array hist{ 1, 1, 1, 1, 1, 1, 1, 1, 5, 26, 543, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 114012534, 12081, 13038, 1596, 1334, 83320, 706165, 475568, 2779, 2548, 998, 29249967, 53961, 13175485, 99000, 69726435, 69422967, 182172009, 123544533, 76493373, 96341977, 64601914, 48105392, 60215630, 69253599, 48811529, 818580990, 1226, 69, 922, 140, 720, 230, 333714212, 95995178, 65692203, 50995122, 52156728, 44187793, 32988519, 46978428, 49648957, 43769567, 68958857, 56765240, 80721594, 51577447, 70298692, 56957407, 93372706, 47400672, 70912347, 78241282, 49291723, 69807896, 48372387, 39312015, 58020704, 60084247, 1378, 2471, 1584, 14, 37880886, 117, 184273430, 80952783, 135676228, 101229664, 230479318, 70652028, 137836653, 70943805, 154072333, 29316298, 58302725, 109445030, 117306062, 129270567, 166048852, 103000639, 54174517, 174819705, 166323524, 124543976, 80215452, 49650895, 101281709, 49817574, 56668585, 50459552, 273352049, 166, 273352009, 16, 1, 57668, 1724, 1886, 3668, 3960, 1963, 1124, 945, 1836, 1882, 1709, 2389, 921, 2154, 1020, 1792, 3747, 6750, 1318, 3100, 4506, 1175, 1514, 1430, 3474, 44548, 3179, 1149, 2410, 9689, 727, 2348, 2148, 1785, 5025, 1040, 3246, 1699, 505, 1034, 9995, 24776, 3345, 1897, 1019, 1614, 35349, 988, 2469, 5759, 2043, 7976, 1229, 896, 2692, 962, 3341, 2490, 2648, 1162, 4812, 8404, 949, 3132, 1, 1, 34754, 58694, 3400, 561, 6, 5, 3, 47, 41, 19, 292, 24, 17, 12, 626, 382, 6, 1, 1, 9, 1, 433, 879, 743, 7, 9, 1, 1, 1, 60, 746, 224, 54115, 4566, 5463, 10917, 5446, 7960, 5382, 2204, 281, 649, 761, 188, 1, 2630, 6680, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; // for huge values we need to scale down the histogram because the Huffman algorithm // implementation crashes otherwise. // The bug is in the following code in huf_compress.c: // huffNode0[0].count = (U32)(1U<<31); /* fake entry, strong barrier */ // where it uses the count as a sentinel assuming that no other counts can be larger than 2^31. // this may not be true for histograms with huge counts, so we need to make sure that sum of all // counts is smaller than 2^31. uint64_t sum = 0; for (unsigned i = 0; i < hist.size(); ++i) { sum += hist[i]; hist[i] /= 4; // Without this the algorithm causes a data race and crash. } LOG(INFO) << "Total sum: " << sum << " reduced sum: " << sum / 4; ASSERT_TRUE(encoder_.Build(hist.data(), hist.size() - 1, &error_msg_)) << error_msg_; string bindata = encoder_.Export(); encoder_.Reset(); ASSERT_TRUE(encoder_.Load(bindata, &error_msg_)) << error_msg_; } using benchmark::DoNotOptimize; // Parse Double benchmarks static void BM_ParseFastFloat(benchmark::State& state) { std::vector args(100); std::random_device rd; for (auto& arg : args) { arg = std::to_string(std::uniform_real_distribution(0, 1e5)(rd)); } double res; while (state.KeepRunning()) { for (const auto& arg : args) { fast_float::from_chars(arg.data(), arg.data() + arg.size(), res); } } } BENCHMARK(BM_ParseFastFloat); static void BM_ParseDoubleAbsl(benchmark::State& state) { std::vector args(100); for (auto& arg : args) { arg = std::to_string(std::uniform_real_distribution(0, 1e5)(rd)); } double res; while (state.KeepRunning()) { for (const auto& arg : args) { absl::from_chars(arg.data(), arg.data() + arg.size(), res); } } } BENCHMARK(BM_ParseDoubleAbsl); template void BM_ClockType(benchmark::State& state) { timespec ts; while (state.KeepRunning()) { DoNotOptimize(clock_gettime(cid, &ts)); } } BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_REALTIME); BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_MONOTONIC); BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_PROCESS_CPUTIME_ID); BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_THREAD_CPUTIME_ID); // These clocks are not available on apple platform #if !defined(__APPLE__) BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_REALTIME_COARSE); BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_MONOTONIC_COARSE); BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_BOOTTIME); BENCHMARK_TEMPLATE(BM_ClockType, CLOCK_BOOTTIME_ALARM); #endif static void BM_MatchGlob(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); GlobMatcher matcher("*foobar*", true); while (state.KeepRunning()) { DoNotOptimize(matcher.Matches(random_val)); } } BENCHMARK(BM_MatchGlob)->Arg(32)->Arg(1000)->Arg(10000); static void BM_MatchGlob2(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); GlobMatcher matcher("bull:*:meta", true); while (state.KeepRunning()) { DoNotOptimize(matcher.Matches(random_val)); } } BENCHMARK(BM_MatchGlob2)->Arg(32)->Arg(1000)->Arg(10000); // See https://nvd.nist.gov/vuln/detail/cve-2022-36021 static void BM_MatchGlobExp(benchmark::State& state) { GlobMatcher matcher("a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*b", true); while (state.KeepRunning()) { DoNotOptimize(matcher.Matches("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")); } } BENCHMARK(BM_MatchGlobExp); static void BM_MatchFindSubstr(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); while (state.KeepRunning()) { DoNotOptimize(random_val.find("foobar")); } } BENCHMARK(BM_MatchFindSubstr)->Arg(1000)->Arg(10000); static void BM_MatchReflexFind(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); reflex::Matcher matcher("foobar"); while (state.KeepRunning()) { matcher.input(random_val); DoNotOptimize(matcher.find()); } } BENCHMARK(BM_MatchReflexFind)->Arg(1000)->Arg(10000); static void BM_MatchReflexFindStar(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); reflex::Matcher matcher(".*foobar"); while (state.KeepRunning()) { matcher.input(random_val); DoNotOptimize(matcher.find()); } } BENCHMARK(BM_MatchReflexFindStar)->Arg(1000)->Arg(10000); static void BM_MatchStd(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); std::regex regex(".*foobar"); std::match_results results; while (state.KeepRunning()) { std::regex_match(random_val, results, regex); } } BENCHMARK(BM_MatchStd)->Arg(1000)->Arg(10000); static void BM_MatchRedisGlob(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); const char* pattern = "*foobar*"; while (state.KeepRunning()) { DoNotOptimize( stringmatchlen(pattern, strlen(pattern), random_val.c_str(), random_val.size(), 0)); } } BENCHMARK(BM_MatchRedisGlob)->Arg(1000)->Arg(10000); static void BM_MatchRedisGlob2(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); const char* pattern = "bull:*:meta"; while (state.KeepRunning()) { DoNotOptimize( stringmatchlen(pattern, strlen(pattern), random_val.c_str(), random_val.size(), 0)); } } BENCHMARK(BM_MatchRedisGlob2)->Arg(32)->Arg(1000)->Arg(10000); static void BM_MatchData(benchmark::State& state) { vector keys(5000); for (unsigned i = 0; i < keys.size(); ++i) { keys[i] = GetRandomHex(80); } string_view pattern = "*2addb1c3-eae5-5265-ac8e-9fc9106dda8d*77de68daecd823babbb58edb1c8e14d7106e83bb"sv; if (state.range(0) == 1) { GlobMatcher matcher(pattern, true); while (state.KeepRunning()) { for (const auto& key : keys) { DoNotOptimize(matcher.Matches(key)); } } } else { while (state.KeepRunning()) { for (const auto& key : keys) { DoNotOptimize(stringmatchlen(pattern.data(), pattern.size(), key.c_str(), key.size(), 0)); } } } } BENCHMARK(BM_MatchData)->ArgName("algo")->Arg(0)->Arg(1); #ifdef USE_RE2 static void BM_MatchRe2(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); re2::RE2 re(".*foobar.*", re2::RE2::Latin1); CHECK(re.ok()); while (state.KeepRunning()) { DoNotOptimize(re2::RE2::FullMatch(random_val, re)); } } BENCHMARK(BM_MatchRe2)->Arg(1000)->Arg(10000); #endif #ifdef USE_PCRE2 pair create_pcre2(const char* pattern) { int errnum; PCRE2_SIZE erroffset; pcre2_code* re = pcre2_compile((PCRE2_SPTR)pattern, PCRE2_ZERO_TERMINATED, 0, &errnum, &erroffset, nullptr); CHECK(re); CHECK_EQ(0, pcre2_jit_compile(re, PCRE2_JIT_COMPLETE)); pcre2_match_data* match_data = pcre2_match_data_create_from_pattern(re, NULL); return {re, match_data}; } int pcre2_do_match(string_view str, pcre2_code* re, pcre2_match_data* match_data) { int rc = pcre2_jit_match(re, (PCRE2_SPTR)str.data(), str.size(), 0, PCRE2_ANCHORED | PCRE2_ENDANCHORED, match_data, NULL); return rc; } static void BM_MatchPcre2Jit(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); auto [re, match_data] = create_pcre2(".*foobar.*"); const char sample[] = "aaaaaaaaaaaaafoobar"; int rc = pcre2_do_match(sample, re, match_data); CHECK_EQ(1, rc); while (state.KeepRunning()) { rc = pcre2_do_match(random_val, re, match_data); CHECK_EQ(PCRE2_ERROR_NOMATCH, rc); } pcre2_match_data_free(match_data); pcre2_code_free(re); } BENCHMARK(BM_MatchPcre2Jit)->Arg(32)->Arg(1000)->Arg(10000); static void BM_MatchPcre2Jit2(benchmark::State& state) { string random_val = GetRandomHex(state.range(0)); auto [re, match_data] = create_pcre2("foo.*bar"); while (state.KeepRunning()) { int rc = pcre2_do_match(random_val, re, match_data); CHECK_EQ(PCRE2_ERROR_NOMATCH, rc); } pcre2_match_data_free(match_data); pcre2_code_free(re); } BENCHMARK(BM_MatchPcre2Jit2)->Arg(32)->Arg(1000)->Arg(10000); static void BM_MatchPcre2JitExp(benchmark::State& state) { string exponent_pattern = "a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*a*b"; string str = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; auto [re, match_data] = create_pcre2(exponent_pattern.c_str()); while (state.KeepRunning()) { int rc = pcre2_do_match(str, re, match_data); CHECK_EQ(PCRE2_ERROR_NOMATCH, rc); } pcre2_match_data_free(match_data); pcre2_code_free(re); } BENCHMARK(BM_MatchPcre2JitExp); #endif static void BM_MatchGlobSlow(benchmark::State& state) { GlobMatcher matcher("a*a*a*a*a*.pt", false); while (state.KeepRunning()) { DoNotOptimize(GlobMatcher("a*a*a*a*a*.pt", false)); } } BENCHMARK(BM_MatchGlobSlow); } // namespace dfly ================================================ FILE: src/core/dict_builder.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/dict_builder.h" #include #include #include #include #include #include #include "base/logging.h" namespace dfly { using namespace std; namespace { constexpr unsigned kDmerLength = 6; // Fast hash for 6-byte d-mers. Uses a simple multiplicative hash. inline uint32_t HashDmer(const uint8_t* data) { uint64_t val = 0; memcpy(&val, data, 6); // ZSTD_hash6 algorithm constexpr uint64_t kPrime6Bytes = 227718039650203ULL; uint64_t hash64 = ((val << 16) * kPrime6Bytes) >> 32; return static_cast(hash64); } constexpr unsigned kRegisterLen = 1024; constexpr uint32_t kRegisterMask = kRegisterLen - 1; constexpr unsigned kRegisterBits = 10; constexpr unsigned kRankBits = 32 - kRegisterBits; inline void UpdateHllRegister(uint32_t h, uint8_t* registers) { uint32_t index = h & kRegisterMask; // Use upper bits for rank calculation, ensuring it's never zero uint32_t w = (h >> kRegisterBits) | (1u << kRankBits); uint8_t rank = countr_zero(w) + 1; registers[index] = std::max(registers[index], rank); } double EstimateHllCardinality(const uint8_t* registers) { double sum = 0.0; int zero_registers = 0; for (unsigned i = 0; i < kRegisterLen; ++i) { if (registers[i] == 0) { zero_registers++; } sum += 1.0 / (1 << registers[i]); } // alpha_m * m^2 where m = kRegisterLen // Constants from original HyperLogLog paper (Flajolet et al.) constexpr double kAlphaInf = 0.7213; constexpr double kAlphaCorrection = 1.079; constexpr double kM = static_cast(kRegisterLen); constexpr double kAlphaM2 = (kAlphaInf / (1.0 + kAlphaCorrection / kM)) * (kM * kM); double estimate = kAlphaM2 / sum; // Small range correction constexpr double kSmallRangeThreshold = 2.5 * kM; if (estimate <= kSmallRangeThreshold && zero_registers > 0) { estimate = kM * std::log(kM / zero_registers); } return estimate; } uint32_t CalculateFreqTableSize(absl::Span> data_pieces) { size_t total_input_size = 0; for (const auto& [data, sz] : data_pieces) { total_input_size += sz; } size_t target_size = std::max(1024, total_input_size); return std::bit_ceil(static_cast(std::min(target_size, 1u << 24))); } // Scans all provided data pieces to compute a histogram of 6-byte sequence (d-mer) hashes. void PopulateFrequencyTable(absl::Span> data_pieces, uint16_t* freq, uint32_t freq_table_mask) { for (const auto& [data, sz] : data_pieces) { if (sz < kDmerLength) continue; size_t limit = sz - kDmerLength + 1; for (size_t i = 0; i < limit; ++i) { uint32_t idx = HashDmer(data + i) & freq_table_mask; if (freq[idx] < UINT16_MAX) { ++freq[idx]; } } } } struct BestSegmentResult { std::pair data_piece{nullptr, 0}; uint64_t score = 0; }; // Iterates across all data pieces to find a contiguous byte window of `segment_size` // that maximizes the sum of previously computed sequence frequencies. BestSegmentResult FindBestSegment(absl::Span> data_pieces, size_t segment_size, const uint16_t* freq, uint32_t freq_table_mask) { BestSegmentResult best; for (const auto& [data, sz] : data_pieces) { if (sz < segment_size) continue; size_t window_dmers = segment_size - kDmerLength + 1; uint64_t score = 0; // Compute initial window score for (size_t j = 0; j < window_dmers; ++j) { score += freq[HashDmer(data + j) & freq_table_mask]; } if (score > best.score) { best.score = score; best.data_piece = {data, segment_size}; } // Slide the window size_t limit = sz - segment_size; for (size_t i = 1; i <= limit; ++i) { score -= freq[HashDmer(data + i - 1) & freq_table_mask]; score += freq[HashDmer(data + i + window_dmers - 1) & freq_table_mask]; if (score > best.score) { best.score = score; best.data_piece = {data + i, segment_size}; } } } return best; } void ZeroOutFrequencies(std::pair data_piece, uint16_t* freq, uint32_t freq_table_mask) { if (data_piece.second < kDmerLength) return; size_t seg_dmers = data_piece.second - kDmerLength + 1; for (size_t j = 0; j < seg_dmers; ++j) { freq[HashDmer(data_piece.first + j) & freq_table_mask] = 0; } } } // namespace // Estimates dictionary compressibility by observing the cardinality // of unique 6-byte substrings via a simplified internal HyperLogLog. double EstimateCompressibility(absl::Span> data_pieces, unsigned step) { DCHECK_GT(step, 0u); unique_ptr registers(new uint8_t[kRegisterLen]()); uint64_t total_dmers = 0; for (const auto& [data, sz] : data_pieces) { if (sz < kDmerLength) continue; size_t limit = sz - kDmerLength + 1; for (size_t i = 0; i < limit; i += step) { UpdateHllRegister(HashDmer(data + i), registers.get()); ++total_dmers; } } if (total_dmers == 0) { return 1.0; // No d-mers - we consider it incompressible } double estimate = EstimateHllCardinality(registers.get()); double ratio = estimate / static_cast(total_dmers); return std::min(ratio, 1.0); } // Trains a dictionary using FastCover-style iterative segment selection. // 1. Builds a frequency table of 6-byte d-mer hashes. // 2. For each data piece (epoch), selects the segment of segment_size bytes // that maximizes the sum of d-mer frequencies. // 3. Appends selected segment to dictionary, zeros out its d-mer frequencies. // Returns raw dictionary bytes of approximately dict_size. string TrainDictionary(absl::Span> data_pieces, size_t dict_size, size_t segment_size) { DCHECK_GT(dict_size, 0u); DCHECK_GT(segment_size, kDmerLength); uint32_t freq_table_size = CalculateFreqTableSize(data_pieces); uint32_t freq_table_mask = freq_table_size - 1; unique_ptr freq(new uint16_t[freq_table_size]()); PopulateFrequencyTable(data_pieces, freq.get(), freq_table_mask); std::string dictionary; dictionary.reserve(dict_size); while (dictionary.size() < dict_size) { auto best = FindBestSegment(data_pieces, segment_size, freq.get(), freq_table_mask); if (!best.data_piece.first || best.score == 0) { break; // No useful segments left. } size_t append_size = std::min(best.data_piece.second, dict_size - dictionary.size()); dictionary.append(reinterpret_cast(best.data_piece.first), append_size); ZeroOutFrequencies(best.data_piece, freq.get(), freq_table_mask); } return dictionary; } } // namespace dfly ================================================ FILE: src/core/dict_builder.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include namespace dfly { // Estimates compressibility by counting unique 6-byte d-mers using HyperLogLog. // data_pieces: spans of raw data (e.g., one per QList node). // step: sampling stride (1 = every offset, higher = faster but less accurate). // Returns a value in [0, 1] where 0 means very compressible, and 1 means incompressible. double EstimateCompressibility(absl::Span> data_pieces, unsigned step); // Trains a compression dictionary from a collection of sample data. // // Arguments: // data_pieces: Input data sources (spans of bytes) to extract dictionary segments from. // dict_size: The maximum target size of the resulting dictionary in bytes. // segment_size: The size of continuous byte segments chosen and appended per iteration. // // Returns a raw string containing the trained dictionary up to `dict_size` bytes. std::string TrainDictionary(absl::Span> data_pieces, size_t dict_size = 4096, size_t segment_size = 256); } // namespace dfly ================================================ FILE: src/core/dict_builder_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/dict_builder.h" #include #include #include #include #include #include "base/logging.h" namespace dfly { using namespace std; class DictBuilderTest : public ::testing::Test { protected: using DataPiece = pair; // Generate Celery-like JSON entries with small variations. vector GenerateCeleryEntries(unsigned count) { vector entries; entries.reserve(count); for (unsigned i = 0; i < count; ++i) { string id = to_string(100000 + i); string entry = "{\"body\": \"W10=\", \"content-encoding\": \"utf-8\", " "\"content-type\": \"application/json\", " "\"headers\": {\"lang\": \"py\", \"task\": \"process_job\", " "\"id\": \"b3e4b923-8a77-4053-aff0-" + id + "\", \"shadow\": null, \"eta\": null, " "\"expires\": null, \"group\": null, \"retries\": 0, " "\"timelimit\": [null, null], " "\"root_id\": \"b3e4b923-8a77-4053-aff0-" + id + "\", \"parent_id\": null, " "\"argsrepr\": \"('job" + to_string(i) + "',)\", \"kwargsrepr\": \"{}\", " "\"origin\": \"gen917779@hut\"}, " "\"properties\": {\"correlation_id\": \"b3e4b923\", " "\"reply_to\": \"9933040c\", \"delivery_mode\": 2, " "\"delivery_info\": {\"exchange\": \"\", \"routing_key\": \"my_queue\"}, " "\"priority\": 0}}"; entries.push_back(std::move(entry)); } return entries; } vector ToPieces(const vector& entries) { vector pieces; pieces.reserve(entries.size()); for (const auto& e : entries) { pieces.emplace_back(reinterpret_cast(e.data()), e.size()); } return pieces; } // Generate random binary data. vector GenerateRandomEntries(unsigned count, size_t entry_size) { vector entries; entries.reserve(count); mt19937 rng(42); for (unsigned i = 0; i < count; ++i) { string entry(entry_size, '\0'); for (auto& c : entry) { c = static_cast(rng() & 0xFF); } entries.push_back(std::move(entry)); } return entries; } }; TEST_F(DictBuilderTest, RepetitiveDataIsCompressible) { auto entries = GenerateCeleryEntries(200); auto pieces = ToPieces(entries); double ratio = EstimateCompressibility(pieces, 1); LOG(INFO) << "Celery data uniqueness ratio: " << ratio; EXPECT_LT(ratio, 0.5f); } TEST_F(DictBuilderTest, RandomDataIsIncompressible) { auto entries = GenerateRandomEntries(200, 400); auto pieces = ToPieces(entries); double ratio = EstimateCompressibility(pieces, 1); LOG(INFO) << "Random data uniqueness ratio: " << ratio; EXPECT_FALSE(ratio < 0.85); } TEST_F(DictBuilderTest, TrainDictionaryProducesOutput) { auto entries = GenerateCeleryEntries(200); auto pieces = ToPieces(entries); string dict = TrainDictionary(pieces, 4096, 256); LOG(INFO) << "Trained dictionary size: " << dict.size() << " bytes"; EXPECT_GT(dict.size(), 0u); EXPECT_LE(dict.size(), 4096u); } TEST_F(DictBuilderTest, TrainDictionaryEmptyForTinyData) { // Single small entry - not enough for segment selection. string tiny = "hello"; vector pieces = {{reinterpret_cast(tiny.data()), tiny.size()}}; string dict = TrainDictionary(pieces, 4096, 256); EXPECT_TRUE(dict.empty()); } TEST_F(DictBuilderTest, ZstdCompressionWithTrainedDict) { auto entries = GenerateCeleryEntries(200); auto pieces = ToPieces(entries); string dict = TrainDictionary(pieces, 4096, 256); ASSERT_GT(dict.size(), 0u); // Create ZSTD CDict/DDict from trained dictionary. ZSTD_CDict* cdict = ZSTD_createCDict(dict.data(), dict.size(), 1); ASSERT_TRUE(cdict); ZSTD_DDict* ddict = ZSTD_createDDict(dict.data(), dict.size()); ASSERT_TRUE(ddict); ZSTD_CCtx* cctx = ZSTD_createCCtx(); ZSTD_DCtx* dctx = ZSTD_createDCtx(); size_t total_raw = 0; size_t total_compressed_dict = 0; size_t total_compressed_nodict = 0; for (const auto& entry : entries) { total_raw += entry.size(); // Compress with dictionary. size_t bound = ZSTD_compressBound(entry.size()); string compressed(bound, '\0'); size_t csz = ZSTD_compress_usingCDict(cctx, compressed.data(), bound, entry.data(), entry.size(), cdict); ASSERT_FALSE(ZSTD_isError(csz)) << ZSTD_getErrorName(csz); compressed.resize(csz); total_compressed_dict += csz; // Compress without dictionary for comparison. string compressed_nodict(bound, '\0'); size_t csz_nodict = ZSTD_compressCCtx(cctx, compressed_nodict.data(), bound, entry.data(), entry.size(), 1); ASSERT_FALSE(ZSTD_isError(csz_nodict)); total_compressed_nodict += csz_nodict; // Verify roundtrip. string decompressed(entry.size(), '\0'); size_t dsz = ZSTD_decompress_usingDDict(dctx, decompressed.data(), entry.size(), compressed.data(), csz, ddict); ASSERT_FALSE(ZSTD_isError(dsz)) << ZSTD_getErrorName(dsz); ASSERT_EQ(dsz, entry.size()); EXPECT_EQ(decompressed, entry); } double ratio_dict = double(total_raw) / double(total_compressed_dict); double ratio_nodict = double(total_raw) / double(total_compressed_nodict); LOG(INFO) << "Total raw: " << total_raw << " bytes"; LOG(INFO) << "With dict: " << total_compressed_dict << " bytes (ratio " << ratio_dict << "x)"; LOG(INFO) << "No dict: " << total_compressed_nodict << " bytes (ratio " << ratio_nodict << "x)"; LOG(INFO) << "Dict advantage: " << ratio_dict / ratio_nodict << "x better"; // Dictionary compression should be significantly better for repetitive data. EXPECT_GT(ratio_dict, ratio_nodict); EXPECT_GT(ratio_dict, 3.0f); // Expect at least 3x compression with dict. ZSTD_freeCCtx(cctx); ZSTD_freeDCtx(dctx); ZSTD_freeCDict(cdict); ZSTD_freeDDict(ddict); } TEST_F(DictBuilderTest, StepParameterWorks) { auto entries = GenerateCeleryEntries(200); auto pieces = ToPieces(entries); double step1_ratio = EstimateCompressibility(pieces, 1); double step4_ratio = EstimateCompressibility(pieces, 4); // Both should detect compressibility, though with slightly different ratios. EXPECT_TRUE(step1_ratio < 0.85); EXPECT_TRUE(step4_ratio < 0.85); LOG(INFO) << "Step=1 ratio: " << step1_ratio << ", Step=4 ratio: " << step4_ratio; } } // namespace dfly ================================================ FILE: src/core/dragonfly_core.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include "base/logging.h" #include "core/intent_lock.h" namespace dfly { const char* IntentLock::ModeName(Mode m) { switch (m) { case IntentLock::SHARED: return "SHARED"; case IntentLock::EXCLUSIVE: return "EXCLUSIVE"; } ABSL_UNREACHABLE(); } void IntentLock::VerifyDebug() { constexpr uint32_t kMsb = 1ULL << (sizeof(cnt_[0]) * 8 - 1); DCHECK_EQ(0u, cnt_[0] & kMsb); DCHECK_EQ(0u, cnt_[1] & kMsb); } } // namespace dfly ================================================ FILE: src/core/expire_period.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include namespace dfly { class ExpirePeriod { public: static constexpr size_t kMaxGenId = 15; ExpirePeriod() : val_(0), gen_(0), precision_(0) { static_assert(sizeof(ExpirePeriod) == 8); // TODO } explicit ExpirePeriod(uint64_t ms, unsigned gen = 0) : ExpirePeriod() { Set(ms); } // always returns milliseconds value. uint64_t duration_ms() const { return precision_ ? uint64_t(val_) * 1000 : val_; } // generation id for the base of this duration. // when we update the generation, we need to update the value as well according to this // logic: // new_val = (old_val + old_base) - new_base. unsigned generation_id() const { return gen_; } void Set(uint64_t ms); bool is_second_precision() { return precision_ == 1;} private: uint64_t val_ : 59; uint64_t gen_ : 4; uint64_t precision_ : 1; // 0 - ms, 1 - sec. }; inline void ExpirePeriod::Set(uint64_t ms) { constexpr uint64_t kBarrier = (1ULL << 48); if (ms < kBarrier) { val_ = ms; precision_ = 0; // ms return; } precision_ = 1; if (ms < kBarrier << 10) { ms = (ms + 500) / 1000; // seconds } val_ = ms >= kBarrier ? kBarrier - 1 : ms; } } // namespace dfly ================================================ FILE: src/core/extent_tree.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/extent_tree.h" #include "base/logging.h" namespace dfly { using namespace std; // offset, len must be multiplies of 256MB. void ExtentTree::Add(size_t start, size_t len) { DCHECK_GT(len, 0u); DCHECK_EQ(len_extents_.size(), extents_.size()); auto it = extents_.lower_bound(start); optional prev_extent_key; if (it != extents_.begin()) { auto prev = it; --prev; DCHECK_LE(prev->second, start); if (prev->second == start) { // combine with the previous extent size_t prev_len = prev->second - prev->first; CHECK_EQ(1u, len_extents_.erase(pair{prev_len, prev->first})); prev->second += len; start = prev->first; len += prev_len; prev_extent_key = prev->first; } } if (it != extents_.end()) { DCHECK_GE(it->first, start + len); if (start + len == it->first) { // merge with the next extent size_t it_len = it->second - it->first; CHECK_EQ(1u, len_extents_.erase(pair{it_len, it->first})); extents_.erase(it); len += it_len; } } len_extents_.emplace(len, start); if (prev_extent_key) { DCHECK(extents_.find(*prev_extent_key) != extents_.end()); extents_[*prev_extent_key] = start + len; } else { extents_.emplace(start, start + len); } } optional> ExtentTree::GetRange(size_t len, size_t align) { DCHECK_GT(align, 0u); DCHECK_EQ(0u, align & (align - 1)); DCHECK_EQ(0u, len & (align - 1)); auto it = len_extents_.lower_bound(pair{len, 0}); if (it == len_extents_.end()) return nullopt; size_t amask = align - 1; size_t aligned_start = it->second; size_t extent_end = it->first + it->second; while (true) { if ((aligned_start & amask) == 0) // aligned break; // round up to the next aligned address aligned_start = (aligned_start + amask) & (~amask); if (aligned_start + len <= extent_end) // check if we still inside the extent break; ++it; if (it == len_extents_.end()) return nullopt; aligned_start = it->second; extent_end = it->first + it->second; } DCHECK_GE(aligned_start, it->second); // if we are here - we found the range starting at aligned_start. // now we need to possibly break the existing extent to several parts or completely // delete it. auto eit = extents_.find(it->second); DCHECK(eit != extents_.end()); size_t range_end = aligned_start + len; len_extents_.erase(it); // we break the extent [eit->first, eit->second] to either 0, 1 or 2 intervals. if (aligned_start > eit->first) { // do we have prefix? eit->second = aligned_start; len_extents_.emplace(eit->second - eit->first, eit->first); } else { extents_.erase(eit); } if (range_end < extent_end) { // do we have suffix? extents_.emplace(range_end, extent_end); len_extents_.emplace(extent_end - range_end, range_end); } DCHECK_EQ(range_end - aligned_start, len); return pair{aligned_start, range_end}; } } // namespace dfly ================================================ FILE: src/core/extent_tree.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly { // represents a tree of disjoint extents. // check-fails if overlapping ranges are added. // automatically handles union of the consequent ranges that are added to the tree. class ExtentTree { public: void Add(size_t start, size_t len); // in case of success, returns (start, end) pair, where (end-start) >= len and // start is aligned by align. std::optional> GetRange(size_t len, size_t align); private: absl::btree_map extents_; // start -> end). absl::btree_set> len_extents_; // (length, start) }; } // namespace dfly ================================================ FILE: src/core/extent_tree_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/extent_tree.h" #include #include "base/gtest.h" #include "base/logging.h" namespace dfly { using namespace std; class ExtentTreeTest : public ::testing::Test { protected: static void SetUpTestSuite() { } static void TearDownTestSuite() { } ExtentTree tree_; }; TEST_F(ExtentTreeTest, Basic) { tree_.Add(0, 256); auto op = tree_.GetRange(64, 16); EXPECT_TRUE(op); EXPECT_THAT(*op, testing::Pair(0, 64)); // [64, 256) tree_.Add(56, 8); op = tree_.GetRange(64, 16); EXPECT_TRUE(op); EXPECT_THAT(*op, testing::Pair(64, 128)); // {[56, 64), [128, 256)} op = tree_.GetRange(18, 2); EXPECT_TRUE(op); EXPECT_THAT(*op, testing::Pair(128, 146)); // {[56, 64), [146, 256)} op = tree_.GetRange(80, 16); EXPECT_TRUE(op); EXPECT_THAT(*op, testing::Pair(160, 240)); // {[56, 64), [146, 160), [240, 256)} op = tree_.GetRange(4, 1); EXPECT_TRUE(op); EXPECT_THAT(*op, testing::Pair(56, 60)); // {[60, 64), [146, 160), [240, 256)} op = tree_.GetRange(32, 1); EXPECT_FALSE(op); tree_.Add(64, 146 - 64); op = tree_.GetRange(32, 4); EXPECT_TRUE(op); EXPECT_THAT(*op, testing::Pair(60, 92)); } TEST_F(ExtentTreeTest, Union) { tree_.Add(0, 16); tree_.Add(16, 16); auto range = tree_.GetRange(32, 1); ASSERT_TRUE(range); EXPECT_THAT(*range, testing::Pair(0, 32)); } } // namespace dfly ================================================ FILE: src/core/flatbuffers.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #ifndef __USE_GNU // needed to flatbuffers to compile with musl libc. #define FLATBUFFERS_LOCALE_INDEPENDENT 0 #endif #include #include #include namespace dfly { using FlatJson = flexbuffers::Reference; } // namespace dfly ================================================ FILE: src/core/flatbuffers_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/flatbuffers.h" #include #include "base/gtest.h" #include "base/logging.h" using namespace std; namespace dfly { class FlatBuffersTest : public ::testing::Test { protected: }; TEST_F(FlatBuffersTest, Basic) { flexbuffers::Builder fbb; fbb.Map([&] { fbb.String("foo", "bar"); fbb.Double("bar", 1.5); fbb.Vector("strs", [&] { fbb.String("hello"); fbb.String("world"); }); }); fbb.Finish(); auto buffer = fbb.GetBuffer(); flexbuffers::Reference ref = flexbuffers::GetRoot(buffer); auto map = ref.AsMap(); EXPECT_EQ("bar", map["foo"].AsString().str()); } TEST_F(FlatBuffersTest, FlexiParser) { flatbuffers::Parser parser; const char* json = R"( { "foo": "bar", "bar": 1.5, "strs": ["hello", "world"] } )"; flexbuffers::Builder fbb; ASSERT_TRUE(parser.ParseFlexBuffer(json, nullptr, &fbb)); fbb.Finish(); const auto& buffer = fbb.GetBuffer(); string_view buf_view{reinterpret_cast(buffer.data()), buffer.size()}; LOG(INFO) << "Binary buffer: " << absl::CHexEscape(buf_view); flexbuffers::Reference root = flexbuffers::GetRoot(buffer); auto map = root.AsMap(); EXPECT_EQ("bar", map["foo"].AsString().str()); } TEST_F(FlatBuffersTest, ParseJson) { const char* schema = R"( namespace dfly; table Foo { foo: string; bar: double; strs: [string]; } root_type Foo; )"; flatbuffers::Parser parser; ASSERT_TRUE(parser.Parse(schema)); parser.Serialize(); flatbuffers::DetachedBuffer bsb = parser.builder_.Release(); // This schema will always reference bsb. auto* fbs_schema = reflection::GetSchema(bsb.data()); flatbuffers::Verifier verifier(bsb.data(), bsb.size()); ASSERT_TRUE(fbs_schema->Verify(verifier)); auto* root_table = fbs_schema->root_table(); auto* fields = root_table->fields(); auto* field_foo = fields->LookupByKey("foo"); ASSERT_EQ(field_foo->type()->base_type(), reflection::String); const char* json = R"( { "foo": "value", "bar": 1.5, "strs": ["hello", "world"] } )"; ASSERT_TRUE(parser.Parse(json)); size_t buf_size = parser.builder_.GetSize(); ASSERT_TRUE( flatbuffers::Verify(*fbs_schema, *root_table, parser.builder_.GetBufferPointer(), buf_size)); auto* root_obj = flatbuffers::GetAnyRoot(parser.builder_.GetBufferPointer()); const flatbuffers::String* value = flatbuffers::GetFieldS(*root_obj, *field_foo); EXPECT_EQ("value", value->str()); // wrong type. ASSERT_FALSE(parser.Parse(R"({"foo": 1})")); } } // namespace dfly ================================================ FILE: src/core/generate_bin_sizes.py ================================================ #!/usr/bin/env python3 import argparse import random from array import array # We print in 64 bit words. ALIGN = 1 << 10 # 1KB alignment def print_small_bins(): prev_val = 0 for i in range(56, 1, -1): len = (4096 - i*8) # reduce by size of hashes len = (len // 8)*8 # make it 8 bytes aligned if len != prev_val: print(i, len) prev_val = len print() def main(): parser = argparse.ArgumentParser(description='') parser.add_argument('-n', type=int, dest='num', help='number of quadruplets', default=9) parser.add_argument('-small', action='store_true') args = parser.parse_args() if args.small: print("small") print_small_bins() return size = 512*4 print ('{512, 512*2, 512*3, ', end=' ') # print ('{', end=' ') for i in range(args.num): incr = size // 4 for j in range(4): assert size % 512 == 0, size print (f'{size}, ', end=' ') size += incr if i % 2 == 1: print('') print('};') if __name__ == "__main__": main() ================================================ FILE: src/core/glob_matcher.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/glob_matcher.h" #include #include "base/logging.h" namespace dfly { using namespace std; /* Glob-style pattern matching taken from Valkey. */ static int stringmatchlen_impl(const char* pattern, int patternLen, const char* string, int stringLen, int nocase, int* skipLongerMatches, int nesting) { /* Protection against abusive patterns. */ if (nesting > 1000) return 0; while (patternLen && stringLen) { switch (pattern[0]) { case '*': while (patternLen && pattern[1] == '*') { pattern++; patternLen--; } if (patternLen == 1) return 1; /* match */ while (stringLen) { if (stringmatchlen_impl(pattern + 1, patternLen - 1, string, stringLen, nocase, skipLongerMatches, nesting + 1)) return 1; /* match */ if (*skipLongerMatches) return 0; /* no match */ string++; stringLen--; } /* There was no match for the rest of the pattern starting * from anywhere in the rest of the string. If there were * any '*' earlier in the pattern, we can terminate the * search early without trying to match them to longer * substrings. This is because a longer match for the * earlier part of the pattern would require the rest of the * pattern to match starting later in the string, and we * have just determined that there is no match for the rest * of the pattern starting from anywhere in the current * string. */ *skipLongerMatches = 1; return 0; /* no match */ break; case '?': string++; stringLen--; break; case '[': { int not_op, match; pattern++; patternLen--; not_op = patternLen && pattern[0] == '^'; if (not_op) { pattern++; patternLen--; } match = 0; while (1) { if (patternLen >= 2 && pattern[0] == '\\') { pattern++; patternLen--; if (pattern[0] == string[0]) match = 1; } else if (patternLen == 0) { pattern--; patternLen++; break; } else if (pattern[0] == ']') { break; } else if (patternLen >= 3 && pattern[1] == '-') { int start = pattern[0]; int end = pattern[2]; int c = string[0]; if (start > end) { int t = start; start = end; end = t; } if (nocase) { start = tolower(start); end = tolower(end); c = tolower(c); } pattern += 2; patternLen -= 2; if (c >= start && c <= end) match = 1; } else { if (!nocase) { if (pattern[0] == string[0]) match = 1; } else { if (tolower((int)pattern[0]) == tolower((int)string[0])) match = 1; } } pattern++; patternLen--; } if (not_op) match = !match; if (!match) return 0; /* no match */ string++; stringLen--; break; } case '\\': if (patternLen >= 2) { pattern++; patternLen--; } /* fall through */ default: if (!nocase) { if (pattern[0] != string[0]) return 0; /* no match */ } else { if (tolower((int)pattern[0]) != tolower((int)string[0])) return 0; /* no match */ } string++; stringLen--; break; } pattern++; patternLen--; if (stringLen == 0) { while (patternLen && *pattern == '*') { pattern++; patternLen--; } break; } } if (patternLen == 0 && stringLen == 0) return 1; return 0; } int stringmatchlen(const char* pattern, int patternLen, const char* string, int stringLen, int nocase) { int skipLongerMatches = 0; return stringmatchlen_impl(pattern, patternLen, string, stringLen, nocase, &skipLongerMatches, 0); } string GlobMatcher::Glob2Regex(string_view glob) { string regex; regex.reserve(glob.size()); size_t in_group = 0; for (size_t i = 0; i < glob.size(); i++) { char c = glob[i]; if (in_group > 0) { if (c == ']') { if (i == in_group + 1) { if (glob[in_group] == '^') { // [^ regex.pop_back(); regex.back() = '.'; in_group = 0; continue; } } in_group = 0; } regex.push_back(c); if (c == '\\') { if (i + 1 < glob.size() && glob[i + 1] == ']') { ++i; regex.push_back(']'); } else { regex.push_back('\\'); // escape the backslash } } continue; } switch (c) { case '*': regex.append(".*"); break; case '?': regex.append("."); break; case '.': case '(': case ')': case '{': case '}': case '^': case '$': case '+': case '|': regex.push_back('\\'); regex.push_back(c); break; case '\\': if (i + 1 < glob.size()) { ++i; } if (absl::ascii_ispunct(glob[i])) { regex.push_back('\\'); } regex.push_back(glob[i]); break; case '[': regex.push_back('['); if (i + 1 < glob.size()) { in_group = i + 1; } break; default: regex.push_back(c); break; } } return regex; } GlobMatcher::GlobMatcher(string_view pattern, bool case_sensitive) : glob_(pattern), case_sensitive_(case_sensitive) { #ifdef REFLEX_PERFORMANCE if (!pattern.empty()) { starts_with_star_ = pattern.front() == '*'; pattern.remove_prefix(starts_with_star_); if (!pattern.empty()) { ends_with_star_ = (pattern.back() == '*') && (pattern.size() == 1 || pattern[pattern.size() - 2] != '\\'); pattern.remove_suffix(ends_with_star_); } } string regex("(?s"); // dotall mode if (!case_sensitive) { regex.push_back('i'); } regex.push_back(')'); if (pattern.empty()) { regex.append(Glob2Regex("*")); } else { regex.append(Glob2Regex(pattern)); } matcher_.pattern(regex); #elif defined(USE_PCRE2) string regex("(?s"); // dotall mode if (!case_sensitive) { regex.push_back('i'); } regex.push_back(')'); regex.append(Glob2Regex(pattern)); int errnum; PCRE2_SIZE erroffset; re_ = pcre2_compile((PCRE2_SPTR)regex.c_str(), regex.size(), 0, &errnum, &erroffset, nullptr); if (re_) { CHECK_EQ(0, pcre2_jit_compile(re_, PCRE2_JIT_COMPLETE)); match_data_ = pcre2_match_data_create_from_pattern(re_, NULL); } #endif } bool GlobMatcher::Matches(std::string_view str) const { #ifdef REFLEX_PERFORMANCE if (str.size() < 16) { return stringmatchlen(glob_.data(), glob_.size(), str.data(), str.size(), !case_sensitive_); } if (glob_.empty()) { return true; } DCHECK(!matcher_.pattern().empty()); matcher_.input(reflex::Input(str.data(), str.size())); bool use_find = starts_with_star_ || ends_with_star_; if (!use_find) { return matcher_.matches() > 0; } bool found = matcher_.find() > 0; if (!found) { return false; } if (!ends_with_star_ && matcher_.last() != str.size()) { return false; } if (!starts_with_star_ && matcher_.first() != 0) { return false; } return true; #elif defined(USE_PCRE2) if (!re_ || str.size() < 16) { return stringmatchlen(glob_.data(), glob_.size(), str.data(), str.size(), !case_sensitive_); } if (glob_.empty()) { return true; } int rc = pcre2_jit_match(re_, (PCRE2_SPTR)str.data(), str.size(), 0, 0, match_data_, NULL); return rc > 0; #else return stringmatchlen(glob_.data(), glob_.size(), str.data(), str.size(), !case_sensitive_); #endif } GlobMatcher::~GlobMatcher() { #ifdef REFLEX_PERFORMANCE #elif defined(USE_PCRE2) if (re_) { pcre2_code_free(re_); pcre2_match_data_free(match_data_); } #endif } } // namespace dfly ================================================ FILE: src/core/glob_matcher.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include // We opt for using Reflex library for glob matching. // While I find PCRE2 faster, it's not substantially faster to justify the shared lib dependency. // For some regex, Reflex (and pcre2) have extremely slow compile times(70+ms). // This latency is significant for the hot path and therefore both are disabled // and we fall back to the plain old stringmatchlen. For more info, refer to #5547 on gh. //#define REFLEX_PERFORMANCE #ifndef REFLEX_PERFORMANCE #ifdef USE_PCRE2 #define PCRE2_CODE_UNIT_WIDTH 8 #include #endif #endif namespace dfly { class GlobMatcher { GlobMatcher(const GlobMatcher&) = delete; GlobMatcher& operator=(const GlobMatcher&) = delete; public: explicit GlobMatcher(std::string_view pattern, bool case_sensitive); ~GlobMatcher(); bool Matches(std::string_view str) const; // Exposed for testing purposes. static std::string Glob2Regex(std::string_view glob); private: // TODO: we fix the problem of stringmatchlen being much // faster when the result is immediately known to be false, for example: "a*" vs "bxxxxx". // The goal is to demonstrate on-par performance for the following case: // > debug populate 5000000 keys 32 RAND // > while true; do time valkey-cli scan 0 match 'foo*bar'; done // Also demonstrate that the "improved" performance via SCAN command and not only via // micro-benchmark. // The performance of naive algorithm becomes worse in cases where string is long enough, // and the pattern has a star at the start (or it matches at first). #ifdef REFLEX_PERFORMANCE mutable reflex::Matcher matcher_; bool starts_with_star_ = false; bool ends_with_star_ = false; #elif defined(USE_PCRE2) pcre2_code_8* re_ = nullptr; pcre2_match_data_8* match_data_ = nullptr; #endif std::string_view glob_; bool case_sensitive_; }; } // namespace dfly ================================================ FILE: src/core/huff_coder.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/huff_coder.h" #include "base/logging.h" extern "C" { #include "huff/huf.h" } using namespace std; namespace dfly { constexpr size_t kWspSize = HUF_CTABLE_WORKSPACE_SIZE; bool HuffmanEncoder::Load(std::string_view binary_data, std::string* error_msg) { CHECK(!huf_ctable_); huf_ctable_.reset(new HUF_CElt[HUF_CTABLE_SIZE_ST(255)]); table_max_symbol_ = 255; unsigned has_zero_weights = 0; size_t read_size = HUF_readCTable(huf_ctable_.get(), &table_max_symbol_, binary_data.data(), binary_data.size(), &has_zero_weights); if (HUF_isError(read_size)) { huf_ctable_.reset(); *error_msg = HUF_getErrorName(read_size); return false; } if (read_size != binary_data.size()) { *error_msg = "Corrupted data"; huf_ctable_.reset(); return false; } HUF_CTableHeader header = HUF_readCTableHeader(huf_ctable_.get()); num_bits_ = header.tableLog; table_max_symbol_ = header.maxSymbolValue; return true; } bool HuffmanEncoder::Build(const unsigned hist[], unsigned max_symbol, std::string* error_msg) { CHECK(!huf_ctable_); huf_ctable_.reset(new HUF_CElt[HUF_CTABLE_SIZE_ST(max_symbol)]); unique_ptr wrkspace(new uint32_t[HUF_CTABLE_WORKSPACE_SIZE_U32]); size_t num_bits = HUF_buildCTable_wksp(huf_ctable_.get(), hist, max_symbol, 0, wrkspace.get(), kWspSize); if (HUF_isError(num_bits)) { *error_msg = HUF_getErrorName(num_bits); huf_ctable_.reset(); return false; } num_bits_ = static_cast(num_bits); table_max_symbol_ = max_symbol; return true; } void HuffmanEncoder::Reset() { huf_ctable_.reset(); table_max_symbol_ = 0; } bool HuffmanEncoder::Encode(std::string_view data, uint8_t* dest, uint32_t* dest_size, std::string* error_msg) const { DCHECK(huf_ctable_); size_t res = HUF_compress1X_usingCTable(dest, *dest_size, data.data(), data.size(), huf_ctable_.get(), 0); if (HUF_isError(res)) { *error_msg = HUF_getErrorName(res); return false; } *dest_size = static_cast(res); return true; } unsigned HuffmanEncoder::GetNBits(uint8_t symbol) const { DCHECK(huf_ctable_); return HUF_getNbBitsFromCTable(huf_ctable_.get(), symbol); } size_t HuffmanEncoder::EstimateCompressedSize(const unsigned hist[], unsigned max_symbol) const { DCHECK(huf_ctable_); size_t res = HUF_estimateCompressedSize(huf_ctable_.get(), hist, max_symbol); return res; } string HuffmanEncoder::Export() const { DCHECK(huf_ctable_); // Reverse engineered: (maxSymbolValue + 1) / 2 + 1. constexpr unsigned kMaxTableSize = 130; string res; res.resize(kMaxTableSize); unique_ptr wrkspace(new uint32_t[HUF_CTABLE_WORKSPACE_SIZE_U32]); // Seems we can reuse the same workspace, its capacity is enough. size_t size = HUF_writeCTable_wksp(res.data(), res.size(), huf_ctable_.get(), table_max_symbol_, num_bits_, wrkspace.get(), kWspSize); CHECK(!HUF_isError(size)); res.resize(size); return res; } // Copied from HUF_tightCompressBound. size_t HuffmanEncoder::CompressedBound(size_t src_size) const { return ((src_size * num_bits_) >> 3) + 8; } bool HuffmanDecoder::Load(std::string_view binary_data, std::string* error_msg) { DCHECK(!huf_dtable_); huf_dtable_.reset(new HUF_DTable[HUF_DTABLE_SIZE(HUF_TABLELOG_MAX)]); huf_dtable_[0] = (HUF_TABLELOG_MAX - 1) * 0x01000001; // some sort of magic number constexpr size_t kWspSize = HUF_DECOMPRESS_WORKSPACE_SIZE; unique_ptr wrksp(new uint8_t[kWspSize]); size_t res = HUF_readDTableX1_wksp(huf_dtable_.get(), binary_data.data(), binary_data.size(), wrksp.get(), kWspSize, 0); if (HUF_isError(res)) { *error_msg = HUF_getErrorName(res); huf_dtable_.reset(); return false; } if (res != binary_data.size()) { *error_msg = "Corrupted data"; huf_dtable_.reset(); return false; } return true; } bool HuffmanDecoder::Decode(std::string_view src, size_t dest_size, char* dest) const { DCHECK(huf_dtable_); size_t res = HUF_decompress1X_usingDTable(dest, dest_size, src.data(), src.size(), huf_dtable_.get(), 1); if (HUF_isError(res)) { LOG(DFATAL) << "Failed to decompress: " << HUF_getErrorName(res); return false; } return true; } } // namespace dfly ================================================ FILE: src/core/huff_coder.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly { class HuffmanEncoder { public: bool Build(const unsigned hist[], unsigned max_symbol, std::string* error_msg); bool Encode(std::string_view data, uint8_t* dest, uint32_t* dest_size, std::string* error_msg) const; size_t EstimateCompressedSize(const unsigned hist[], unsigned max_symbol) const; void Reset(); // Load using the serialized data produced by Export(). bool Load(std::string_view binary_data, std::string* error_msg); // Exports a binary representation of the table, that can be loaded using Load(). std::string Export() const; uint8_t num_bits() const { return num_bits_; } bool valid() const { return bool(huf_ctable_); } unsigned max_symbol() const { return table_max_symbol_; } unsigned GetNBits(uint8_t symbol) const; // Estimation of the size of the destination buffer needed to store the compressed data. // destination of this size must be passed to Encode(). size_t CompressedBound(size_t src_size) const; private: using HUF_CElt = size_t; std::unique_ptr huf_ctable_; unsigned table_max_symbol_ = 0; uint8_t num_bits_ = 0; }; class HuffmanDecoder { public: bool Load(std::string_view binary_data, std::string* error_msg); bool valid() const { return bool(huf_dtable_); } // decoded_size should be the *precise* size of the decoded data, otherwise the function will // fail. dest should point to a buffer of at least decoded_size bytes. // Returns true if decompression was successful, false if the data is corrupted. bool Decode(std::string_view src, size_t decoded_size, char* dest) const; private: using HUF_DTable = uint32_t; std::unique_ptr huf_dtable_; }; } // namespace dfly ================================================ FILE: src/core/intent_lock.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #pragma once namespace dfly { // SHARED - can be acquired multiple times as long as other intents are absent. // EXCLUSIVE - is acquired only if it's the only lock recorded. // Transactions at the head of tx-queue are considered to be the ones that acquired the lock class IntentLock { public: enum Mode { SHARED = 0, EXCLUSIVE = 1 }; // Returns true if lock was acquired. In any case, the intent is recorded. bool Acquire(Mode m) { ++cnt_[m]; if (cnt_[1 ^ int(m)]) return false; return m == SHARED || cnt_[EXCLUSIVE] == 1; } // Returns true if lock can be acquired using `m` mode. bool Check(Mode m) const { unsigned s = cnt_[EXCLUSIVE]; if (s) return false; return (m == SHARED) ? true : cnt_[SHARED] == 0; } // Returns true if this lock would block transactions from running unless they are at the head // of the transaction queue (first ones) bool IsContended() const { return (cnt_[EXCLUSIVE] > 1) || (cnt_[EXCLUSIVE] == 1 && cnt_[SHARED] > 0); } // A heuristic function to estimate the contention amount with a single score. unsigned ContentionScore() const { return cnt_[EXCLUSIVE] * 256 + cnt_[SHARED]; } void Release(Mode m, unsigned val = 1) { assert(cnt_[m] >= val); cnt_[m] -= val; // return cnt_[m] == 0 ? cnt_[1 ^ int(m)] : 0; } bool IsFree() const { return (cnt_[0] | cnt_[1]) == 0; } static const char* ModeName(Mode m); void VerifyDebug(); friend std::ostream& operator<<(std::ostream& o, const IntentLock& lock) { return o << "{SHARED: " << lock.cnt_[0] << ", EXCLUSIVE: " << lock.cnt_[1] << "}"; } private: unsigned cnt_[2] = {0, 0}; }; } // namespace dfly ================================================ FILE: src/core/interpreter.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/interpreter.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include "base/flags.h" #include "core/interpreter_polyfill.h" #include "overloaded.h" extern "C" { #include #include #include #include "redis/sds.h" #include "redis/util.h" LUALIB_API int(luaopen_cjson)(lua_State* L); LUALIB_API int(luaopen_struct)(lua_State* L); LUALIB_API int(luaopen_cmsgpack)(lua_State* L); LUALIB_API int(luaopen_bit)(lua_State* L); } #include #include "base/logging.h" struct LuaGcGen { int minormul = 20; int majormul = 100; }; struct LuaGcInc { int pause = 200; int stepmul = 100; int stepsize = 13; }; using LuaGcFlag = std::variant; ABSL_FLAG(LuaGcFlag, luagc, {}, "Specifies Lua garabage collector preferences. By default used default lua GC parameters." "Format should be 'inc/200/100/13' or 'gen/20/100' where 'inc' and 'gen' are types of " "GC, numbers are parameters." "For more information check https://www.lua.org/manual/5.4/manual.html#2.5"); ABSL_FLAG(uint64_t, lua_mem_gc_threshold, 10000000, "Specifies Lua interpreter's per thread memory limit in bytes after which the GC will be " "called forcefully. 0 value remove forced GC calls"); ABSL_FLAG(bool, lua_enable_redis_log, false, "Enable redis.log to write logs from lua script."); static bool AbslParseFlag(std::string_view in, LuaGcFlag* flag, std::string* err) { if (in.empty()) { *flag = LuaGcFlag{}; return true; } std::vector parts = absl::StrSplit(in, '/'); if (parts.size() == 3) { if (parts[0] == "gen") { LuaGcGen args; if (absl::SimpleAtoi(parts[1], &args.minormul) && absl::SimpleAtoi(parts[2], &args.majormul)) { *flag = args; return true; } } } else if (parts.size() == 4) { if (parts[0] == "inc") { LuaGcInc args; if (absl::SimpleAtoi(parts[1], &args.pause) && absl::SimpleAtoi(parts[2], &args.stepmul) && absl::SimpleAtoi(parts[3], &args.stepsize)) { *flag = LuaGcFlag{args}; return true; } } } *err = absl::StrCat("Invalid luagc flag parameters"); return false; } static std::string AbslUnparseFlag(const LuaGcFlag& flag) { return std::visit(dfly::Overloaded{ [](std::monostate) { return std::string(); }, [](const LuaGcGen& gen) { return absl::StrCat("gen", "/", gen.minormul, "/", gen.majormul); }, [](const LuaGcInc& inc) { return absl::StrCat("inc", "/", inc.pause, "/", inc.stepmul, "/", inc.stepsize); }, }, flag); } namespace dfly { using namespace std; namespace { // EVP_Q_digest is not present in the older versions of OpenSSL. int EVPDigest(const void* data, size_t datalen, unsigned char* md, size_t* mdlen) { unsigned int temp = 0; int ret = EVP_Digest(data, datalen, md, &temp, EVP_sha1(), NULL); if (mdlen != NULL) *mdlen = temp; return ret; } /* This function is used in order to push an error on the Lua stack in the * format used by redis.pcall to return errors, which is a lua table * with a single "err" field set to the error string. Note that this * table is never a valid reply by proper commands, since the returned * tables are otherwise always indexed by integers, never by strings. */ void PushError(lua_State* lua, string_view error, bool trace = true) { lua_Debug dbg; lua_newtable(lua); lua_pushstring(lua, "err"); /* Attempt to figure out where this function was called, if possible */ if (trace && lua_getstack(lua, 1, &dbg) && lua_getinfo(lua, "nSl", &dbg)) { string msg = absl::StrCat(dbg.source, ": ", dbg.currentline, ": ", error); lua_pushlstring(lua, msg.c_str(), msg.size()); } else { lua_pushlstring(lua, error.data(), error.size()); } lua_settable(lua, -3); } // Custom object explorer that collects all values into string array struct StringCollectorTranslator : public ObjectExplorer { void OnString(std::string_view str) final { values.emplace_back(str); } void OnArrayStart(unsigned len) final { // if values is n't empty it means we can not predict the needed size so reserve can // significantly decrease performance if (values.empty()) { values.reserve(len); } } void OnArrayEnd() final { } void OnBool(bool b) final { OnString(absl::AlphaNum(b).Piece()); } void OnDouble(double d) final { OnString(absl::AlphaNum(d).Piece()); } void OnInt(int64_t val) final { OnString(absl::AlphaNum(val).Piece()); } void OnNil() final { OnString(""); } void OnStatus(std::string_view str) final { OnString(str); } void OnError(std::string_view str) final { LOG(ERROR) << str; } vector values; }; class RedisTranslator : public ObjectExplorer { public: RedisTranslator(lua_State* lua) : lua_(lua) { } void OnBool(bool b) final; void OnString(std::string_view str) final; void OnDouble(double d) final; void OnInt(int64_t val) final; void OnArrayStart(unsigned len) final; void OnArrayEnd() final; void OnNil() final; void OnStatus(std::string_view str) final; void OnError(std::string_view str) final; bool HasError(); private: void ArrayPre() { } void ArrayPost() { if (!array_index_.empty()) { lua_rawseti(lua_, -2, array_index_.back()++); /* set table at key `i' */ } } lua_State* lua_; bool has_error_{false}; vector array_index_{}; }; void RedisTranslator::OnBool(bool b) { CHECK(!b) << "Only false (nil) supported"; ArrayPre(); lua_pushboolean(lua_, 0); ArrayPost(); } void RedisTranslator::OnString(std::string_view str) { ArrayPre(); lua_pushlstring(lua_, str.data(), str.size()); ArrayPost(); } void RedisTranslator::OnDouble(double d) { const double kConvertEps = std::numeric_limits::epsilon(); double fractpart, intpart; fractpart = modf(d, &intpart); ArrayPre(); // Convert to integer when possible to allow converting to string without trailing zeros. if (abs(fractpart) < kConvertEps && intpart < double(std::numeric_limits::max()) && intpart > std::numeric_limits::min()) lua_pushinteger(lua_, static_cast(d)); else lua_pushnumber(lua_, d); ArrayPost(); } void RedisTranslator::OnInt(int64_t val) { ArrayPre(); lua_pushinteger(lua_, val); ArrayPost(); } void RedisTranslator::OnNil() { ArrayPre(); lua_pushboolean(lua_, 0); ArrayPost(); } void RedisTranslator::OnStatus(std::string_view str) { CHECK(array_index_.empty()) << "unexpected status"; lua_createtable(lua_, 0, 1); lua_pushstring(lua_, "ok"); lua_pushlstring(lua_, str.data(), str.size()); lua_settable(lua_, -3); } void RedisTranslator::OnError(std::string_view str) { has_error_ = true; PushError(lua_, str, false); } void RedisTranslator::OnArrayStart(unsigned len) { ArrayPre(); lua_createtable(lua_, len, 0); array_index_.push_back(1); } void RedisTranslator::OnArrayEnd() { CHECK(!array_index_.empty()); DCHECK(lua_istable(lua_, -1)); array_index_.pop_back(); ArrayPost(); } bool RedisTranslator::HasError() { return has_error_; } void RunSafe(lua_State* lua, string_view buf, const char* name) { CHECK_EQ(0, luaL_loadbuffer(lua, buf.data(), buf.size(), name)); int err = lua_pcall(lua, 0, 0, 0); if (err) { const char* errstr = lua_tostring(lua, -1); LOG(FATAL) << "Error running " << name << " " << errstr; } } void Require(lua_State* lua, const char* name, lua_CFunction openf) { luaL_requiref(lua, name, openf, 1); lua_pop(lua, 1); /* remove lib */ } string_view TopSv(lua_State* lua) { return string_view{lua_tostring(lua, -1), lua_rawlen(lua, -1)}; } optional FetchKey(lua_State* lua, const char* key) { lua_pushcfunction(lua, [](lua_State* lua) -> int { lua_gettable(lua, -3); return 1; }); lua_pushstring(lua, key); int status = lua_pcall(lua, 1, 1, 0); if (status != LUA_OK) { lua_pop(lua, 1); return nullopt; } int type = lua_type(lua, -1); if (type == LUA_TNIL) { lua_pop(lua, 1); return nullopt; } return type; } void SetGlobalArrayInternal(lua_State* lua, const char* name, Interpreter::SliceSpan args) { lua_createtable(lua, args.size(), 0); for (size_t j = 0; j < args.size(); j++) { lua_pushlstring(lua, args[j].data(), args[j].size()); lua_rawseti(lua, -2, j + 1); } lua_setglobal(lua, name); } /* In case the error set into the Lua stack by PushError() was generated * by the non-error-trapping version of redis.pcall(), which is redis.call(), * this function will raise the Lua error so that the execution of the * script will be halted. * This function never returns, it unwinds the Lua call stack until an error handler is found or the * script exits */ int RaiseErrorAndAbort(lua_State* lua) { lua_pushstring(lua, "err"); lua_gettable(lua, -2); return lua_error(lua); } void LoadLibrary(lua_State* lua, const char* libname, lua_CFunction luafunc) { lua_pushcfunction(lua, luafunc); lua_pushstring(lua, libname); lua_call(lua, 1, 0); } void InitLua(lua_State* lua) { Require(lua, "", luaopen_base); Require(lua, LUA_TABLIBNAME, luaopen_table); Require(lua, LUA_STRLIBNAME, luaopen_string); Require(lua, LUA_MATHLIBNAME, luaopen_math); Require(lua, LUA_DBLIBNAME, luaopen_debug); LoadLibrary(lua, "cjson", luaopen_cjson); LoadLibrary(lua, "struct", luaopen_struct); LoadLibrary(lua, "cmsgpack", luaopen_cmsgpack); LoadLibrary(lua, "bit", luaopen_bit); /* Add a helper function we use for pcall error reporting. * Note that when the error is in the C function we want to report the * information about the caller, that's what makes sense from the point * of view of the user debugging a script. */ { const char errh_func[] = "local dbg = debug\n" "function __redis__err__handler(err)\n" " local i = dbg.getinfo(2,'nSl')\n" " if i and i.what == 'C' then\n" " i = dbg.getinfo(3,'nSl')\n" " end\n" " if i then\n" " return i.source .. ':' .. i.currentline .. ': ' .. err\n" " else\n" " return err\n" " end\n" "end\n"; RunSafe(lua, errh_func, "@err_handler_def"); } { const char code[] = R"( local dbg=debug local mt = {} setmetatable(_G, mt) mt.__newindex = function (t, n, v) if dbg.getinfo(2) then local w = dbg.getinfo(2, "S").what if w ~= "main" and w ~= "C" then error("Script attempted to create global variable '"..tostring(n).."'", 2) end end rawset(t, n, v) end mt.__index = function (t, n) if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2) end return rawget(t, n) end debug = nil )"; RunSafe(lua, code, "@enable_strict_lua"); } lua_pushnil(lua); lua_setglobal(lua, "loadfile"); lua_pushnil(lua); lua_setglobal(lua, "dofile"); // Register deprecated or removed functions to maintain compatibility with 5.1 register_polyfills(lua); } // dest must have at least 41 chars. void ToHex(const uint8_t* src, char* dest) { const char cset[] = "0123456789abcdef"; for (size_t j = 0; j < 20; j++) { dest[j * 2] = cset[((src[j] & 0xF0) >> 4)]; dest[j * 2 + 1] = cset[(src[j] & 0xF)]; } dest[40] = '\0'; } int DragonflyHashCommand(lua_State* lua) { XXH64_hash_t hash = absl::bit_cast(lua_tointeger(lua, 1)); bool requires_sort = lua_toboolean(lua, 2); // Pop first two arguments to call RedisGenericCommand from this function with tail lua_remove(lua, 1); lua_remove(lua, 1); // Compute key hash; for MGET hash all key arguments, otherwise just the first { size_t cmd_len; const char* cmd = lua_tolstring(lua, 1, &cmd_len); int top = lua_gettop(lua); int key_end = absl::EqualsIgnoreCase(absl::string_view(cmd, cmd_len), "mget") ? top : 2; for (int i = 2; i <= key_end; ++i) { size_t len; const char* key = lua_tolstring(lua, i, &len); hash = XXH64(key, len, hash); } } // Collect output into custom string collector StringCollectorTranslator translator; void** ptr = static_cast(lua_getextraspace(lua)); reinterpret_cast(*ptr)->RedisGenericCommand(false, false, &translator); if (requires_sort) sort(translator.values.begin(), translator.values.end()); // Compute new hash and return it for (string_view str : translator.values) hash = XXH64(str.data(), str.size(), hash); lua_pushinteger(lua, absl::bit_cast(hash)); return 1; } int DragonflyRandstrCommand(lua_State* state) { int argc = lua_gettop(state); lua_Integer dsize = lua_tonumber(state, 1); lua_remove(state, 1); std::string buf(dsize, ' '); auto push_str = [dsize, state, &buf]() { static const char alphanum[] = "0123456789" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz"; static const char pattern[] = "DRAGONFLY"; constexpr int pattern_len = sizeof(pattern) - 1; constexpr int pattern_interval = 53; for (int i = 0; i < dsize; ++i) { if (i % pattern_interval == 0 && i + pattern_len <= dsize) { // Insert the repeating pattern for better compression of random string. buf.replace(i, pattern_len, pattern, pattern_len); i += pattern_len - 1; // Adjust index to skip the pattern } else { // Fill the rest with semi-random characters for variation buf[i] = alphanum[rand() % (sizeof(alphanum) - 1)]; } } lua_pushlstring(state, buf.c_str(), buf.length()); }; if (argc == 1) { push_str(); } else { lua_Integer num = lua_tonumber(state, 1); lua_createtable(state, num, 0); for (int i = 1; i <= num; i++) { push_str(); lua_rawseti(state, -2, i); } } return 1; } int RedisSha1Command(lua_State* lua) { int argc = lua_gettop(lua); if (argc != 1) { lua_pushstring(lua, "wrong number of arguments"); return lua_error(lua); } size_t len; const char* s = lua_tolstring(lua, 1, &len); uint8_t digest[EVP_MAX_MD_SIZE]; EVPDigest(s, len, digest, NULL); char hex[41]; ToHex(digest, hex); lua_pushstring(lua, hex); return 1; } /* Returns a table with a single field 'field' set to the string value * passed as argument. This helper function is handy when returning * a Redis Protocol error or status reply from Lua: * * return redis.error_reply("ERR Some Error") * return redis.status_reply("ERR Some Error") */ int SingleFieldTable(lua_State* lua, const char* field) { if (lua_gettop(lua) != 1 || lua_type(lua, -1) != LUA_TSTRING) { PushError(lua, "wrong number or type of arguments"); return 1; } lua_newtable(lua); lua_pushstring(lua, field); lua_pushvalue(lua, -3); lua_settable(lua, -3); return 1; } int RedisErrorReplyCommand(lua_State* lua) { return SingleFieldTable(lua, "err"); } int RedisStatusReplyCommand(lua_State* lua) { return SingleFieldTable(lua, "ok"); } // no-op int RedisReplicateCommands(lua_State* lua) { lua_pushinteger(lua, 1); // number of results (the number of elements pushed to the lua stack return 1; } int RedisLogCommand(lua_State* lua) { int j, argc = lua_gettop(lua); sds log; if (argc < 2) { PushError(lua, "redis.log() requires two arguments or more."); return RaiseErrorAndAbort(lua); } else if (!lua_isnumber(lua, -argc)) { PushError(lua, "First argument must be a number (log level)."); return RaiseErrorAndAbort(lua); } if (absl::GetFlag(FLAGS_lua_enable_redis_log)) { int level = lua_tonumber(lua, -argc); if (level < LL_DEBUG || level > LL_WARNING) { PushError(lua, "Invalid log level."); return RaiseErrorAndAbort(lua); } /* Glue together all the arguments */ log = sdsempty(); for (j = 1; j < argc; j++) { size_t len; char* s; s = (char*)lua_tolstring(lua, (-argc) + j, &len); if (s) { if (j != 1) log = sdscatlen(log, " ", 1); log = sdscatlen(log, s, len); } } switch (level) { case LL_DEBUG: case LL_VERBOSE: VLOG(1) << log; break; case LL_NOTICE: LOG(INFO) << log; break; case LL_WARNING: LOG(WARNING) << log; default: break; } sdsfree(log); } return 0; } // See https://www.lua.org/manual/5.3/manual.html#lua_Alloc void* mimalloc_glue(void* ud, void* ptr, size_t osize, size_t nsize) { int64_t& used_bytes = *static_cast(ud); if (nsize == 0) { used_bytes -= mi_usable_size(ptr); mi_free_size(ptr, osize); return nullptr; } else if (ptr == nullptr) { ptr = mi_malloc(nsize); used_bytes += mi_usable_size(ptr); return ptr; } else { const auto old_size = mi_usable_size(ptr); ptr = mi_realloc(ptr, nsize); if (ptr) { used_bytes -= old_size; used_bytes += mi_usable_size(ptr); } return ptr; } } } // namespace Interpreter::Interpreter() { InterpreterManager::tl_stats().interpreter_cnt++; // interpreter can be runnned in different threads so we need to calculate // used memory via &used_bytes_ additional parameter lua_ = lua_newstate(mimalloc_glue, &used_bytes_); InitLua(lua_); void** ptr = static_cast(lua_getextraspace(lua_)); *ptr = this; // SaveOnRegistry(lua_, kInstanceKey, this); /* Register the dragonfly commands table and fields */ lua_newtable(lua_); /* dragonfly.ihash - compute quick integer hash of command result */ lua_pushstring(lua_, "ihash"); lua_pushcfunction(lua_, DragonflyHashCommand); lua_settable(lua_, -3); /* dragonfly.randstr - generate random string or table of random strings */ lua_pushstring(lua_, "randstr"); lua_pushcfunction(lua_, DragonflyRandstrCommand); lua_settable(lua_, -3); /* Finally set the table as 'dragonfly' global var. */ lua_setglobal(lua_, "dragonfly"); CHECK(lua_checkstack(lua_, 64)); /* Register the redis commands table and fields */ lua_newtable(lua_); /* redis.call */ lua_pushstring(lua_, "call"); lua_pushcfunction(lua_, RedisCallCommand); lua_settable(lua_, -3); /* redis.pcall */ lua_pushstring(lua_, "pcall"); lua_pushcfunction(lua_, RedisPCallCommand); lua_settable(lua_, -3); /* redis.acall */ lua_pushstring(lua_, "acall"); lua_pushcfunction(lua_, RedisACallCommand); lua_settable(lua_, -3); /* redis.apcall */ lua_pushstring(lua_, "apcall"); lua_pushcfunction(lua_, RedisAPCallCommand); lua_settable(lua_, -3); lua_pushstring(lua_, "sha1hex"); lua_pushcfunction(lua_, RedisSha1Command); lua_settable(lua_, -3); /* redis.error_reply and redis.status_reply */ lua_pushstring(lua_, "error_reply"); lua_pushcfunction(lua_, RedisErrorReplyCommand); lua_settable(lua_, -3); lua_pushstring(lua_, "status_reply"); lua_pushcfunction(lua_, RedisStatusReplyCommand); lua_settable(lua_, -3); /* no-op functions */ /* redis.replicate_commands*/ lua_pushstring(lua_, "replicate_commands"); lua_pushcfunction(lua_, RedisReplicateCommands); lua_settable(lua_, -3); /* redis.log*/ lua_pushstring(lua_, "log"); lua_pushcfunction(lua_, RedisLogCommand); lua_settable(lua_, -3); lua_pushinteger(lua_, LL_DEBUG); lua_setfield(lua_, -2, "LOG_DEBUG"); lua_pushinteger(lua_, LL_VERBOSE); lua_setfield(lua_, -2, "LOG_VERBOSE"); lua_pushinteger(lua_, LL_NOTICE); lua_setfield(lua_, -2, "LOG_NOTICE"); lua_pushinteger(lua_, LL_WARNING); lua_setfield(lua_, -2, "LOG_WARNING"); /* Finally set the table as 'redis' global var. */ lua_setglobal(lua_, "redis"); CHECK(lua_checkstack(lua_, 64)); UpdateGCParameters(); } Interpreter::~Interpreter() { InterpreterManager::tl_stats().interpreter_cnt--; lua_close(lua_); } void Interpreter::FuncSha1(string_view body, char* fp) { uint8_t digest[EVP_MAX_MD_SIZE]; EVPDigest(body.data(), body.size(), digest, NULL); ToHex(digest, fp); } auto Interpreter::AddFunction(string_view sha, string_view body, string* result) -> AddResult { char funcname[43]; funcname[0] = 'f'; funcname[1] = '_'; DCHECK(sha.size() == 40); memcpy(funcname + 2, sha.data(), sha.size()); funcname[42] = '\0'; int type = lua_getglobal(lua_, funcname); lua_pop(lua_, 1); if (type == LUA_TNIL && !AddInternal(funcname, body, result)) return COMPILE_ERR; return type == LUA_TNIL ? ADD_OK : ALREADY_EXISTS; } bool Interpreter::Exists(string_view sha) const { DCHECK(lua_); if (sha.size() != 40) return false; char fname[43]; fname[0] = 'f'; fname[1] = '_'; fname[42] = '\0'; memcpy(fname + 2, sha.data(), 40); int type = lua_getglobal(lua_, fname); lua_pop(lua_, 1); return type == LUA_TFUNCTION; } auto Interpreter::RunFunction(string_view sha, std::string* error) -> RunResult { DVLOG(2) << "RunFunction " << sha << " " << lua_gettop(lua_); DCHECK_EQ(40u, sha.size()); lua_getglobal(lua_, "__redis__err__handler"); char fname[43]; fname[0] = 'f'; fname[1] = '_'; memcpy(fname + 2, sha.data(), 40); fname[42] = '\0'; int type = lua_getglobal(lua_, fname); if (type != LUA_TFUNCTION) { lua_pop(lua_, 2); return NOT_EXISTS; } // At this point lua stack has 2 globals. /* We have zero arguments and expect * a single return value. */ int err = lua_pcall(lua_, 0, 1, -2); if (err) { *error = lua_tostring(lua_, -1); } return err == 0 ? RUN_OK : RUN_ERR; } void Interpreter::SetGlobalArray(const char* name, SliceSpan args) { SetGlobalArrayInternal(lua_, name, args); } optional Interpreter::DetectPossibleAsyncCalls(string_view body_sv) { // We want to detect `redis.call` expressions with unused return values, i.e. they are a // standalone statement, not part of a expression, condition, function call or assignment. // // We search for all `redis.(p)call` statements, that are preceeded on the same line by // - `do` or `then` -> first statement in a new block, certainly unused value // - no tokens -> we need to check the previous line, if its part of a multi-line expression. // // If we need to check the previous line, we search for the last word (before comments, if it has // one). static const regex kRegex{"(?:(\\S+)(\\s*--.*?)*\\s*\n|(then)|(do)|(^))\\s*redis\\.(p*call)"}; // Taken from https://www.lua.org/manual/5.4/manual.html - 3.1 - Lexical conventions // If a line ends with it, then most likely the next line belongs to it as well static const set kContOperators = { "+", "-", "*", "/", "%", "^", "#", "&", "~", "|", "<<", ">>", "//", "==", "~=", "<=", ">=", "<", ">", "=", "(", "{", "[", "::", ":", ",", ".", ".."}; // If a line ends with it, then most likely the next line belongs to it as well static const set kContTokens = {"and", "else", "elseif", "for", "goto", "if", "in", "local", "not", "or", "repeat", "return", "until", "while"}; auto last_n = [](const string& s, size_t n) { return s.size() < n ? s : s.substr(s.size() - n, n); }; smatch sm; string body{body_sv}; vector targets; // We don't handle comment blocks yet. if (body.find("--[[") != string::npos) return {}; sregex_iterator it{body.begin(), body.end(), kRegex}; sregex_iterator end{}; for (; it != end; it++) { auto last_word = it->str(1); if (kContOperators.count(last_n(last_word, 2)) > 0 || kContOperators.count(last_n(last_word, 1)) > 0) continue; if (kContTokens.count(last_word) > 0) continue; targets.push_back(it->position(it->size() - 1)); } if (targets.empty()) return nullopt; // Insert 'a' before 'call' and 'pcall'. Reverse order to preserve positions reverse(targets.begin(), targets.end()); body.reserve(body.size() + targets.size()); for (auto pos : targets) body.insert(pos, "a"); VLOG(1) << "Detected " << targets.size() << " aync calls in script"; return body; } bool Interpreter::IsResultSafe() const { int top = lua_gettop(lua_); if (top >= 128) return false; int t = lua_type(lua_, -1); if (t != LUA_TTABLE) return true; bool res = IsTableSafe(); // Stack can contain intermediate unwindings that were not clean up. DCHECK_GE(lua_gettop(lua_), top); lua_settop(lua_, top); // restore to the original setting. return res; } bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) { string script = absl::StrCat("function ", f_id, "() \n"); absl::StrAppend(&script, body, "\nend"); int res = luaL_loadbuffer(lua_, script.data(), script.size(), "@user_script"); if (res == 0) { res = lua_pcall(lua_, 0, 0, 0); // run func definition code } if (res) { error->assign(lua_tostring(lua_, -1)); lua_pop(lua_, 1); // Remove the error. return false; } return true; } // Stack is cleaned for us, we can leave it dirty bool Interpreter::IsTableSafe() const { auto fres = FetchKey(lua_, "err"); if (fres && *fres == LUA_TSTRING) { return true; } fres = FetchKey(lua_, "ok"); if (fres && *fres == LUA_TSTRING) { return true; } // Copy root table because we remove it upon finishing traversal lua_pushnil(lua_); lua_copy(lua_, -2, -1); int depth = 1; lua_pushnil(lua_); // DFS based on lua stack: [parent-table] [parent-key] [parent-value = table] [key] while (depth > 0) { if (lua_checkstack(lua_, 3) == 0 || depth > 128) return false; bool descending = false; for (; lua_next(lua_, -2) != 0; lua_pop(lua_, 1)) { if (lua_type(lua_, -1) != LUA_TTABLE) continue; // If we descend, keep value as new table and push nil for start key depth++; lua_pushnil(lua_); descending = true; break; } if (!descending) { lua_pop(lua_, 1); depth--; } } return true; } void Interpreter::SerializeResult(ObjectExplorer* serializer) { int t = lua_type(lua_, -1); switch (t) { case LUA_TSTRING: serializer->OnString(TopSv(lua_)); break; case LUA_TBOOLEAN: serializer->OnBool(lua_toboolean(lua_, -1)); break; case LUA_TNUMBER: if (lua_isinteger(lua_, -1)) { serializer->OnInt(lua_tointeger(lua_, -1)); } else { serializer->OnDouble(lua_tonumber(lua_, -1)); } break; case LUA_TTABLE: { auto fres = FetchKey(lua_, "err"); if (fres && *fres == LUA_TSTRING) { serializer->OnError(TopSv(lua_)); lua_pop(lua_, 1); break; } fres = FetchKey(lua_, "ok"); if (fres && *fres == LUA_TSTRING) { serializer->OnStatus(TopSv(lua_)); lua_pop(lua_, 1); break; } fres = FetchKey(lua_, "map"); if (fres && *fres == LUA_TTABLE) { // Calculate length of map part, there is sadly no other way unsigned len = 0; for (lua_pushnil(lua_); lua_next(lua_, -2) != 0; lua_pop(lua_, 1)) len++; serializer->OnMapStart(len); for (lua_pushnil(lua_); lua_next(lua_, -2) != 0;) { // Push key to stack top: key value key lua_pushnil(lua_); lua_copy(lua_, -3, -1); SerializeResult(serializer); // pops key SerializeResult(serializer); // pop value } serializer->OnMapEnd(); lua_pop(lua_, 2); break; } unsigned len = lua_rawlen(lua_, -1); serializer->OnArrayStart(len); for (unsigned i = 0; i < len; ++i) { t = lua_rawgeti(lua_, -1, i + 1); // push table element // TODO: we should make sure that we have enough stack space // to traverse each object. This can be done as a dry-run before doing real serialization. // Once we are sure we are safe we can simplify the serialization flow and // remove the error factor. SerializeResult(serializer); // pops the element } serializer->OnArrayEnd(); break; } case LUA_TNIL: serializer->OnNil(); break; default: LOG(ERROR) << "Unsupported type " << lua_typename(lua_, t); serializer->OnNil(); } lua_pop(lua_, 1); } void Interpreter::ResetStack() { lua_settop(lua_, 0); } int64_t Interpreter::RunGC() { int64_t before_kb = lua_gc(lua_, LUA_GCCOUNT); lua_gc(lua_, LUA_GCCOLLECT); int64_t after_kb = lua_gc(lua_, LUA_GCCOUNT); LOG_IF(DFATAL, after_kb > before_kb) << "LUA_GCCOLLECT increase memory consumption from " << before_kb << "kB to " << after_kb << "kB"; int64_t res = (before_kb - after_kb) * 1024; return std::max(int64_t(0), res); } void Interpreter::UpdateGCParameters() { auto gc = absl::GetFlag(FLAGS_luagc); std::visit(dfly::Overloaded{ [](std::monostate) {}, [&](const LuaGcGen& gen) { lua_gc(lua_, LUA_GCGEN, gen.minormul, gen.majormul); }, [&](const LuaGcInc& inc) { lua_gc(lua_, LUA_GCINC, inc.pause, inc.stepmul, inc.stepsize); }, }, gc); } std::optional> Interpreter::PrepareArgs() { int argc = lua_gettop(lua_); /* Require at least one argument */ if (argc == 0) { PushError(lua_, "Please specify at least one argument for redis.call()"); return std::nullopt; } size_t blob_len = 0; char tmpbuf[64]; // Determine size required for backing storage for all args. // Skip command name (idx=1), as its stored in a separate buffer. for (int idx = 2; idx <= argc; idx++) { switch (lua_type(lua_, idx)) { case LUA_TNUMBER: if (lua_isinteger(lua_, idx)) { blob_len += absl::AlphaNum(lua_tointeger(lua_, idx)).size(); } else { int fmt_len = absl::SNPrintF(tmpbuf, sizeof(tmpbuf), "%.17g", lua_tonumber(lua_, idx)); CHECK_GT(fmt_len, 0); blob_len += fmt_len; } continue; case LUA_TSTRING: blob_len += lua_rawlen(lua_, idx) + 1; continue; default: PushError(lua_, "Lua redis() command arguments must be strings or integers"); return std::nullopt; } } absl::FixedArray args(argc); // Copy command name to name_buffer and set it as first arg. unsigned name_len = lua_rawlen(lua_, 1); if (name_len >= sizeof(name_buffer_)) { PushError(lua_, "Lua redis() command name too long"); return std::nullopt; } memcpy(name_buffer_, lua_tostring(lua_, 1), name_len); args[0] = {name_buffer_, name_len}; buffer_.resize(blob_len + 4, '\0'); // backing storage for args char* cur = buffer_.data(); char* end = cur + blob_len; for (int idx = 2; idx <= argc; idx++) { size_t len = 0; switch (lua_type(lua_, idx)) { case LUA_TNUMBER: if (lua_isinteger(lua_, idx)) { char* next = absl::numbers_internal::FastIntToBuffer(lua_tointeger(lua_, idx), cur); len = next - cur; } else if (lua_isnumber(lua_, idx)) { // we pass `end - cur + 1` because we do not want to skip the last character // if it's the last argument. int fmt_len = absl::SNPrintF(cur, end - cur + 1, "%.17g", lua_tonumber(lua_, idx)); CHECK_GT(fmt_len, 0); len = fmt_len; } break; case LUA_TSTRING: len = lua_rawlen(lua_, idx); memcpy(cur, lua_tostring(lua_, idx), len + 1); // + 1 for null terminator }; args[idx - 1] = {cur, len}; cur += len; } /* Pop all arguments from the stack, we do not need them anymore * and this way we guaranty we will have room on the stack for the result. */ lua_pop(lua_, argc); return args; } // Calls redis function // Returns false if error needs to be raised. bool Interpreter::CallRedisFunction(bool raise_error, bool async, ObjectExplorer* explorer, SliceSpan args) { // Calling with custom explorer is not supported with errors or async DCHECK(explorer == nullptr || (!raise_error && !async)); // If no custom explorer is set, use default translator optional translator; if (explorer == nullptr) { translator.emplace(lua_); explorer = &*translator; } cmd_depth_++; redis_func_(CallArgs{args, &buffer_, explorer, async, raise_error, &raise_error}); cmd_depth_--; // Shrink reusable buffer if it's too big. if (buffer_.capacity() > 128) { buffer_.clear(); buffer_.shrink_to_fit(); } if (!translator) return true; // Raise error for regular 'call' command if needed. if (raise_error && translator->HasError()) { // error is already on top of stack return false; } if (!async) DCHECK_EQ(1, lua_gettop(lua_)); return true; } // Returns number of results, which is always 1 in this case. // Please note that lua resets the stack once the function returns so no need // to unwind the stack manually in the function (though lua allows doing this). int Interpreter::RedisGenericCommand(bool raise_error, bool async, ObjectExplorer* explorer) { /* By using Lua debug hooks it is possible to trigger a recursive call * to luaRedisGenericCommand(), which normally should never happen. * To make this function reentrant is futile and makes it slower, but * we should at least detect such a misuse, and abort. */ if (cmd_depth_) { const char* recursion_warning = "luaRedisGenericCommand() recursive call detected. " "Are you doing funny stuff with Lua debug hooks?"; PushError(lua_, recursion_warning); return 1; } if (!redis_func_) { PushError(lua_, "internal error - redis function not defined"); if (raise_error) { return RaiseErrorAndAbort(lua_); } return 1; } // IMPORTANT! all allocations within this funciton must be freed // BEFORE calling RaiseErrorAndAbort in case of script error. RaiseErrorAndAbort // uses longjmp which bypasses stack unwinding and skips the destruction of objects. { std::optional> args = PrepareArgs(); if (args.has_value()) { raise_error = !CallRedisFunction(raise_error, async, explorer, SliceSpan{*args}); } } if (!raise_error) { return 1; } return RaiseErrorAndAbort(lua_); // this function never returns, it unwinds the Lua call stack } int Interpreter::RedisCallCommand(lua_State* lua) { void** ptr = static_cast(lua_getextraspace(lua)); return reinterpret_cast(*ptr)->RedisGenericCommand(true, false); } int Interpreter::RedisPCallCommand(lua_State* lua) { void** ptr = static_cast(lua_getextraspace(lua)); return reinterpret_cast(*ptr)->RedisGenericCommand(false, false); } int Interpreter::RedisACallCommand(lua_State* lua) { void** ptr = static_cast(lua_getextraspace(lua)); return reinterpret_cast(*ptr)->RedisGenericCommand(true, true); } int Interpreter::RedisAPCallCommand(lua_State* lua) { void** ptr = static_cast(lua_getextraspace(lua)); return reinterpret_cast(*ptr)->RedisGenericCommand(false, true); } InterpreterManager::Stats& InterpreterManager::Stats::operator+=(const Stats& other) { this->used_bytes += other.used_bytes; this->interpreter_cnt += other.interpreter_cnt; this->blocked_cnt += other.blocked_cnt; this->force_gc_calls += other.force_gc_calls; this->gc_duration_ns += other.gc_duration_ns; this->interpreter_return += other.interpreter_return; this->gc_freed_memory += other.gc_freed_memory; return *this; } InterpreterManager::Stats& InterpreterManager::tl_stats() { static thread_local Stats stats; return stats; } Interpreter* InterpreterManager::Get() { // Grow if none is available and we have unused capacity left. if (available_.empty() && storage_.size() < storage_.capacity()) { storage_.emplace_back(); return &storage_.back(); } bool blocked = waker_.await([this]() { return !available_.empty(); }); tl_stats().blocked_cnt += (uint64_t)blocked; Interpreter* ir = available_.back(); available_.pop_back(); return ir; } void InterpreterManager::Return(Interpreter* ir) { const uint64_t max_memory_usage = absl::GetFlag(FLAGS_lua_mem_gc_threshold); using namespace chrono; ++tl_stats().interpreter_return; tl_stats().used_bytes += ir->TakeUsedBytes(); if (max_memory_usage != 0 && tl_stats().used_bytes > max_memory_usage) { ++tl_stats().force_gc_calls; auto before = steady_clock::now(); tl_stats().gc_freed_memory += ir->RunGC(); VLOG(2) << "stats_used_bytes: " << tl_stats().used_bytes << " lua_mem_gc_threshold: " << max_memory_usage << " force_gc_calls: " << tl_stats().force_gc_calls << " freed_mem: " << tl_stats().gc_freed_memory; auto after = steady_clock::now(); tl_stats().gc_duration_ns += duration_cast(after - before).count(); } if (ir >= storage_.data() && ir < storage_.data() + storage_.size()) { available_.push_back(ir); waker_.notify(); } else if (return_untracked_ > 0) { return_untracked_--; if (return_untracked_ == 0) { reset_ec_.notify(); } } else { LOG(DFATAL) << "Returning untracked interpreter"; } } void InterpreterManager::Reset() { lock_guard guard{reset_mu_}; // we perform double buffer swapping with storage and wait for the old interepreters to be // returned. return_untracked_ = storage_.size() - available_.size(); std::vector next_storage; next_storage.reserve(storage_.capacity()); next_storage.resize(storage_.size()); next_storage.swap(storage_); available_.clear(); for (auto& ir : storage_) { available_.push_back(&ir); } reset_ec_.await([this]() { return return_untracked_ == 0; }); VLOG(1) << "InterpreterManager::Reset ended"; } void InterpreterManager::Alter(std::function modf) { vector taken; swap(taken, available_); // swap data because modf can preempt for (Interpreter* ir : taken) { modf(ir); Return(ir); } } } // namespace dfly ================================================ FILE: src/core/interpreter.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include "util/fibers/synchronization.h" typedef struct lua_State lua_State; namespace dfly { class ObjectExplorer { public: virtual ~ObjectExplorer() = default; virtual void OnBool(bool b) = 0; virtual void OnString(std::string_view str) = 0; virtual void OnDouble(double d) = 0; virtual void OnInt(int64_t val) = 0; virtual void OnArrayStart(unsigned len) = 0; virtual void OnArrayEnd() = 0; virtual void OnNil() = 0; virtual void OnStatus(std::string_view str) = 0; virtual void OnError(std::string_view str) = 0; virtual void OnMapStart(unsigned len) { OnArrayStart(len * 2); } virtual void OnMapEnd() { OnArrayEnd(); } }; class Interpreter { public: using SliceSpan = absl::Span; // Arguments received from redis.call struct CallArgs { // Full arguments, including cmd name. SliceSpan args; // Pointer to backing storage for args (excluding cmd name). // Moving can invalidate arg slice pointers. Moved by async to re-use buffer. std::string* buffer; ObjectExplorer* translator; bool async; // async by acall bool error_abort; // abort on errors (not pcall) // The function can request an abort due to an error, even if error_abort is false. // It happens when async cmds are flushed and result in an uncatched error. bool* requested_abort; }; using RedisFunc = std::function; Interpreter(); ~Interpreter(); Interpreter(const Interpreter&) = delete; void operator=(const Interpreter&) = delete; Interpreter(Interpreter&&) = default; Interpreter& operator=(Interpreter&&) = default; // Note: We leak the state for now. // Production code should not access this method. lua_State* lua() { return lua_; } enum AddResult { ADD_OK = 0, ALREADY_EXISTS = 1, COMPILE_ERR = 2, }; // Add function with sha and body to interpreter. AddResult AddFunction(std::string_view sha, std::string_view body, std::string* error); int64_t TakeUsedBytes() { return std::exchange(used_bytes_, 0); } bool Exists(std::string_view sha) const; enum RunResult { RUN_OK = 0, NOT_EXISTS = 1, RUN_ERR = 2, }; void SetGlobalArray(const char* name, SliceSpan args); // Runs already added function sha returned by a successful call to AddFunction(). // Returns: true if the call succeeded, otherwise fills error and returns false. // sha must be 40 char length. RunResult RunFunction(std::string_view sha, std::string* err); // Checks whether the result is safe to serialize. // Should fit 2 conditions: // 1. Be the only value on the stack. // 2. Should have depth of no more than 128. bool IsResultSafe() const; void SerializeResult(ObjectExplorer* serializer); void ResetStack(); // run gc and returns size of freed memory in bytes int64_t RunGC(); void UpdateGCParameters(); // fp must point to buffer with at least 41 chars. // fp[40] will be set to '\0'. static void FuncSha1(std::string_view body, char* fp); static std::optional DetectPossibleAsyncCalls(std::string_view body); template void SetRedisFunc(U&& u) { redis_func_ = std::forward(u); } // Invoke command with arguments from lua stack, given options and possibly custom explorer int RedisGenericCommand(bool raise_error, bool async, ObjectExplorer* explorer = nullptr); private: // Returns true if function was successfully added, // otherwise returns false and sets the error. bool AddInternal(const char* f_id, std::string_view body, std::string* error); bool IsTableSafe() const; static int RedisCallCommand(lua_State* lua); static int RedisPCallCommand(lua_State* lua); static int RedisACallCommand(lua_State* lua); static int RedisAPCallCommand(lua_State* lua); std::optional> PrepareArgs(); bool CallRedisFunction(bool raise_error, bool async, ObjectExplorer* explorer, SliceSpan args); lua_State* lua_; unsigned cmd_depth_ = 0; RedisFunc redis_func_; std::string buffer_; int64_t used_bytes_ = 0; char name_buffer_[32]; // backing storage for cmd name }; // Manages an internal interpreter pool. This allows multiple connections residing on the same // thread to run multiple lua scripts in parallel. class InterpreterManager { public: struct Stats { Stats& operator+=(const Stats& other); uint64_t used_bytes = 0; uint64_t interpreter_cnt = 0; uint64_t blocked_cnt = 0; uint64_t force_gc_calls = 0; uint64_t gc_duration_ns = 0; uint64_t interpreter_return = 0; int64_t gc_freed_memory = 0; }; public: InterpreterManager(unsigned num) : waker_{}, available_{}, storage_{} { // We pre-allocate the backing storage during initialization and // start storing pointers to slots in the available vector. storage_.reserve(num); } // Borrow interpreter. Always return it after usage. Interpreter* Get(); void Return(Interpreter*); // Clear all interpreters, keeps capacity. Waits until all are returned. void Reset(); // Run on all unused interpreters. Those are marked as used at once, so the callback can preempt void Alter(std::function modf); static Stats& tl_stats(); private: util::fb2::EventCount waker_, reset_ec_; std::vector available_; std::vector storage_; util::fb2::Mutex reset_mu_; // Acts as a singleton. unsigned return_untracked_ = 0; // Number of returned interpreters during reset. }; } // namespace dfly ================================================ FILE: src/core/interpreter_polyfill.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // // This header contains implementations of deprecated, removed or renamed lua functions. #pragma once extern "C" { #include #include #include // TODO: Fix checktab #define aux_getn(L, n, w) (luaL_len(L, n)) LUA_API void lua_len(lua_State* L, int idx); static int polyfill_table_getn(lua_State* L) { lua_len(L, 1); return 1; } static int polyfill_table_setn(lua_State* L) { // From Lua 5.1, ltablib.c luaL_checktype(L, 1, LUA_TTABLE); luaL_error(L, "setn is obsolete"); lua_pushvalue(L, 1); return 1; } static int polyfill_table_foreach(lua_State* L) { // From Lua 5.1, ltablib.c luaL_checktype(L, 1, LUA_TTABLE); luaL_checktype(L, 2, LUA_TFUNCTION); lua_pushnil(L); /* first key */ while (lua_next(L, 1)) { lua_pushvalue(L, 2); /* function */ lua_pushvalue(L, -3); /* key */ lua_pushvalue(L, -3); /* value */ lua_call(L, 2, 1); if (!lua_isnil(L, -1)) return 1; lua_pop(L, 2); /* remove value and result */ } return 0; } static int polyfill_table_foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); // Check type here because aux_getn is stripped // From Lua 5.1, ltablib.c int i; int n = aux_getn(L, 1, 0b11); luaL_checktype(L, 2, LUA_TFUNCTION); for (i = 1; i <= n; i++) { lua_pushvalue(L, 2); /* function */ lua_pushinteger(L, i); /* 1st argument */ lua_rawgeti(L, 1, i); /* 2nd argument */ lua_call(L, 2, 1); if (!lua_isnil(L, -1)) return 1; lua_pop(L, 1); /* remove nil result */ } return 0; } static void register_polyfills(lua_State* lua) { lua_getglobal(lua, "table"); // unpack was a global function until Lua 5.2 lua_getfield(lua, -1, "unpack"); lua_setglobal(lua, "unpack"); // table.getn - removed, length operator # should be used instead lua_pushcfunction(lua, polyfill_table_getn); lua_setfield(lua, -2, "getn"); // table.setn - removed, freely resizing a table is no longer possible lua_pushcfunction(lua, polyfill_table_setn); lua_setfield(lua, -2, "setn"); // table.getn - removed, instead the length operator # should be used lua_pushcfunction(lua, polyfill_table_foreach); lua_setfield(lua, -2, "foreach"); // table.forachi - removed, use for loops should be used instead lua_pushcfunction(lua, polyfill_table_foreachi); lua_setfield(lua, -2, "foreachi"); lua_remove(lua, -1); } } ================================================ FILE: src/core/interpreter_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/interpreter.h" extern "C" { #include #include } #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { using namespace std; class TestSerializer : public ObjectExplorer { public: string res; void OnBool(bool b) final { absl::StrAppend(&res, "bool(", b, ") "); } void OnString(std::string_view str) final { absl::StrAppend(&res, "str(", str, ") "); } void OnDouble(double d) final { absl::StrAppend(&res, "d(", d, ") "); } void OnInt(int64_t val) final { absl::StrAppend(&res, "i(", val, ") "); } void OnArrayStart(unsigned len) final { absl::StrAppend(&res, "["); } void OnArrayEnd() final { if (res.back() == ' ') res.pop_back(); absl::StrAppend(&res, "] "); } void OnNil() final { absl::StrAppend(&res, "nil "); } void OnMapStart(unsigned len) final { absl::StrAppend(&res, "{"); } void OnMapEnd() final { if (res.back() == ' ') res.pop_back(); absl::StrAppend(&res, "} "); } void OnStatus(std::string_view str) { absl::StrAppend(&res, "status(", str, ") "); } void OnError(std::string_view str) { absl::StrAppend(&res, "err(", str, ") "); } }; using SliceSpan = Interpreter::SliceSpan; class InterpreterTest : public ::testing::Test { protected: InterpreterTest() { // configure redis lib zmalloc which requires mimalloc heap to work. auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); } lua_State* lua() { return intptr_.lua(); } void RunInline(string_view buf, const char* name, unsigned num_results = 0) { CHECK_EQ(0, luaL_loadbuffer(lua(), buf.data(), buf.size(), name)); CHECK_EQ(0, lua_pcall(lua(), 0, num_results, 0)); } void SetGlobalArray(const char* name, const vector& vec); // returns true if script run successfully. bool Execute(string_view script); Interpreter intptr_; TestSerializer ser_; string error_; vector> strings_; }; void InterpreterTest::SetGlobalArray(const char* name, const vector& vec) { vector slices(vec.size()); for (size_t i = 0; i < vec.size(); ++i) { strings_.emplace_back(new string(vec[i])); slices[i] = string_view{*strings_.back()}; } intptr_.SetGlobalArray(name, SliceSpan{slices}); } bool InterpreterTest::Execute(string_view script) { char sha_buf[64]; Interpreter::FuncSha1(script, sha_buf); string_view sha{sha_buf, std::strlen(sha_buf)}; string result; Interpreter::AddResult add_res = intptr_.AddFunction(sha, script, &result); if (add_res == Interpreter::COMPILE_ERR) { error_ = result; return false; } Interpreter::RunResult run_res = intptr_.RunFunction(sha, &error_); if (run_res != Interpreter::RUN_OK) { return false; } ser_.res.clear(); intptr_.SerializeResult(&ser_); ser_.res.pop_back(); return true; } TEST_F(InterpreterTest, Basic) { RunInline(R"( function foo(n) return n,n+1 end)", "code1"); int type = lua_getglobal(lua(), "foo"); ASSERT_EQ(LUA_TFUNCTION, type); lua_pushnumber(lua(), 42); lua_pcall(lua(), 1, 2, 0); int val1 = lua_tointeger(lua(), -1); int val2 = lua_tointeger(lua(), -2); lua_pop(lua(), 2); EXPECT_EQ(43, val1); EXPECT_EQ(42, val2); EXPECT_EQ(0, lua_gettop(lua())); lua_pushstring(lua(), "foo"); EXPECT_EQ(3, lua_rawlen(lua(), 1)); lua_pop(lua(), 1); RunInline("return {nil, 'b'}", "code2", 1); ASSERT_EQ(1, lua_gettop(lua())); LOG(INFO) << lua_typename(lua(), lua_type(lua(), -1)); ASSERT_TRUE(lua_istable(lua(), -1)); ASSERT_EQ(2, lua_rawlen(lua(), -1)); lua_len(lua(), -1); ASSERT_EQ(2, lua_tointeger(lua(), -1)); lua_pop(lua(), 1); lua_pushnil(lua()); while (lua_next(lua(), -2)) { /* uses 'key' (at index -2) and 'value' (at index -1) */ int kt = lua_type(lua(), -2); int vt = lua_type(lua(), -1); LOG(INFO) << "k/v : " << lua_typename(lua(), kt) << "/" << lua_tonumber(lua(), -2) << " " << lua_typename(lua(), vt); lua_pop(lua(), 1); } } TEST_F(InterpreterTest, UnknownFunc) { string_view code(R"( function foo(n) return myunknownfunc(1, n) end)"); CHECK_EQ(0, luaL_loadbuffer(lua(), code.data(), code.size(), "code1")); CHECK_EQ(0, lua_pcall(lua(), 0, 0, 0)); int type = lua_getglobal(lua(), "myunknownfunc"); ASSERT_EQ(LUA_TNIL, type); } TEST_F(InterpreterTest, Stack) { RunInline(R"( local x = {} for i=1,127 do x = {x} end return x )", "code1", 1); ASSERT_EQ(1, lua_gettop(lua())); ASSERT_TRUE(intptr_.IsResultSafe()); lua_pop(lua(), 1); RunInline(R"( local x = {} for i=1,128 do x = {x} end return x )", "code1", 1); ASSERT_EQ(1, lua_gettop(lua())); ASSERT_FALSE(intptr_.IsResultSafe()); } TEST_F(InterpreterTest, Add) { const char* s1 = "return 0"; const char* s2 = "foobar"; char sha_buf1[64], sha_buf2[64]; Interpreter::FuncSha1(s1, sha_buf1); Interpreter::FuncSha1(s2, sha_buf2); string_view sha1{sha_buf1, std::strlen(sha_buf1)}; string_view sha2{sha_buf2, std::strlen(sha_buf2)}; string err; EXPECT_EQ(Interpreter::ADD_OK, intptr_.AddFunction(sha1, "return 0", &err)); EXPECT_EQ(0, lua_gettop(lua())); EXPECT_EQ(Interpreter::COMPILE_ERR, intptr_.AddFunction(sha2, "foobar", &err)); EXPECT_THAT(err, testing::HasSubstr("syntax error")); EXPECT_EQ(0, lua_gettop(lua())); EXPECT_TRUE(intptr_.Exists(sha1)); } // Test cases taken from scripting.tcl TEST_F(InterpreterTest, Execute) { ASSERT_TRUE(Execute("return 42")); EXPECT_EQ("i(42)", ser_.res); EXPECT_TRUE(Execute("return 'hello'")); EXPECT_EQ("str(hello)", ser_.res); // Breaks compatibility. EXPECT_TRUE(Execute("return 100.5")); EXPECT_EQ("d(100.5)", ser_.res); EXPECT_TRUE(Execute("return true")); EXPECT_EQ("bool(1)", ser_.res); EXPECT_TRUE(Execute("return false")); EXPECT_EQ("bool(0)", ser_.res); EXPECT_TRUE(Execute("return {ok='fine'}")); EXPECT_EQ("status(fine)", ser_.res); EXPECT_TRUE(Execute("return {err= 'bla'}")); EXPECT_EQ("err(bla)", ser_.res); EXPECT_TRUE(Execute("return {1, 2, nil, 3}")); EXPECT_EQ("[i(1) i(2) nil i(3)]", ser_.res); EXPECT_TRUE(Execute("return {1,2,3,'ciao', {1,2}}")); EXPECT_EQ("[i(1) i(2) i(3) str(ciao) [i(1) i(2)]]", ser_.res); EXPECT_TRUE(Execute("return {map={a=1,b=2}}")); EXPECT_THAT(ser_.res, testing::AnyOf("{str(a) i(1) str(b) i(2)}", "{str(b) i(2) str(a) i(1)}")); } TEST_F(InterpreterTest, Call) { auto cb = [](auto ca) { auto* reply = ca.translator; auto span = ca.args; CHECK_GE(span.size(), 1u); string_view cmd{span[0].data(), span[0].size()}; if (cmd == "string") { reply->OnString("foo"); } else if (cmd == "double") { reply->OnDouble(3.1415); } else if (cmd == "int") { reply->OnInt(42); } else if (cmd == "err") { reply->OnError("myerr"); } else if (cmd == "status") { reply->OnStatus("mystatus"); } else { LOG(FATAL) << "Invalid param"; } }; intptr_.SetRedisFunc(cb); ASSERT_TRUE(Execute("local var = redis.pcall('string'); return {type(var), var}")); EXPECT_EQ("[str(string) str(foo)]", ser_.res); EXPECT_TRUE(Execute("local var = redis.pcall('double'); return {type(var), var}")); EXPECT_EQ("[str(number) d(3.1415)]", ser_.res); EXPECT_TRUE(Execute("local var = redis.pcall('int'); return {type(var), var}")); EXPECT_EQ("[str(number) i(42)]", ser_.res); EXPECT_TRUE(Execute("local var = redis.pcall('err'); return {type(var), var}")); EXPECT_EQ("[str(table) err(myerr)]", ser_.res); EXPECT_TRUE(Execute("local var = redis.pcall('status'); return {type(var), var}")); EXPECT_EQ("[str(table) status(mystatus)]", ser_.res); } TEST_F(InterpreterTest, CallArray) { auto cb = [](auto ca) { auto* reply = ca.translator; reply->OnArrayStart(2); reply->OnArrayStart(1); reply->OnArrayStart(2); reply->OnNil(); reply->OnString("s2"); reply->OnArrayEnd(); reply->OnArrayEnd(); reply->OnInt(42); reply->OnArrayEnd(); }; intptr_.SetRedisFunc(cb); EXPECT_TRUE(Execute("local var = redis.call(''); return {type(var), var}")); EXPECT_EQ("[str(table) [[[bool(0) str(s2)]] i(42)]]", ser_.res); } TEST_F(InterpreterTest, ArgKeys) { vector vec_arr{}; vector slices; SetGlobalArray("ARGV", {"foo", "bar"}); SetGlobalArray("KEYS", {"key1", "key2"}); EXPECT_TRUE(Execute("return {ARGV[1], KEYS[1], KEYS[2]}")); EXPECT_EQ("[str(foo) str(key1) str(key2)]", ser_.res); SetGlobalArray("INTKEYS", {"123456", "1"}); EXPECT_TRUE(Execute("return INTKEYS[1] + 0")) << error_; EXPECT_EQ("i(123456)", ser_.res); } TEST_F(InterpreterTest, Modules) { // cjson module EXPECT_TRUE(Execute("return cjson.encode({1, 2, 3})")); EXPECT_EQ("str([1,2,3])", ser_.res); EXPECT_TRUE(Execute("return cjson.decode('{\"a\": 1}')['a']")); EXPECT_EQ("i(1)", ser_.res); // cmsgpack module EXPECT_TRUE(Execute("return cmsgpack.pack('ok', true)")); EXPECT_EQ("str(\xA2ok\xC3)", ser_.res); // bit module EXPECT_TRUE(Execute("return bit.bor(8, 4, 5)")); EXPECT_EQ("i(13)", ser_.res); // struct module EXPECT_TRUE(Execute("return struct.pack('bbc4', 1, 2, 'test')")); EXPECT_EQ("str(\x1\x2test)", ser_.res); } // Check compatibility with Lua 5.1 TEST_F(InterpreterTest, Compatibility) { // unpack is no longer global EXPECT_TRUE(Execute("return unpack{1,2,3}")); EXPECT_EQ("i(1)", ser_.res); string_view test_foreach_template = "local t = {1,'two',3;four='yes'}; local out = {};" "table.{TESTF} (t, function(k, v) table.insert(out, {k, v}) end); " "return out; "; // table.foreach was removed string test_foreach = absl::StrReplaceAll(test_foreach_template, {{"{TESTF}", "foreach"}}); EXPECT_TRUE(Execute(test_foreach)); EXPECT_EQ("[[i(1) i(1)] [i(2) str(two)] [i(3) i(3)] [str(four) str(yes)]]", ser_.res); // table.foreachi was removed string test_foreachi = absl::StrReplaceAll(test_foreach_template, {{"{TESTF}", "foreachi"}}); EXPECT_TRUE(Execute(test_foreachi)); EXPECT_EQ("[[i(1) i(1)] [i(2) str(two)] [i(3) i(3)]]", ser_.res); EXPECT_FALSE(Execute("table.foreachi('not-a-table', print);")); // check invalid args // table.getn was replaced with length operator EXPECT_TRUE(Execute("return table.getn{1, 2, 3};")); EXPECT_EQ("i(3)", ser_.res); // table.setn was removed, resizing is no longer needed, it thows an error EXPECT_FALSE(Execute("local t = {}; local a = 1; table.setn(t, 100); return a+123;")); } TEST_F(InterpreterTest, AsyncReplacement) { const string_view kCases[] = { R"( redis.[A]call('INCR', 'A') redis.[A]call('INCR', 'A') )", R"( function test() redis.[A]call('INCR', 'A') end )", R"( local b = redis.call('GET', 'A') + redis.call('GET', 'B') )", R"( if redis.call('EXISTS', 'A') then redis.[A]call('SET', 'B', 1) end )", R"( while redis.call('EXISTS', 'A') do redis.[A]call('SET', 'B', 1) end )", R"( while redis.call('EXISTS', 'A') do print("OK") end )", R"( print(redis.call('GET', 'A')) )", R"( local table = { redis.call('GET', 'A') } )", R"( while true do redis.[A]call('INCR', 'A') end )", R"( if 1 + -- now this is a tricky comment redis.call('GET', 'A') > 0 then end )", R"( print('Output' .. redis.call('GET', 'A') ) )", R"( while 0 < -- we have a comment here unfortunately redis.call('GET', 'A') then end )", R"( while -- we have -- a tricky -- multiline comment redis.call('EXISTS') do end )", R"( --[[ WE SKIP COMMENT BLOCKS FOR NOW ]] redis.call('ECHO', 'TEST') )"}; for (auto test : kCases) { auto expected = absl::StrReplaceAll(test, {{"[A]", "a"}}); auto input = absl::StrReplaceAll(test, {{"[A]", ""}}); auto result = Interpreter::DetectPossibleAsyncCalls(input); string_view output = result ? *result : input; EXPECT_EQ(expected, output); } } TEST_F(InterpreterTest, ReplicateCommands) { EXPECT_TRUE(Execute("return redis.replicate_commands()")); EXPECT_EQ("i(1)", ser_.res); EXPECT_TRUE(Execute("redis.replicate_commands()")); EXPECT_EQ("nil", ser_.res); } TEST_F(InterpreterTest, Log) { EXPECT_FALSE(Execute(R"(redis.log('nonsense', 'nonsense'))")); EXPECT_THAT(error_, testing::HasSubstr("First argument must be a number (log level).")); EXPECT_TRUE(Execute(R"(redis.log(redis.LOG_WARNING, 'warn'))")); EXPECT_EQ("nil", ser_.res); EXPECT_FALSE(Execute(R"(redis.log(4))")); EXPECT_THAT(error_, testing::HasSubstr("requires two arguments or more")); } TEST_F(InterpreterTest, Robust) { EXPECT_FALSE(Execute(R"(eval "local a = {} setmetatable(a,{__index=function() foo() end}) return a")")); EXPECT_EQ("", ser_.res); } TEST_F(InterpreterTest, Unpack) { auto cb = [](Interpreter::CallArgs ca) { auto* reply = ca.translator; reply->OnInt(1); }; intptr_.SetRedisFunc(cb); ASSERT_TRUE(lua_checkstack(lua(), 7000)); bool res = Execute(R"( local N = 7000 local stringTable = {} for i = 1, N do stringTable[i] = "String " .. i end return redis.pcall('func', unpack(stringTable)) )"); ASSERT_TRUE(res) << error_; EXPECT_EQ("i(1)", ser_.res); } TEST_F(InterpreterTest, AvoidIntOverflow) { EXPECT_TRUE(Execute("return bit.tohex(65535, -2147483648)")); EXPECT_EQ("str(0000FFFF)", ser_.res); } TEST_F(InterpreterTest, LuaIntOverflow) { EXPECT_FALSE(Execute("EVAL \"struct.pack('>I2147483648', '10')\" 0")); } TEST_F(InterpreterTest, LuaGcStatistic) { InterpreterManager im(1); auto* interpreter = im.Get(); std::string_view keys[] = {"key1", "key2", "key3", "key4", "key5", "key6", "key7"}; interpreter->SetGlobalArray("KEYS", SliceSpan{keys}); auto cb = [](Interpreter::CallArgs ca) { auto* reply = ca.translator; reply->OnInt(1); }; interpreter->SetRedisFunc(cb); // next script generate several big values and set them to the keys // after the script is finished, GM isn't called for all values and // in the most cases we have more than 300k allocated memory // that will be cleaned later in the separate thread std::string script = R"( for i = 1, 7 do local str = string.rep(i, 1024 * 100) redis.call('SET', KEYS[1], str .. str) end )"; char sha_buf[64]; Interpreter::FuncSha1(script, sha_buf); string_view sha{sha_buf, std::strlen(sha_buf)}; string result; Interpreter::AddResult add_res = interpreter->AddFunction(sha, script, &result); EXPECT_EQ(Interpreter::ADD_OK, add_res); // When script is executed in the most cases we see that not all memory was deallocated // immediately and can be deallocated later Interpreter::RunResult run_res = interpreter->RunFunction(sha, &error_); EXPECT_EQ(Interpreter::RUN_OK, run_res); // check that after script is finished not the all memory was deallocated uint64_t used_bytes = InterpreterManager::tl_stats().used_bytes; EXPECT_GE(used_bytes, 0); auto force_gc_calls = InterpreterManager::tl_stats().force_gc_calls; // we need return interpreter to update statistic // force_gc_calls shouldn't be called im.Return(interpreter); EXPECT_EQ(force_gc_calls, InterpreterManager::tl_stats().force_gc_calls); EXPECT_LE(used_bytes, InterpreterManager::tl_stats().used_bytes); used_bytes = InterpreterManager::tl_stats().used_bytes; // we get the same interpeter again to call GC in separate thread auto* new_interpreter = im.Get(); EXPECT_EQ(interpreter, new_interpreter); // check that even if memory is deallocated in separate thread our statistic is correct std::thread t([&] { interpreter->RunGC(); EXPECT_EQ(InterpreterManager::tl_stats().used_bytes, 0); }); t.join(); im.Return(interpreter); EXPECT_GE(used_bytes, InterpreterManager::tl_stats().used_bytes); } } // namespace dfly ================================================ FILE: src/core/json/CMakeLists.txt ================================================ gen_flex(jsonpath_lexer) gen_bison(jsonpath_grammar) cur_gen_dir(gen_dir) add_library(jsonpath lexer_impl.cc driver.cc path.cc ${gen_dir}/jsonpath_lexer.cc ${gen_dir}/jsonpath_grammar.cc json_object.cc detail/jsoncons_dfs.cc detail/flat_dfs.cc detail/interned_blob.cc detail/interned_string.cc) target_link_libraries(jsonpath base absl::strings TRDP::reflex TRDP::jsoncons TRDP::flatbuffers dfly_page_usage) helio_cxx_test(jsonpath_test jsonpath dfly_core LABELS DFLY) helio_cxx_test(json_test jsonpath TRDP::jsoncons LABELS DFLY) helio_cxx_test(interned_blob_test dfly_core TRDP::mimalloc2 LABELS DFLY) ================================================ FILE: src/core/json/detail/common.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once namespace dfly::json::detail { enum MatchStatus { OUT_OF_BOUNDS, MISMATCH, }; } ================================================ FILE: src/core/json/detail/flat_dfs.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/json/detail/flat_dfs.h" #include "base/logging.h" namespace dfly::json::detail { using namespace std; using nonstd::make_unexpected; inline bool IsRecursive(flexbuffers::Type type) { return type == flexbuffers::FBT_MAP || type == flexbuffers::FBT_VECTOR; } // Binary search of a key, returns UINT_MAX if not found. unsigned FindByKey(const flexbuffers::TypedVector& keys, const char* elem) { unsigned s = 0, end = keys.size(); while (s < end) { unsigned mid = (s + end) / 2; flexbuffers::String mid_elem = keys[mid].AsString(); int res = strcmp(elem, mid_elem.c_str()); if (res < 0) { end = mid; } else if (res > 0) { s = mid + 1; } else { return mid; } } return UINT_MAX; } auto FlatDfsItem::Init(const PathSegment& segment) -> AdvanceResult { switch (segment.type()) { case SegmentType::IDENTIFIER: { if (obj().IsMap()) { auto map = obj().AsMap(); flexbuffers::TypedVector keys = map.Keys(); unsigned index = FindByKey(keys, segment.identifier().c_str()); if (index == UINT_MAX) { return Exhausted(); } state_.emplace(index, index); return DepthState{obj().AsVector()[index], depth_state_.second + 1}; } break; } case SegmentType::INDEX: { auto vec = obj().AsVector(); IndexExpr index = segment.index().Normalize(vec.size()); if (index.Empty()) { return make_unexpected(OUT_OF_BOUNDS); } state_ = index; return Next(vec[index.first]); break; } case SegmentType::DESCENT: if (segment_step_ == 1) { // first time, branching to return the same object but with the next segment, // exploring the path of ignoring the DESCENT operator. // Also, shift the state (segment_step) to bypass this branch next time. segment_step_ = 0; return DepthState{depth_state_.first, depth_state_.second + 1}; } // Now traverse all the children but do not progress with segment path. // This is why segment_step_ is set to 0. [[fallthrough]]; case SegmentType::WILDCARD: { auto vec = obj().AsVector(); if (vec.size() == 0) { return Exhausted(); } state_ = IndexExpr::All(); return Next(vec[0]); } break; default: LOG(DFATAL) << "Unknown segment " << SegmentName(segment.type()); } // end switch return nonstd::make_unexpected(MISMATCH); } auto FlatDfsItem::Advance(const PathSegment& segment) -> AdvanceResult { if (!state_) { return Init(segment); } ++state_->first; if (state_->Empty()) return Exhausted(); auto vec = obj().AsVector(); return Next(vec[state_->first]); } FlatDfs FlatDfs::Traverse(absl::Span path, const flexbuffers::Reference root, const PathFlatCallback& callback) { DCHECK(!path.empty()); FlatDfs dfs; if (path.size() == 1) { dfs.PerformStep(path[0], root, callback); return dfs; } using ConstItem = FlatDfsItem; vector stack; stack.emplace_back(root); do { unsigned segment_index = stack.back().segment_idx(); const auto& path_segment = path[segment_index]; // init or advance the current object ConstItem::AdvanceResult res = stack.back().Advance(path_segment); if (res && !res->first.IsNull()) { const flexbuffers::Reference next = res->first; DVLOG(2) << "Handling now " << next.GetType() << " " << next.ToString(); // We descent only if next is object or an array. if (IsRecursive(next.GetType())) { unsigned next_seg_id = res->second; if (next_seg_id + 1 < path.size()) { stack.emplace_back(next, next_seg_id); } else { // terminal step // TODO: to take into account MatchStatus // for `json.set foo $.a[10]` or for `json.set foo $.*.b` dfs.PerformStep(path[next_seg_id], next, callback); } } } else { stack.pop_back(); } } while (!stack.empty()); return dfs; } auto FlatDfs::PerformStep(const PathSegment& segment, const flexbuffers::Reference node, const PathFlatCallback& callback) -> nonstd::expected { switch (segment.type()) { case SegmentType::IDENTIFIER: { if (!node.IsMap()) return make_unexpected(MISMATCH); auto map = node.AsMap(); flexbuffers::Reference value = map[segment.identifier().c_str()]; if (!value.IsNull()) { DoCall(callback, string_view{segment.identifier()}, value); } } break; case SegmentType::INDEX: { if (!node.IsUntypedVector()) return make_unexpected(MISMATCH); auto vec = node.AsVector(); IndexExpr index = segment.index().Normalize(vec.size()); if (index.Empty()) { return make_unexpected(OUT_OF_BOUNDS); } for (; index.first <= index.second; ++index.first) DoCall(callback, nullopt, vec[index.first]); } break; case SegmentType::DESCENT: case SegmentType::WILDCARD: { auto vec = node.AsVector(); // always succeeds auto keys = node.AsMap().Keys(); // always succeeds string str; for (size_t i = 0; i < vec.size(); ++i) { flexbuffers::Reference key = keys[i]; optional opt_key; if (key.IsString()) { str = key.ToString(); opt_key = str; } DoCall(callback, opt_key, vec[i]); } } break; default: LOG(DFATAL) << "Unknown segment " << SegmentName(segment.type()); } return {}; } } // namespace dfly::json::detail ================================================ FILE: src/core/json/detail/flat_dfs.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "core/flatbuffers.h" #include "core/json/detail/common.h" #include "core/json/path.h" namespace dfly::json::detail { class FlatDfsItem { public: using ValueType = flexbuffers::Reference; using DepthState = std::pair; // object, segment_idx pair using AdvanceResult = nonstd::expected; FlatDfsItem(ValueType val, unsigned idx = 0) : depth_state_(val, idx) { } // Returns the next object to traverse // or null if traverse was exhausted or the segment does not match. AdvanceResult Advance(const PathSegment& segment); unsigned segment_idx() const { return depth_state_.second; } private: ValueType obj() const { return depth_state_.first; } DepthState Next(ValueType obj) const { return {obj, depth_state_.second + segment_step_}; } DepthState Exhausted() const { return {ValueType(), 0}; } AdvanceResult Init(const PathSegment& segment); // For most operations we advance the path segment by 1 when we descent into the children. unsigned segment_step_ = 1; DepthState depth_state_; std::optional state_; }; // Traverses a json object according to the given path and calls the callback for each matching // field. With DESCENT segments it will match 0 or more fields in depth. // MATCH(node, DESCENT|SUFFIX) = MATCH(node, SUFFIX) || // { MATCH(node->child, DESCENT/SUFFIX) for each child of node } class FlatDfs { public: // TODO: for some operations we need to know the type of mismatches. static FlatDfs Traverse(absl::Span path, const flexbuffers::Reference root, const PathFlatCallback& callback); unsigned matches() const { return matches_; } private: bool TraverseImpl(absl::Span path, const PathFlatCallback& callback); nonstd::expected PerformStep(const PathSegment& segment, const flexbuffers::Reference node, const PathFlatCallback& callback); void DoCall(const PathFlatCallback& callback, std::optional key, const flexbuffers::Reference node) { ++matches_; callback(key, node); } unsigned matches_ = 0; }; } // namespace dfly::json::detail ================================================ FILE: src/core/json/detail/interned_blob.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #include "core/json/detail/interned_blob.h" #include #include #include "core/detail/stateless_allocator.h" namespace { constexpr size_t kUint32Size = sizeof(uint32_t); constexpr size_t kHeaderSize = sizeof(uint32_t) * 2; } // namespace namespace dfly::detail { InternedBlobHandle InternedBlobHandle::Create(std::string_view sv) { if (sv.empty()) { return InternedBlobHandle{nullptr}; } constexpr uint32_t ref_count = 1; DCHECK_LE(sv.size(), std::numeric_limits::max()); const uint32_t str_len = sv.size(); // We need +1 byte for \0 because jsoncons expects c_str() and data() style accessors on keys BlobPtr blob = StatelessAllocator{}.allocate(kHeaderSize + str_len + 1); std::memcpy(blob, &str_len, kUint32Size); std::memcpy(blob + kUint32Size, &ref_count, kUint32Size); std::memcpy(blob + kHeaderSize, sv.data(), str_len); // null terminate so jsoncons can directly access the char* as string blob[kHeaderSize + str_len] = '\0'; return InternedBlobHandle{blob + kHeaderSize}; } uint32_t InternedBlobHandle::Size() const { if (!blob_) return 0; uint32_t size; std::memcpy(&size, blob_ - kHeaderSize, kUint32Size); return size; } uint32_t InternedBlobHandle::RefCount() const { DCHECK(blob_) << "Called RefCount() on empty blob"; uint32_t ref_count; std::memcpy(&ref_count, blob_ - kUint32Size, kUint32Size); return ref_count; } void InternedBlobHandle::IncrRefCount() { // NOLINT - non-const, mutates via ptr const uint32_t ref_count = RefCount(); DCHECK_LT(ref_count, std::numeric_limits::max()) << "Attempt to increase max refcount"; const uint32_t updated_count = ref_count + 1; std::memcpy(blob_ - kUint32Size, &updated_count, kUint32Size); } void InternedBlobHandle::DecrRefCount() { // NOLINT - non-const, mutates via ptr const uint32_t ref_count = RefCount(); DCHECK_GE(ref_count, 1ul) << "Attempt to decrease zero refcount"; const uint32_t updated_count = ref_count - 1; std::memcpy(blob_ - kUint32Size, &updated_count, kUint32Size); } size_t InternedBlobHandle::MemUsed() const { return blob_ ? mi_usable_size(blob_ - kHeaderSize) : 0; } void InternedBlobHandle::Destroy(InternedBlobHandle& handle) { if (handle.blob_) { const size_t to_destroy = kHeaderSize + handle.Size() + 1; StatelessAllocator{}.deallocate(handle.blob_ - kHeaderSize, to_destroy); handle.blob_ = nullptr; } } } // namespace dfly::detail ================================================ FILE: src/core/json/detail/interned_blob.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #pragma once #include #include namespace dfly::detail { // Layout is: 4 bytes size, 4 bytes refcount, char data, followed by nul-char. // The trailing nul-char is required because jsoncons needs to access c_str/data without a // size. The blob_ itself points directly to the data, so that callers do not have to perform // pointer arithmetic for c_str() and data() calls: // [size:4] [refcount:4] [string] [\0] // ^-8 ^- 4 ^blob_ using BlobPtr = char*; // A lightweight handle around a blob pointer, used to wrap the blob data when storing it in hashset // and also within interned strings. Does not handle lifetime of the data. Only provides convenience // methods to change state inside the blob and "view" style methods to access the string inside the // blob. Multiple handles can point to the same blob. class InternedBlobHandle { public: InternedBlobHandle() = default; [[nodiscard]] static InternedBlobHandle Create(std::string_view sv); uint32_t Size() const; uint32_t RefCount() const; const char* Data() const { return blob_; } // The refcount methods are explicitly part of the public API and not tied to the handle lifetime // to keep control over exactly when we modify data in the blob ptr. We do not want to increase // ref count on each handle creation and conversely decrease it when a handle is destroyed, eg on // every hash table lookup etc. The ref count is only increased or decreased at the InternedString // API level, when a new string is created, and when a string is destroyed. This allows us to // avoid writing to memory unless absolutely necessary, making the handle cheap. // Increment ref count, asserts if count grows over type max limit void IncrRefCount(); // Decrement ref count, asserts if count falls below 0 void DecrRefCount(); // Returns bytes used, including string, header and trailing byte size_t MemUsed() const; // Convenience method to deallocate storage. Not for use in destructor. static void Destroy(InternedBlobHandle& handle); operator std::string_view() const { // NOLINT (non-explicit operator for easier comparisons) return blob_ ? std::string_view{blob_, Size()} : ""; } auto operator<=>(const InternedBlobHandle& other) const = default; bool operator==(const InternedBlobHandle& other) const = default; explicit operator bool() const { return blob_; } private: explicit InternedBlobHandle(BlobPtr blob) : blob_{blob} { } BlobPtr blob_{nullptr}; }; struct BlobHash { using is_transparent = void; size_t operator()(std::string_view sv) const { return std::hash{}(sv); } }; struct BlobEq { using is_transparent = void; bool operator()(const InternedBlobHandle& a, const InternedBlobHandle& b) const { return a.Data() == b.Data(); } bool operator()(std::string_view a, std::string_view b) const { return a == b; } }; // This pool holds blob handles and is used by InternedString to manage string access. It would be // nice to keep this on the mimalloc heap by using StatelessAllocator. However, JSON memory usage is // estimated by comparing mimalloc usage before and after creating an object. If we keep this pool // on mimalloc, it can introduce variations such as resizing of its internal store when adding a new // object. This results in non-deterministic memory usage, which introduces incorrectness in tests // and the memory usage command. To keep memory estimation per object accurate, the pool is // allocated on the default heap. using InternedBlobPool = absl::flat_hash_set; static_assert(sizeof(InternedBlobHandle) == sizeof(char*)); } // namespace dfly::detail ================================================ FILE: src/core/json/detail/interned_string.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #include "core/json/detail/interned_string.h" namespace { constexpr auto kLoadFactorToShrinkPool = 0.2; thread_local dfly::InternedStringStats tl_stats; } // namespace namespace dfly::detail { InternedString& InternedString::operator=(InternedString other) { swap(other); return *this; } void InternedString::ResetPool() { InternedBlobPool& pool = GetPoolRef(); for (InternedBlobHandle handle : pool) { InternedBlobHandle::Destroy(handle); } pool.clear(); // Pool hits and misses are not reset, they are monotonically increasing counters // TODO reset these two fields in config resetstats tl_stats.pool_bytes = 0; tl_stats.pool_entries = 0; tl_stats.pool_table_bytes = 0; tl_stats.live_references = 0; } InternedBlobHandle InternedString::Intern(const std::string_view sv) { if (sv.empty()) return {}; tl_stats.live_references += 1; InternedBlobPool& pool_ref = GetPoolRef(); if (const auto it = pool_ref.find(sv); it != pool_ref.end()) { tl_stats.hits++; InternedBlobHandle blob = *it; blob.IncrRefCount(); return blob; } InternedBlobHandle handle = InternedBlobHandle::Create(sv); pool_ref.emplace(handle); tl_stats.pool_entries++; tl_stats.pool_bytes += handle.MemUsed(); tl_stats.misses++; return handle; } void InternedString::Acquire() { // NOLINT if (!entry_) return; tl_stats.live_references += 1; entry_.IncrRefCount(); } void InternedString::Release() { if (!entry_) return; entry_.DecrRefCount(); tl_stats.live_references -= 1; if (entry_.RefCount() == 0) { InternedBlobPool& pool_ref = GetPoolRef(); pool_ref.erase(entry_); tl_stats.pool_entries--; tl_stats.pool_bytes -= entry_.MemUsed(); InternedBlobHandle::Destroy(entry_); // When pool is underutilized, shrink it by swapping. if (const auto load_factor = pool_ref.load_factor(); ABSL_PREDICT_FALSE(load_factor > 0 && load_factor < kLoadFactorToShrinkPool)) { // The LHS of swap is a new pool constructed from the original pool reference. The RHS is the // original pool. After the swap, the temporary is destroyed. Note that this is not a strict // shrink. The new pool internally allocates enough capacity so that the load factor is around // 0.8. So the capacity after swap is still larger than size, but the load factor is improved. InternedBlobPool(pool_ref).swap(pool_ref); } } } InternedBlobPool& InternedString::GetPoolRef() { // Note on lifetimes: this pool is thread local and depends on the thread local memory resource // defined in the stateless allocator in src/core/detail/stateless_allocator.h. Since there is no // well-defined order of destruction, this pool must be manually reset before the memory resource // destruction. thread_local InternedBlobPool pool; return pool; } } // namespace dfly::detail namespace dfly { InternedStringStats& InternedStringStats::operator+=(const InternedStringStats& other) { pool_entries += other.pool_entries; pool_bytes += other.pool_bytes; hits += other.hits; misses += other.misses; pool_table_bytes += other.pool_table_bytes; live_references += other.live_references; return *this; } InternedStringStats GetInternedStringStats() { tl_stats.pool_table_bytes = detail::InternedString::GetPoolRef().capacity() * (sizeof(detail::InternedBlobHandle) + 1); return tl_stats; } } // namespace dfly ================================================ FILE: src/core/json/detail/interned_string.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #pragma once #include "core/detail/stateless_allocator.h" #include "core/json/detail/interned_blob.h" namespace dfly::detail { // InternedString handles incrementing and decrementing reference counts of the blobs tied to its // own lifecycle. It deletes the blob from a shard local pool when refcount is 0. // TODO examine cross shard json object interactions. Can a pool end up access from another shard? class InternedString { public: using allocator_type = StatelessAllocator; InternedString() = default; explicit InternedString(const std::string_view sv) : entry_(Intern(sv)) { } // The following constructors and members are added because they are required by jsoncons for // keys. Each of these is added in response to compiler errors and should not be removed, even if // they are seemingly a no-op or duplicated. // jsoncons sometimes creates empty obj with custom allocator. If it creates an object with any // other allocator, we should fail during compilation. template explicit InternedString(StatelessAllocator /*unused*/) { } template InternedString(const char* data, size_t size, Alloc alloc); template InternedString(It begin, It end); InternedString(const InternedString& other) : entry_{other.entry_} { Acquire(); } InternedString(InternedString&& other) noexcept : entry_{other.entry_} { other.entry_ = {}; } InternedString& operator=(InternedString other); ~InternedString() { Release(); } operator std::string_view() const { return entry_; } const char* data() const { return entry_ ? entry_.Data() : ""; } const char* c_str() const { return data(); } void swap(InternedString& other) noexcept { std::swap(entry_, other.entry_); } size_t length() const { return size(); } size_t size() const { return entry_.Size(); } int compare(const InternedString& other) const { return std::string_view{*this}.compare(other); } int compare(std::string_view other) const { return std::string_view{*this}.compare(other); } // lex. comparison auto operator<=>(const InternedString& other) const { return std::string_view{*this} <=> std::string_view{other}; } bool operator==(const InternedString& other) const = default; void shrink_to_fit() { // NOLINT (must be non-const to align with jsoncons usage) } // Destroys all strings in the pool. Must be called on process shutdown before the backing memory // resource is destroyed. static void ResetPool(); static InternedBlobPool& GetPoolRef(); size_t MemUsed() const { return entry_.MemUsed(); } private: // If a string exists in the pool, increments its refcount. If not, adds the string to the pool. // Returns a handle wrapping the string. static InternedBlobHandle Intern(std::string_view sv); // Increments the refcount if the entry is not null void Acquire(); // Decrements the refcount, removes entry from the pool if necessary, destroying the interned // blob. A side effect may be shrinking the pool if the load factor is suboptimal (see // kLoadFactorToShrinkPool in the implementation) void Release(); // Wraps a null pointer by default InternedBlobHandle entry_; }; template InternedString::InternedString(const char* data, size_t size, Alloc /*unused*/) : InternedString(std::string_view{data, size}) { } template InternedString::InternedString(It begin, It end) { if (begin == end) { return; } const auto size = std::distance(begin, end); const auto data_ptr = &*begin; entry_ = Intern(std::string_view(data_ptr, size)); } } // namespace dfly::detail namespace dfly { struct InternedStringStats { size_t pool_entries = 0; size_t pool_bytes = 0; size_t hits = 0; size_t misses = 0; size_t pool_table_bytes = 0; size_t live_references = 0; InternedStringStats& operator+=(const InternedStringStats& other); }; InternedStringStats GetInternedStringStats(); } // namespace dfly ================================================ FILE: src/core/json/detail/jsoncons_dfs.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // // clang-format off #include // clang-format on #include "core/json/detail/jsoncons_dfs.h" namespace dfly::json::detail { using namespace std; using nonstd::make_unexpected; ostream& operator<<(ostream& os, const PathSegment& ps) { os << SegmentName(ps.type()); return os; } inline bool IsRecursive(jsoncons::json_type type) { return type == jsoncons::json_type::object_value || type == jsoncons::json_type::array_value; } Dfs Dfs::Traverse(absl::Span path, const JsonType& root, const Cb& callback) { DCHECK(!path.empty()); Dfs dfs; if (path.size() == 1) { dfs.PerformStep(path[0], root, callback); return dfs; } using ConstItem = JsonconsDfsItem; vector stack; stack.emplace_back(&root); do { unsigned segment_index = stack.back().segment_idx(); const auto& path_segment = path[segment_index]; // init or advance the current object DVLOG(2) << "Advance segment [" << segment_index << "] " << path_segment; ConstItem::AdvanceResult res = stack.back().Advance(path_segment); if (res && res->first != nullptr) { const JsonType* next = res->first; // We descent only if next is object or an array. if (IsRecursive(next->type())) { unsigned next_seg_id = res->second; if (next_seg_id + 1 < path.size()) { DVLOG(2) << "Exploring node[" << stack.size() << "] " << next->type() << " " << next->to_string(); stack.emplace_back(next, next_seg_id); } else { DVLOG(2) << "Terminal node[" << stack.size() << "] " << next->type() << " " << next->to_string() << ", segment:" << path[next_seg_id]; // terminal step // TODO: to take into account MatchStatus // for `json.set foo $.a[10]` or for `json.set foo $.*.b` dfs.PerformStep(path[next_seg_id], *next, callback); } } } else { stack.pop_back(); } } while (!stack.empty()); return dfs; } Dfs Dfs::Mutate(absl::Span path, const MutateCallback& callback, JsonType* json) { DCHECK(!path.empty()); Dfs dfs; if (path.size() == 1) { dfs.MutateStep(path[0], callback, json); return dfs; } // Use vector to maintain order std::vector nodes_to_mutate; using Item = detail::JsonconsDfsItem; vector stack; stack.emplace_back(json); do { unsigned segment_index = stack.back().segment_idx(); const auto& path_segment = path[segment_index]; // init or advance the current object Item::AdvanceResult res = stack.back().Advance(path_segment); if (res && res->first != nullptr) { JsonType* next = res->first; DVLOG(2) << "Handling now " << next->type() << " " << next->to_string(); // We descent only if next is object or an array. if (IsRecursive(next->type())) { unsigned next_seg_id = res->second; if (next_seg_id + 1 < path.size()) { stack.emplace_back(next, next_seg_id); } else { // Terminal step: collect node for mutation nodes_to_mutate.push_back(next); } } } else { // If Advance failed (e.g., MISMATCH or OUT_OF_BOUNDS), the current node itself // might still be a terminal match because of the previous DESCENT segment. // Instead of mutating immediately (which could break ordering guarantees), // collect the node and defer mutation until after traversal. if (!res && segment_index > 0 && path[segment_index - 1].type() == SegmentType::DESCENT && stack.back().get_segment_step() == 0) { if (segment_index + 1 == path.size()) { // Terminal node discovered via DESCENT – store for later processing. nodes_to_mutate.push_back(stack.back().obj_ptr()); } } stack.pop_back(); } } while (!stack.empty()); // Apply mutations after DFS traversal is complete const PathSegment& terminal_segment = path.back(); for (auto it = nodes_to_mutate.begin(); it != nodes_to_mutate.end(); ++it) { dfs.MutateStep(terminal_segment, callback, *it); } return dfs; } Dfs Dfs::Delete(absl::Span path, JsonType* json) { DCHECK(!path.empty()); Dfs dfs; if (path.size() == 1) { dfs.DeleteStep(path[0], json); return dfs; } using Item = detail::JsonconsDfsItem; vector stack; stack.emplace_back(json); do { unsigned segment_index = stack.back().segment_idx(); const auto& path_segment = path[segment_index]; Item::AdvanceResult res = stack.back().Advance(path_segment); if (res && res->first != nullptr) { JsonType* next = res->first; if (IsRecursive(next->type())) { unsigned next_seg_id = res->second; if (next_seg_id + 1 < path.size()) { stack.emplace_back(next, next_seg_id); } else { // Terminal step: perform deletion immediately // At this point we're in the deepest level, so safe to delete dfs.DeleteStep(path[next_seg_id], next); } } } else { if (!res && segment_index > 0 && path[segment_index - 1].type() == SegmentType::DESCENT && stack.back().get_segment_step() == 0) { if (segment_index + 1 == path.size()) { // Terminal node discovered via DESCENT - safe to delete immediately // as we're backtracking dfs.DeleteStep(path[segment_index], stack.back().obj_ptr()); } } stack.pop_back(); } } while (!stack.empty()); return dfs; } auto Dfs::PerformStep(const PathSegment& segment, const JsonType& node, const Cb& callback) -> nonstd::expected { switch (segment.type()) { case SegmentType::IDENTIFIER: { if (!node.is_object()) return make_unexpected(MISMATCH); auto it = node.find(segment.identifier()); if (it != node.object_range().end()) { DoCall(callback, it->key(), it->value()); } } break; case SegmentType::INDEX: { if (!node.is_array()) return make_unexpected(MISMATCH); IndexExpr index = segment.index().Normalize(node.size()); if (index.Empty()) { return make_unexpected(OUT_OF_BOUNDS); } for (; index.first <= index.second; ++index.first) { DoCall(callback, nullopt, node[index.first]); } } break; case SegmentType::DESCENT: case SegmentType::WILDCARD: { if (node.is_object()) { for (const auto& k_v : node.object_range()) { DoCall(callback, k_v.key(), k_v.value()); } } else if (node.is_array()) { for (const auto& item : node.array_range()) { DoCall(callback, nullopt, item); } } } break; default: LOG(DFATAL) << "Unknown segment " << SegmentName(segment.type()); } return {}; } auto Dfs::MutateStep(const PathSegment& segment, const MutateCallback& cb, JsonType* node) -> nonstd::expected { switch (segment.type()) { case SegmentType::IDENTIFIER: { if (!node->is_object()) return make_unexpected(MISMATCH); auto it = node->find(segment.identifier()); if (it != node->object_range().end()) { cb(it->key(), &it->value()); } } break; case SegmentType::INDEX: { if (!node->is_array()) return make_unexpected(MISMATCH); IndexExpr index = segment.index().Normalize(node->size()); if (index.Empty()) { return make_unexpected(OUT_OF_BOUNDS); } while (index.first <= index.second) { auto it = node->array_range().begin() + index.first; cb(nullopt, &*it); ++index.first; } } break; case SegmentType::DESCENT: case SegmentType::WILDCARD: { if (node->is_object()) { auto it = node->object_range().begin(); while (it != node->object_range().end()) { cb(it->key(), &it->value()); ++it; } } else if (node->is_array()) { auto it = node->array_range().begin(); while (it != node->array_range().end()) { cb(nullopt, &*it); ++it; } } } break; case SegmentType::FUNCTION: LOG(DFATAL) << "Function segment is not supported for mutation"; break; } return {}; } auto Dfs::DeleteStep(const PathSegment& segment, JsonType* node) -> nonstd::expected { switch (segment.type()) { case SegmentType::IDENTIFIER: { if (!node->is_object()) return make_unexpected(MISMATCH); auto it = node->find(segment.identifier()); if (it != node->object_range().end()) { node->erase(it); ++matches_; } } break; case SegmentType::INDEX: { if (!node->is_array()) return make_unexpected(MISMATCH); IndexExpr index = segment.index().Normalize(node->size()); if (index.Empty()) { return make_unexpected(OUT_OF_BOUNDS); } // Delete from end to beginning to maintain indices for (int i = index.second; i >= index.first; --i) { auto it = node->array_range().begin() + i; node->erase(it); ++matches_; } } break; case SegmentType::DESCENT: case SegmentType::WILDCARD: { size_t initial_size = node->size(); node->clear(); matches_ += initial_size; } break; case SegmentType::FUNCTION: LOG(DFATAL) << "Function segment is not supported for deletion"; break; } return {}; } } // namespace dfly::json::detail ================================================ FILE: src/core/json/detail/jsoncons_dfs.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "core/json/detail/common.h" #include "core/json/json_object.h" #include "core/json/path.h" #include "core/overloaded.h" namespace dfly::json::detail { // Describes the current state of the DFS traversal for a single node inside json hierarchy. // Specifically it holds the parent object (can be a either a real object or an array), // and the iterator to one of its children that is currently being traversed. template class JsonconsDfsItem { public: using ValueType = std::conditional_t; using Ptr = ValueType*; using Ref = ValueType&; using ObjIterator = std::conditional_t; using ArrayIterator = std::conditional_t; using DepthState = std::pair; // object, segment_idx pair using AdvanceResult = nonstd::expected; JsonconsDfsItem(Ptr o, unsigned idx = 0) : depth_state_(o, idx) { } // Returns the next object to traverse // or null if traverse was exhausted or the segment does not match. AdvanceResult Advance(const PathSegment& segment); unsigned segment_idx() const { return depth_state_.second; } Ptr obj_ptr() const { return depth_state_.first; } unsigned get_segment_step() const { return segment_step_; } private: static bool ShouldIterateAll(SegmentType type) { return type == SegmentType::WILDCARD || type == SegmentType::DESCENT; } ObjIterator Begin() const { if constexpr (IsConst) { return obj().object_range().cbegin(); } else { return obj().object_range().begin(); } } ArrayIterator ArrBegin() const { if constexpr (IsConst) { return obj().array_range().cbegin(); } else { return obj().array_range().begin(); } } ArrayIterator ArrEnd() const { if constexpr (IsConst) { return obj().array_range().cend(); } else { return obj().array_range().end(); } } Ref obj() const { return *depth_state_.first; } DepthState Next(Ref obj) const { return {&obj, depth_state_.second + segment_step_}; } DepthState Exhausted() const { return {nullptr, 0}; } AdvanceResult Init(const PathSegment& segment); // For most operations we advance the path segment by 1 when we descent into the children. unsigned segment_step_ = 1; DepthState depth_state_; std::variant> state_; }; // Traverses a json object according to the given path and calls the callback for each matching // field. With DESCENT segments it will match 0 or more fields in depth. // MATCH(node, DESCENT|SUFFIX) = MATCH(node, SUFFIX) || // { MATCH(node->child, DESCENT/SUFFIX) for each child of node } class Dfs { public: using Cb = PathCallback; // TODO: for some operations we need to know the type of mismatches. static Dfs Traverse(absl::Span path, const JsonType& json, const Cb& callback); static Dfs Mutate(absl::Span path, const MutateCallback& callback, JsonType* json); // Simplified deletion without callback - more efficient for deletion operations static Dfs Delete(absl::Span path, JsonType* json); unsigned matches() const { return matches_; } private: bool TraverseImpl(absl::Span path, const Cb& callback); nonstd::expected PerformStep(const PathSegment& segment, const JsonType& node, const Cb& callback); nonstd::expected MutateStep(const PathSegment& segment, const MutateCallback& cb, JsonType* node); nonstd::expected DeleteStep(const PathSegment& segment, JsonType* node); void DoCall(const Cb& callback, std::optional key, const JsonType& node) { ++matches_; callback(key, node); } unsigned matches_ = 0; }; template auto JsonconsDfsItem::Advance(const PathSegment& segment) -> AdvanceResult { AdvanceResult result = std::visit( // line break Overloaded{ [&](std::monostate) { return Init(segment); }, // Init state [&](ObjIterator& it) -> AdvanceResult { if (!ShouldIterateAll(segment.type())) return Exhausted(); ++it; return it == obj().object_range().end() ? Exhausted() : Next(it->value()); }, [&](std::pair& pair) -> AdvanceResult { if (pair.first == pair.second) return Exhausted(); ++pair.first; return Next(*pair.first); }, }, state_); return result; } template auto JsonconsDfsItem::Init(const PathSegment& segment) -> AdvanceResult { switch (segment.type()) { case SegmentType::IDENTIFIER: { if (obj().is_object()) { auto it = obj().find(segment.identifier()); if (it != obj().object_range().end()) { state_ = it; return DepthState{&it->value(), depth_state_.second + 1}; } else { return Exhausted(); } } break; } case SegmentType::INDEX: if (obj().is_array()) { IndexExpr index = segment.index().Normalize(obj().size()); if (index.Empty()) { return nonstd::make_unexpected(OUT_OF_BOUNDS); } auto start = ArrBegin() + index.first, end = ArrBegin() + index.second; state_ = std::make_pair(start, end); return Next(*start); } break; case SegmentType::DESCENT: if (segment_step_ == 1) { // first time, branching to return the same object but with the next segment, // exploring the path of ignoring the DESCENT operator. // Also, shift the state (segment_step) to bypass this branch next time. segment_step_ = 0; return DepthState{depth_state_.first, depth_state_.second + 1}; } // Now traverse all the children but do not progress with segment path. // This is why segment_step_ is set to 0. [[fallthrough]]; case SegmentType::WILDCARD: { if (obj().is_object()) { jsoncons::range rng = obj().object_range(); if (rng.cbegin() == rng.cend()) { return Exhausted(); } state_ = Begin(); return Next(Begin()->value()); } if (obj().is_array()) { auto start = ArrBegin(), end = ArrEnd(); if (start == end) { return Exhausted(); } state_ = std::make_pair(start, end - 1); // end is inclusive return Next(*start); } break; } default: LOG(DFATAL) << "Unknown segment " << SegmentName(segment.type()); } // end switch return nonstd::make_unexpected(MISMATCH); } } // namespace dfly::json::detail ================================================ FILE: src/core/json/driver.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "src/core/json/driver.h" #include #include "base/logging.h" #include "src/core/json/lexer_impl.h" #include "src/core/overloaded.h" using namespace std; namespace dfly::json { namespace { class SingleValueImpl : public AggFunction { Result GetResultImpl() const final { return val_; } protected: void Init(const JsonType& src) { if (src.is_double()) { val_.emplace(src.as_double()); } else { val_.emplace(src.as()); } } void Init(const flexbuffers::Reference src) { if (src.IsFloat()) { val_.emplace(src.AsDouble()); } else { val_.emplace(src.AsInt64()); } } Result val_; }; class MaxImpl : public SingleValueImpl { bool ApplyImpl(const JsonType& src) final { if (!src.is_number()) { return false; } visit(Overloaded{ [&](monostate) { Init(src); }, [&](double d) { val_ = max(d, src.as_double()); }, [&](int64_t i) { if (src.is_double()) val_ = max(double(i), src.as_double()); else val_ = max(i, src.as()); }, }, val_); return true; } bool ApplyImpl(flexbuffers::Reference src) final { if (!src.IsNumeric()) { return false; } visit(Overloaded{ [&](monostate) { Init(src); }, [&](double d) { val_ = max(d, src.AsDouble()); }, [&](int64_t i) { if (src.IsFloat()) val_ = max(double(i), src.AsDouble()); else val_ = max(i, src.AsInt64()); }, }, val_); return true; } }; class MinImpl : public SingleValueImpl { private: bool ApplyImpl(const JsonType& src) final { if (!src.is_number()) { return false; } visit(Overloaded{ [&](monostate) { Init(src); }, [&](double d) { val_ = min(d, src.as_double()); }, [&](int64_t i) { if (src.is_double()) val_ = min(double(i), src.as_double()); else val_ = min(i, src.as()); }, }, val_); return true; } bool ApplyImpl(flexbuffers::Reference src) final { if (!src.IsNumeric()) { return false; } visit(Overloaded{ [&](monostate) { Init(src); }, [&](double d) { val_ = min(d, src.AsDouble()); }, [&](int64_t i) { if (src.IsFloat()) val_ = min(double(i), src.AsDouble()); else val_ = min(i, src.AsInt64()); }, }, val_); return true; } }; class AvgImpl : public AggFunction { private: bool ApplyImpl(const JsonType& src) final { if (!src.is_number()) { return false; } sum_ += src.as_double(); count_++; return true; } bool ApplyImpl(flexbuffers::Reference src) final { if (!src.IsNumeric()) { return false; } sum_ += src.AsDouble(); count_++; return true; } Result GetResultImpl() const final { DCHECK_GT(count_, 0u); // AggFunction guarantees that return Result(double(sum_ / count_)); } double sum_ = 0; uint64_t count_ = 0; }; } // namespace Driver::Driver() : lexer_(make_unique()) { } Driver::~Driver() { } void Driver::SetInput(string str) { cur_str_ = std::move(str); lexer_->in(cur_str_); path_.clear(); } void Driver::ResetScanner() { lexer_ = make_unique(); } void Driver::AddFunction(string_view fname) { if (!path_.empty()) { throw Parser::syntax_error(lexer_->location(), "function can be only at the beginning of the path"); } shared_ptr func; if (fname == "max") { func = make_shared(); } else if (fname == "min") { func = make_shared(); } else if (fname == "avg") { func = make_shared(); } else { throw Parser::syntax_error(lexer_->location(), absl::StrCat("Unknown function: ", fname)); } path_.emplace_back(std::move(func)); } } // namespace dfly::json ================================================ FILE: src/core/json/driver.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "src/core/json/path.h" namespace dfly { namespace json { class Lexer; class location; // from jsonpath_grammar.hh class Driver { public: Driver(); virtual ~Driver(); Lexer* lexer() { return lexer_.get(); } void SetInput(std::string str); void ResetScanner(); virtual void Error(const location& l, const std::string& msg) = 0; void AddIdentifier(const std::string& identifier) { AddSegment(PathSegment(SegmentType::IDENTIFIER, identifier)); } void AddFunction(std::string_view fname); void AddWildcard() { AddSegment(PathSegment(SegmentType::WILDCARD)); } void AddSegment(PathSegment segment) { path_.push_back(std::move(segment)); } Path TakePath() { return std::move(path_); } private: Path path_; std::string cur_str_; std::unique_ptr lexer_; }; } // namespace json } // namespace dfly ================================================ FILE: src/core/json/interned_blob_test.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #include "base/gtest.h" #include "core/detail/stateless_allocator.h" #include "core/json/detail/interned_string.h" #include "core/mi_memory_resource.h" using namespace std::literals; using namespace dfly; namespace { MiMemoryResource* MemoryResource() { thread_local mi_heap_t* heap = mi_heap_new(); thread_local MiMemoryResource memory_resource{heap}; return &memory_resource; } } // namespace class InternedBlobTest : public testing::Test { protected: void SetUp() override { InitTLStatelessAllocMR(MemoryResource()); } void TearDown() override { CleanupStatelessAllocMR(); } }; using detail::BlobPtr; using detail::InternedBlobHandle; TEST_F(InternedBlobTest, MemoryUsage) { const auto* mr = MemoryResource(); const auto usage_before = mr->used(); InternedBlobHandle blob = InternedBlobHandle::Create("1234567"); const auto usage_after = mr->used(); const auto expected_delta = blob.MemUsed(); EXPECT_EQ(usage_before + expected_delta, usage_after); InternedBlobHandle::Destroy(blob); EXPECT_EQ(usage_before, mr->used()); } void CheckBlob(InternedBlobHandle& blob, std::string_view expected, uint32_t ref_cnt = 1) { EXPECT_EQ(blob, expected); EXPECT_EQ(blob.Size(), expected.size()); EXPECT_EQ(blob.RefCount(), ref_cnt); } TEST_F(InternedBlobTest, Ctors) { auto blob = InternedBlobHandle::Create(""); EXPECT_EQ(blob.Size(), 0); EXPECT_FALSE(blob); InternedBlobHandle::Destroy(blob); InternedBlobHandle src = InternedBlobHandle::Create("foobar"); InternedBlobHandle dest{src}; CheckBlob(dest, "foobar"); CheckBlob(src, "foobar"); InternedBlobHandle::Destroy(dest); } TEST_F(InternedBlobTest, Comparison) { auto blob = InternedBlobHandle::Create("foobar"); constexpr detail::BlobEq blob_eq; EXPECT_TRUE(blob_eq(blob, "foobar")); EXPECT_TRUE(blob_eq("foobar", blob)); InternedBlobHandle second = blob; second.IncrRefCount(); EXPECT_TRUE(blob_eq(blob, second)); InternedBlobHandle::Destroy(blob); } TEST_F(InternedBlobTest, RefCounts) { auto blob = InternedBlobHandle::Create("1234567"); EXPECT_EQ(blob.RefCount(), 1); blob.DecrRefCount(); EXPECT_DEBUG_DEATH(blob.DecrRefCount(), "Attempt to decrease zero refcount"); InternedBlobHandle::Destroy(blob); } TEST_F(InternedBlobTest, Pool) { detail::InternedBlobPool pool{}; InternedBlobHandle b1 = InternedBlobHandle::Create("foo"); pool.emplace(b1); // search by string view EXPECT_TRUE(pool.contains("foo")); // increment the refcount. The blob is still found because the hasher only looks at the string b1.IncrRefCount(); EXPECT_TRUE(pool.contains("foo")); InternedBlobHandle::Destroy(b1); } using detail::InternedString; namespace { void StringCheck(const InternedString& s, const char* ptr) { std::string_view sv{ptr}; EXPECT_STREQ(s.data(), ptr); EXPECT_STREQ(s.c_str(), ptr); EXPECT_EQ(s.size(), sv.size()); EXPECT_EQ(s.length(), sv.size()); EXPECT_EQ(std::string_view(s), sv); EXPECT_EQ(std::string_view(s.data(), s.size()), sv); EXPECT_EQ(std::string_view(s.c_str(), s.size()), sv); } } // namespace TEST_F(InternedBlobTest, StringPool) { size_t hits = GetInternedStringStats().hits; size_t misses = GetInternedStringStats().misses; const auto& pool = InternedString::GetPoolRef(); EXPECT_TRUE(pool.empty()); { const InternedString s1{"foobar"}; StringCheck(s1, "foobar"); EXPECT_EQ(pool.size(), 1); misses += 1; EXPECT_EQ(GetInternedStringStats().misses, misses); EXPECT_EQ(GetInternedStringStats().pool_entries, 1); { const InternedString s2{"foobar"}; StringCheck(s2, "foobar"); EXPECT_EQ(pool.size(), 1); EXPECT_EQ(GetInternedStringStats().misses, misses); EXPECT_EQ(GetInternedStringStats().pool_entries, 1); hits += 1; EXPECT_EQ(GetInternedStringStats().hits, hits); } EXPECT_EQ(pool.size(), 1); } EXPECT_TRUE(pool.empty()); EXPECT_EQ(GetInternedStringStats().misses, misses); EXPECT_EQ(GetInternedStringStats().pool_entries, 0); EXPECT_EQ(GetInternedStringStats().pool_bytes, 0); EXPECT_EQ(GetInternedStringStats().hits, hits); std::vector strings; for (auto i = 0; i < 1000; ++i) { strings.emplace_back(std::to_string(i)); } EXPECT_EQ(pool.size(), 1000); EXPECT_EQ(GetInternedStringStats().pool_entries, 1000); misses += 1000; EXPECT_EQ(GetInternedStringStats().misses, misses); strings.clear(); EXPECT_TRUE(pool.empty()); EXPECT_EQ(GetInternedStringStats().pool_entries, 0); EXPECT_EQ(GetInternedStringStats().pool_bytes, 0); for (auto i = 0; i < 1000; ++i) { strings.emplace_back("zyx"); } EXPECT_EQ(pool.size(), 1); EXPECT_EQ(GetInternedStringStats().pool_entries, 1); hits += 999; EXPECT_EQ(GetInternedStringStats().hits, hits); strings.clear(); EXPECT_TRUE(pool.empty()); InternedString empty; EXPECT_TRUE(pool.empty()); } TEST_F(InternedBlobTest, StringApi) { InternedString s1{"foobar"}; EXPECT_EQ(std::string_view{s1}, "foobar"sv); StringCheck(s1, "foobar"); const auto& pool = InternedString::GetPoolRef(); InternedString s2{"psi"}; StringCheck(s2, "psi"); EXPECT_EQ(pool.size(), 2); // swap pointers into the pool s1.swap(s2); EXPECT_EQ(pool.size(), 2); StringCheck(s1, "psi"); StringCheck(s2, "foobar"); EXPECT_NE(s1, s2); EXPECT_EQ(s1, s1); // foobar < psi lexicographically EXPECT_LT(s2, s1); } TEST_F(InternedBlobTest, StringCtors) { const auto& pool = InternedString::GetPoolRef(); InternedString s1{"foobar"}; EXPECT_EQ(pool.size(), 1); // move ctor auto to = std::move(s1); EXPECT_EQ(pool.size(), 1); StringCheck(to, "foobar"); StringCheck(s1, ""); // These tests exercise self-move and self-copy behavior. This causes errors on newer GCC when // warnings are treated as errors (on CI). We need to version gate this because on older GCC this // check is not present. #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 13 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wself-move" #endif to = std::move(to); StringCheck(to, "foobar"); auto copied = to; EXPECT_EQ(pool.size(), 1); StringCheck(to, "foobar"); StringCheck(copied, "foobar"); copied = copied; #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 13 #pragma GCC diagnostic pop #endif StringCheck(copied, "foobar"); EXPECT_EQ(pool.size(), 1); const auto* mr = MemoryResource(); const auto before = mr->used(); std::string_view sv{"......."}; // ptr and size with some allocator, allocator will be ignored InternedString x{sv.data(), sv.size(), std::allocator{}}; StringCheck(x, "......."); EXPECT_EQ(pool.size(), 2); EXPECT_GE(mr->used(), before + x.MemUsed()); InternedString k{sv.begin(), sv.end()}; StringCheck(k, "......."); EXPECT_EQ(pool.size(), 2); } TEST_F(InternedBlobTest, PoolShrink) { InternedString::ResetPool(); std::vector v; const auto& ref = InternedString::GetPoolRef(); for (const auto i : std::views::iota(0, 1000)) v.emplace_back(std::to_string(i)); std::vector caps; constexpr auto jitter = std::views::iota(0, 6); while (!v.empty()) { constexpr auto step = 20; const auto from = v.end() - std::min(step, v.size()); v.erase(from, v.end()); // Interleaving inserts right after a possible resize, to ensure we don't have to increase // capacity right after a shrink. The caps vector should remain monotonically decreasing. for (const auto j : jitter) v.emplace_back(std::to_string(10000 + j)); caps.push_back(ref.capacity()); for (size_t i = 0; i < jitter.size(); ++i) v.pop_back(); } EXPECT_EQ(ref.load_factor(), 0); EXPECT_TRUE(std::ranges::is_sorted(caps, std::ranges::greater{})); // Check that capacity changes very infrequently size_t cap_trans = 0; for (size_t i = 1; i < caps.size(); ++i) { if (caps[i] != caps[i - 1]) ++cap_trans; } EXPECT_LT(cap_trans, caps.size() / 2); } ================================================ FILE: src/core/json/json_object.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/json/json_object.h" #include #include "base/logging.h" #include "core/page_usage/page_usage_stats.h" using namespace jsoncons; namespace { template std::optional ParseWithDecoder(std::string_view input, json_decoder&& decoder) { std::error_code ec; auto JsonErrorHandler = [](json_errc ec, const ser_context&) { VLOG(1) << "Error while decode JSON: " << make_error_code(ec).message(); return false; }; // The maximum allowed JSON nesting depth is 64. // The limit was reduced from 256 to 64. This change is reasonable, as most documents contain // no more than 20-30 levels of nesting. In the test case, over 128 levels were used, causing // the parser to enter a long stall due to excessive resource consumption. Even a limit of 128 // does not mitigate the issue. A limit of 64 is a sensible compromise. // See https://github.com/dragonflydb/dragonfly/issues/5028 const uint32_t json_nesting_depth_limit = 64; /* The maximum possible JSON nesting depth is either the specified json_nesting_depth_limit or half of the input size. Since nesting a JSON object requires at least 2 characters. */ auto parser_options = json_options{}.max_nesting_depth( std::min(json_nesting_depth_limit, uint32_t(input.size() / 2))); json_parser parser(parser_options, JsonErrorHandler); parser.update(input); parser.finish_parse(decoder, ec); if (!ec && decoder.is_valid()) { return decoder.get_result(); } return std::nullopt; } using namespace dfly; // The following two functions allocate a string-based object by copying data to a fresh memory // page. Then the move-assignment operator swaps it with the input node (swap_l_r in jsoncons), and // the temporary is destroyed at the end of the scope. bool DefragmentByteString(JsonType& j, PageUsage* page_usage) { const auto& byte_storage = j.cast(); if (byte_storage.length() == 0 || !page_usage->IsPageForObjectUnderUtilized(const_cast(byte_storage.data()))) return false; const byte_string_view bsv{byte_storage.data(), byte_storage.length()}; if (j.tag() == semantic_tag::ext) { j = JsonType(byte_string_arg, bsv, j.ext_tag(), byte_storage.get_allocator()); return true; } j = JsonType(byte_string_arg, bsv, j.tag(), byte_storage.get_allocator()); return true; } bool DefragmentLongString(JsonType& j, PageUsage* page_usage) { const auto& str_storage = j.cast(); if (str_storage.length() == 0 || !page_usage->IsPageForObjectUnderUtilized(const_cast(str_storage.data()))) return false; JsonType::string_view_type svt{str_storage.data(), str_storage.length()}; j = JsonType(svt, j.tag(), str_storage.get_allocator()); return true; } // Allocates a new json object of type json_object_arg, with fresh memory allocation for its // contained vector of key value pairs. Then moves members from j to this new object. Finally j is // swapped with the new object. bool DefragmentJsonObject(JsonType& j, PageUsage* page_usage) { auto& object = j.cast().value(); if (object.empty() || !page_usage->IsPageForObjectUnderUtilized(&*object.begin())) return false; // Creates a fresh object and reserves space for the underlying vector. JsonType new_node{json_object_arg, j.tag(), object.get_allocator()}; new_node.reserve(object.size()); for (auto& member : object) { // The member values are JsonType themselves, they just wrap pointers to actual storage. // Their move invokes the move ctor in jsoncons, which will move the value wrappers to new_node, // and leave the original in `j` holding references to `null_storage` type, see // `uninitialized_move_a` in jsoncons. The member key (a string) is not moved but copied into // new_node members. new_node.try_emplace(member.key(), std::move(member.value())); } // Invokes move assignment. A swap is performed, and new_node now holds null_storage // references instead of `j`. It will be destroyed on leaving scope, cleaning up its memory. j = std::move(new_node); return true; } // Same as DefragmentJsonObject except uses an array object. The contained members are moved // similarly, and on exit the old node is destroyed. bool DefragmentJsonArray(JsonType& j, PageUsage* page_usage) { auto& array = j.cast().value(); if (array.empty() || !page_usage->IsPageForObjectUnderUtilized(&*array.begin())) return false; JsonType new_node{json_array_arg, j.tag(), array.get_allocator()}; new_node.reserve(array.size()); for (JsonType& member : array) { new_node.push_back(std::move(member)); } j = std::move(new_node); return true; } } // namespace namespace dfly { std::optional JsonFromString(std::string_view input) { return ParseWithDecoder(input, json_decoder{}); } optional ParseJsonUsingShardHeap(string_view input) { return ParseWithDecoder(input, json_decoder{StatelessAllocator{}}); } bool Defragment(JsonType& j, PageUsage* page_usage) { bool did_defragment = false; // stack-based traversal inspired from jsoncons::basic_json::compute_memory_size std::stack stack; stack.push(&j); while (!stack.empty()) { JsonType* current = stack.top(); stack.pop(); const json_storage_kind storage_kind = current->storage_kind(); switch (storage_kind) { case json_storage_kind::byte_str: did_defragment |= DefragmentByteString(*current, page_usage); break; case json_storage_kind::long_str: did_defragment |= DefragmentLongString(*current, page_usage); break; case json_storage_kind::object: { did_defragment |= DefragmentJsonObject(*current, page_usage); auto& object = current->cast().value(); for (auto& member : object) { stack.push(&member.value()); } break; } case json_storage_kind::array: { did_defragment |= DefragmentJsonArray(*current, page_usage); auto& array = current->cast().value(); for (auto& member : array) { stack.push(&member); } break; } default: DCHECK(is_trivial_storage(storage_kind)) << "unexpected non trivial storage type:" << storage_kind; break; } } return did_defragment; } size_t ComputeMemorySize(const JsonType& j) { std::stack stack; stack.push(&j); size_t total = 0; auto add_used_memory = [&total](const auto* data) { if (data) total += mi_usable_size(data); }; using enum json_storage_kind; while (!stack.empty()) { const auto* current = stack.top(); stack.pop(); const auto storage = current->storage_kind(); if (is_trivial_storage(storage)) continue; switch (storage) { case object: { const auto& object_storage = current->cast().value(); if (!object_storage.empty()) add_used_memory(&*object_storage.begin()); for (const auto& member : object_storage) { total += member.key().MemUsed(); const auto& value = member.value(); if (!is_trivial_storage(value.storage_kind())) stack.push(&value); } } break; case array: { const auto& arr = current->cast().value(); if (!arr.empty()) add_used_memory(&arr[0]); for (const auto& elem : arr) if (!is_trivial_storage(elem.storage_kind())) stack.push(&elem); } break; case long_str: add_used_memory(current->cast().data()); break; case byte_str: add_used_memory(current->cast().data()); break; default: DCHECK(false) << "unexpected non trivial storage type:" << storage; } } return total; } } // namespace dfly ================================================ FILE: src/core/json/json_object.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include // for __cpp_lib_to_chars macro. #include "core/detail/stateless_allocator.h" #include "core/json/detail/interned_string.h" // std::from_chars is available in C++17 if __cpp_lib_to_chars is defined. #if __cpp_lib_to_chars >= 201611L #define JSONCONS_HAS_STD_FROM_CHARS 1 #endif #include #include #include #include #include namespace dfly { class PageUsage; using TmpJson = jsoncons::json; struct InternedStringPolicy : jsoncons::sorted_policy { template using member_key = detail::InternedString; }; using JsonType = jsoncons::basic_json>; // A helper type to use in template functions which are expected to work with both TmpJson // and JsonType template using JsonWithAllocator = jsoncons::basic_json; // Parses string into JSON. Any allocatons are done using the std allocator. This method should be // used for generic JSON parsing, in particular, it should not be used to parse objects which will // be stored in the db, as the backing storage is not managed by mimalloc. std::optional JsonFromString(std::string_view input); // Parses string into JSON, using mimalloc heap for allocations. This method should only be used on // shards where mimalloc heap is initialized. std::optional ParseJsonUsingShardHeap(std::string_view input); // Defragments the given json object by traversing its tree structure non-recursively, examining // nodes and defragmenting as needed. Returns true if any object within the node was reallocated bool Defragment(JsonType& j, PageUsage* page_usage); template auto MakeJsonPathExpr(std::string_view path, std::error_code& ec) -> jsoncons::jsonpath::jsonpath_expression { using ResultAllocT = typename Json::allocator_type; using TmpAllocT = std::allocator; using AllocSetT = jsoncons::allocator_set; return jsoncons::jsonpath::make_expression(AllocSetT(), path, ec); } size_t ComputeMemorySize(const JsonType& j); } // namespace dfly ================================================ FILE: src/core/json/json_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #include #include #include "base/gtest.h" #include "base/logging.h" namespace dfly { using namespace jsoncons; using namespace jsoncons::literals; using namespace testing; class JsonTest : public ::testing::Test { protected: JsonTest() { } }; TEST_F(JsonTest, Basic) { std::string data = R"( { "application": "hiking", "reputons": [ { "rater": "HikingAsylum", "assertion": "advanced", "rated": "Marilyn C", "rating": 0.90, "confidence": 0.99 } ] } )"; pmr::json j = pmr::json::parse(data); EXPECT_TRUE(j.contains("reputons")); jsonpath::json_replace(j, "$.reputons[*].rating", 1.1); EXPECT_EQ(1.1, j["reputons"][0]["rating"].as_double()); } TEST_F(JsonTest, SetEmpty) { pmr::json dest{json_object_arg}; // crashes on UB without the tag. dest["bar"] = "foo"; } TEST_F(JsonTest, Query) { json j = R"( {"a":{}, "b":{"a":1}, "c":{"a":1, "b":2}} )"_json; json out = jsonpath::json_query(j, "$..*"); EXPECT_EQ(R"([{},{"a":1},{"a":1,"b":2},1,1,2])"_json, out); json j2 = R"( {"firstName":"John","lastName":"Smith","age":27,"weight":135.25,"isAlive":true,"address":{"street":"21 2nd Street","city":"New York","state":"NY","zipcode":"10021-3100"},"phoneNumbers":[{"type":"home","number":"212 555-1234"},{"type":"office","number":"646 555-4567"}],"children":[],"spouse":null} )"_json; // json_query always returns arrays. // See here: https://github.com/danielaparker/jsoncons/issues/82 // Therefore we are going to only support the "extended" semantics // of json API (as they are called in AWS documentation). out = jsonpath::json_query(j2, "$.address"); EXPECT_EQ(R"([{"street":"21 2nd Street","city":"New York", "state":"NY","zipcode":"10021-3100"}])"_json, out); } TEST_F(JsonTest, Errors) { auto cb = [](json_errc, const ser_context&) { return false; }; json_decoder decoder; basic_json_parser parser(basic_json_decode_options{}, cb); std::string_view input{"\000bla"}; parser.update(input.data(), input.size()); std::error_code ec; parser.parse_some(decoder, ec); EXPECT_TRUE(ec); EXPECT_EQ(ec, json_errc::unexpected_eof); EXPECT_FALSE(decoder.is_valid()); } TEST_F(JsonTest, Path) { std::error_code ec; json j1 = R"({"field" : 1, "field-dash": 2})"_json; auto expr = jsonpath::make_expression("$.field", ec); EXPECT_FALSE(ec); expr.evaluate(j1, [](const std::string& path, const json& val) { ASSERT_EQ("$['field']", path); ASSERT_EQ(1, val.as()); }); expr = jsonpath::make_expression("$.field-dash", ec); ASSERT_FALSE(ec); // parses '-' expr.evaluate(j1, [](const std::string& path, const json& val) { ASSERT_EQ("$['field-dash']", path); ASSERT_EQ(2, val.as()); }); int called = 0; jsonpath::json_query(j1, "max($.*)", [&](const std::string& path, const json& val) { EXPECT_EQ("$", path); ASSERT_EQ(2, val.as()); ++called; }); EXPECT_EQ(1, called); auto res = jsonpath::json_query(j1, "max($.*)"); ASSERT_TRUE(res.is_array() && res.size() == 1); EXPECT_EQ(2, res[0].as()); called = 0; json j2 = R"({"field" : [1, 2, 3, 4, 5]})"_json; jsonpath::json_query(j2, "$.field[1:2]", [&](const std::string& path, const json& val) { EXPECT_EQ("$['field'][1]", path); ASSERT_EQ(2, val.as()); ++called; }); EXPECT_EQ(1, called); std::vector vals; jsonpath::json_query(j2, "$.field[1:]", [&](const std::string& path, const json& val) { vals.push_back(val.as()); }); EXPECT_THAT(vals, ElementsAre(2, 3, 4, 5)); jsonpath::json_query(j2, "$.field[-1]", [&](const std::string& path, const json& val) { EXPECT_EQ(5, val.as()); }); jsonpath::json_query(j2, "$.field[-6:1]", [&](const std::string& path, const json& val) { EXPECT_EQ(1, val.as()); }); } TEST_F(JsonTest, Delete) { json j1 = R"({"c":{"a":1, "b":2}, "d":{"a":1, "b":2, "c":3}, "e": [1,2]})"_json; auto deleter = [](const json::string_view_type& path, json& val) { LOG(INFO) << "path: " << path; // val.evaluate(); // if (val.is_object()) // val.erase(val.object_range().begin(), val.object_range().end()); }; jsonpath::json_replace(j1, "$.d.*", deleter); auto expr = jsonpath::make_expression("$.d.*"); auto callback = [](const std::string& path, const json& val) { LOG(INFO) << path << ": " << val << "\n"; }; expr.evaluate(j1, callback, jsonpath::result_options::path); auto it = j1.find("d"); ASSERT_TRUE(it != j1.object_range().end()); it->value().erase("a"); EXPECT_EQ(R"({"c":{"a":1, "b":2}, "d":{"b":2, "c":3}, "e": [1,2]})"_json, j1); } TEST_F(JsonTest, JsonWithPolymorhicAllocator) { char buffer[1024] = {}; std::pmr::monotonic_buffer_resource pool{std::data(buffer), std::size(buffer)}; std::pmr::polymorphic_allocator alloc(&pool); std::string input = R"( { "store": { "book": [ { "category": "Roman", "author": "Felix Lobrecht", "title": "Sonne und Beton", "price": 12.99 }, { "category": "Roman", "author": "Thomas F. Schneider", "title": "Im Westen nichts Neues", "price": 10.00 } ] } } )"; auto j1 = pmr::json::parse(combine_allocators(alloc), input, json_options{}); EXPECT_EQ("Roman", j1["store"]["book"][0]["category"].as_string()); EXPECT_EQ("Felix Lobrecht", j1["store"]["book"][0]["author"].as_string()); EXPECT_EQ(12.99, j1["store"]["book"][0]["price"].as_double()); EXPECT_EQ("Roman", j1["store"]["book"][1]["category"].as_string()); EXPECT_EQ("Im Westen nichts Neues", j1["store"]["book"][1]["title"].as_string()); EXPECT_EQ(10.00, j1["store"]["book"][1]["price"].as_double()); } } // namespace dfly ================================================ FILE: src/core/json/jsonpath_grammar.y ================================================ %skeleton "lalr1.cc" // -*- C++ -*- %require "3.5" // fedora 32 has this one. %defines // %header starts from 3.8.1 %define api.namespace {dfly::json} %define api.token.raw %define api.token.constructor %define api.value.type variant %define api.parser.class {Parser} %define parse.assert // Added to header file before parser declaration. %code requires { #include "src/core/json/path.h" namespace dfly { namespace json { class Driver; } } } // Added to cc file %code { #include "src/core/json/lexer_impl.h" #include "src/core/json/driver.h" #include #include "base/logging.h" // GCC 13+ yields spurious warnings about uninitialized variant members in bison-generated code #if !defined(__clang__) && __GNUC__ >= 13 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif #define yylex driver->lexer()->Lex using namespace std; static int unsafe_stoi(std::string_view s) { int value; bool success = absl::SimpleAtoi(s, &value); DCHECK(success); return value; } } %parse-param { Driver *driver } %locations %define parse.trace %define parse.error verbose // detailed %define parse.lac full %define api.token.prefix {TOK_} %token LBRACKET "[" RBRACKET "]" COLON ":" LPARENT "(" RPARENT ")" ROOT "$" DOT "." WILDCARD "*" DESCENT ".." SINGLE_QUOTE "'" DOUBLE_QUOTE "\"" // Needed 0 at the end to satisfy bison 3.5.1 %token YYEOF 0 %token UNQ_STR "unquoted string" %token INT "integer" %nterm identifier %nterm bracket_index %nterm single_quoted_string %nterm double_quoted_string %nterm quoted_content %% // Based on the following specification: // https://danielaparker.github.io/JsonCons.Net/articles/JsonPath/Specification.html jsonpath: ROOT { /* skip adding root */ } opt_relative_location | function_expr opt_relative_location opt_relative_location: | relative_location relative_location: DOT relative_path | DESCENT { driver->AddSegment(PathSegment{SegmentType::DESCENT}); } relative_path | bracket_expr relative_path: identifier { driver->AddIdentifier($1); } opt_relative_location | WILDCARD { driver->AddWildcard(); } opt_relative_location | bracket_expr identifier: UNQ_STR | INT bracket_expr: LBRACKET bracket_index RBRACKET { driver->AddSegment($2); } opt_relative_location bracket_index: single_quoted_string { $$ = PathSegment(SegmentType::IDENTIFIER, $1); } | double_quoted_string { $$ = PathSegment(SegmentType::IDENTIFIER, $1); } | WILDCARD { $$ = PathSegment{SegmentType::INDEX, IndexExpr::All()}; } | INT { int tmp_idx = unsafe_stoi($1); $$ = PathSegment(SegmentType::INDEX, IndexExpr(tmp_idx, tmp_idx)); } | INT COLON INT { $$ = PathSegment(SegmentType::INDEX, IndexExpr::HalfOpen( unsafe_stoi($1), unsafe_stoi($3))); } | INT COLON { $$ = PathSegment(SegmentType::INDEX, IndexExpr(unsafe_stoi($1), INT_MAX)); } | COLON INT { $$ = PathSegment(SegmentType::INDEX, IndexExpr::HalfOpen(0, unsafe_stoi($2))); } single_quoted_string: SINGLE_QUOTE quoted_content SINGLE_QUOTE { $$ = $2; } double_quoted_string: DOUBLE_QUOTE quoted_content DOUBLE_QUOTE { $$ = $2; } quoted_content: UNQ_STR { $$ = $1; } | INT { $$ = $1; } | quoted_content DOT UNQ_STR { $$ = $1 + "." + $3; } | quoted_content DOT INT { $$ = $1 + "." + $3; } function_expr: UNQ_STR { driver->AddFunction($1); } LPARENT ROOT relative_location RPARENT %% void dfly::json::Parser::error(const location_type& l, const string& m) { driver->Error(l, m); } ================================================ FILE: src/core/json/jsonpath_lexer.lex ================================================ %top{ // generated in the header file. #include "core/json/jsonpath_grammar.hh" } %o bison-cc-namespace="dfly.json" bison-cc-parser="Parser" %o namespace="dfly.json" // Generated class and main function %o lexer="AbstractLexer" lex="Lex" // our derived class from AbstractLexer %o class="Lexer" /* nodefault removes default echo rule */ %o nodefault batch %option unicode /* Declarations before lexer implementation. */ %{ #define DFLY_LEXER_CC 1 #include "src/core/json/lexer_impl.h" #undef DFLY_LEXER_CC %} %{ // Code run each time a pattern is matched. %} %% %{ // Code run each time lex() is called. %} [[:space:]]+ ; // skip white space "$" return Parser::make_ROOT(loc()); ".." return Parser::make_DESCENT(loc()); "." return Parser::make_DOT(loc()); ":" return Parser::make_COLON(loc()); "[" return Parser::make_LBRACKET(loc()); "]" return Parser::make_RBRACKET(loc()); "*" return Parser::make_WILDCARD(loc()); "(" return Parser::make_LPARENT(loc()); ")" return Parser::make_RPARENT(loc()); "'" return Parser::make_SINGLE_QUOTE(loc()); "\"" return Parser::make_DOUBLE_QUOTE(loc()); -?[0-9]{1,9} return Parser::make_INT(str(), loc()); [\w_\-]+ return Parser::make_UNQ_STR(str(), loc()); <> return Parser::make_YYEOF(loc()); . throw Parser::syntax_error(loc(), UnknownTokenMsg()); %% // Function definitions ================================================ FILE: src/core/json/jsonpath_test.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include "base/gtest.h" #include "base/logging.h" #include "core/json/driver.h" #include "core/json/lexer_impl.h" #include "core/mi_memory_resource.h" namespace flexbuffers { bool operator==(const Reference left, const Reference right) { return left.ToString() == right.ToString(); } } // namespace flexbuffers namespace dfly::json { using namespace std; using testing::ElementsAre; MATCHER_P(SegType, value, "") { return ExplainMatchResult(testing::Property(&PathSegment::type, value), arg, result_listener); } void PrintTo(SegmentType st, std::ostream* os) { *os << " segment(" << SegmentName(st) << ")"; } class TestDriver : public Driver { public: void Error(const location& l, const std::string& msg) final { LOG(INFO) << "Error at " << l << ": " << msg; } }; template JSON ValidJson(string_view str); template <> JsonType ValidJson(string_view str) { auto res = ParseJsonUsingShardHeap(str); CHECK(res) << "Failed to parse json: " << str; return *res; } template <> FlatJson ValidJson(string_view str) { static flexbuffers::Builder fbb; flatbuffers::Parser parser; fbb.Clear(); CHECK(parser.ParseFlexBuffer(str.data(), nullptr, &fbb)); fbb.Finish(); const auto& buffer = fbb.GetBuffer(); return flexbuffers::GetRoot(buffer); } bool is_int(const JsonType& val) { return val.is(); } int to_int(const JsonType& val) { return val.as(); } bool is_object(const JsonType& val) { return val.is_object(); } bool is_array(const JsonType& val) { return val.is_array(); } int is_int(FlatJson ref) { return ref.IsInt(); } int to_int(FlatJson ref) { return ref.AsInt32(); } bool is_object(FlatJson ref) { return ref.IsMap(); } bool is_array(FlatJson ref) { return ref.IsUntypedVector(); } class ScannerTest : public ::testing::Test { protected: void SetUp() override { Test::SetUp(); InitTLStatelessAllocMR(&m_); } ScannerTest() : m_(mi_heap_get_backing()) { driver_.lexer()->set_debug(1); } void SetInput(const std::string& str) { driver_.SetInput(str); } Parser::symbol_type Lex() { try { return driver_.lexer()->Lex(); } catch (const Parser::syntax_error& e) { LOG(INFO) << "Caught exception: " << e.what(); // with later bison versions we can return make_YYerror return Parser::make_YYEOF(e.location); } } MiMemoryResource m_; TestDriver driver_; }; template class JsonPathTest : public ScannerTest { protected: int Parse(const std::string& str) { driver_.ResetScanner(); driver_.SetInput(str); return Parser(&driver_)(); } }; using MyTypes = ::testing::Types; TYPED_TEST_SUITE(JsonPathTest, MyTypes); #define NEXT_TOK(tok_enum) \ { \ auto tok = Lex(); \ ASSERT_EQ(Parser::token::TOK_##tok_enum, tok.type_get()); \ } #define NEXT_EQ(tok_enum, type, val) \ { \ auto tok = Lex(); \ ASSERT_EQ(Parser::token::TOK_##tok_enum, tok.type_get()); \ EXPECT_EQ(val, tok.value.as()); \ } TEST_F(ScannerTest, Basic) { SetInput("$.мага-зин2.book[0].*"); NEXT_TOK(ROOT); NEXT_TOK(DOT); NEXT_EQ(UNQ_STR, string, "мага-зин2"); NEXT_TOK(DOT); NEXT_EQ(UNQ_STR, string, "book"); NEXT_TOK(LBRACKET); NEXT_EQ(INT, string, "0"); NEXT_TOK(RBRACKET); NEXT_TOK(DOT); NEXT_TOK(WILDCARD); SetInput("|"); NEXT_TOK(YYEOF); SetInput("$..*"); NEXT_TOK(ROOT); NEXT_TOK(DESCENT); NEXT_TOK(WILDCARD); } TEST_F(ScannerTest, FlatToJson) { flatbuffers::Parser parser; const char* json = R"( { "foo": "bar", "bar": 1.5, "strs": ["hello", "world"] } )"; flexbuffers::Builder fbb; ASSERT_TRUE(parser.ParseFlexBuffer(json, nullptr, &fbb)); fbb.Finish(); flexbuffers::Reference root = flexbuffers::GetRoot(fbb.GetBuffer()); JsonType res = FromFlat(root); EXPECT_EQ(res, JsonType::parse(json)); fbb.Clear(); FromJsonType(res, &fbb); fbb.Finish(); string actual; flexbuffers::GetRoot(fbb.GetBuffer()).ToString(false, true, actual); EXPECT_EQ(res, JsonType::parse(actual)); } TYPED_TEST(JsonPathTest, Parser) { EXPECT_NE(0, this->Parse("foo")); EXPECT_NE(0, this->Parse("$foo")); EXPECT_NE(0, this->Parse("$|foo")); EXPECT_EQ(0, this->Parse("$.foo.bar")); Path path = this->driver_.TakePath(); // TODO: to improve the UX with gmock/c++ magic. ASSERT_EQ(2, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::IDENTIFIER)); EXPECT_THAT(path[1], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ("foo", path[0].identifier()); EXPECT_EQ("bar", path[1].identifier()); EXPECT_EQ(0, this->Parse("$.*.bar[1]")); path = this->driver_.TakePath(); ASSERT_EQ(3, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::WILDCARD)); EXPECT_THAT(path[1], SegType(SegmentType::IDENTIFIER)); EXPECT_THAT(path[2], SegType(SegmentType::INDEX)); EXPECT_EQ("bar", path[1].identifier()); EXPECT_EQ(IndexExpr(1, 1), path[2].index()); EXPECT_EQ(0, this->Parse("$.plays[*].game")); EXPECT_EQ(0, this->Parse("$.bar[ -1]")); path = this->driver_.TakePath(); EXPECT_THAT(path[1], SegType(SegmentType::INDEX)); EXPECT_EQ(IndexExpr(-1, -1), path[1].index()); } TYPED_TEST(JsonPathTest, Root) { TypeParam json = ValidJson(R"({"foo" : 1, "bar": "str" })"); ASSERT_EQ(0, this->Parse("$")); Path path = this->driver_.TakePath(); int called = 0; EvaluatePath(path, json, [&](optional, const TypeParam& val) { ++called; ASSERT_TRUE(is_object(val)); ASSERT_EQ(json, val); }); ASSERT_EQ(1, called); } TYPED_TEST(JsonPathTest, Functions) { ASSERT_EQ(0, this->Parse("max($.plays[*].score)")); Path path = this->driver_.TakePath(); ASSERT_EQ(4, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::FUNCTION)); EXPECT_THAT(path[1], SegType(SegmentType::IDENTIFIER)); EXPECT_THAT(path[2], SegType(SegmentType::INDEX)); EXPECT_THAT(path[3], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ(IndexExpr::All(), path[2].index()); TypeParam json = ValidJson(R"({"plays": [{"score": 1}, {"score": 2}]})"); int called = 0; EvaluatePath(path, json, [&](auto, const TypeParam& val) { ++called; ASSERT_TRUE(is_int(val)); ASSERT_EQ(2, to_int(val)); }); ASSERT_EQ(1, called); } TYPED_TEST(JsonPathTest, Descent) { EXPECT_EQ(0, this->Parse("$..foo")); Path path = this->driver_.TakePath(); ASSERT_EQ(2, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::DESCENT)); EXPECT_THAT(path[1], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ("foo", path[1].identifier()); EXPECT_EQ(0, this->Parse("$..*")); ASSERT_EQ(2, path.size()); path = this->driver_.TakePath(); EXPECT_THAT(path[0], SegType(SegmentType::DESCENT)); EXPECT_THAT(path[1], SegType(SegmentType::WILDCARD)); EXPECT_NE(0, this->Parse("$..")); EXPECT_NE(0, this->Parse("$...foo")); } TYPED_TEST(JsonPathTest, QuotedStrings) { EXPECT_EQ(0, this->Parse("$[\"foo\"]")); Path path = this->driver_.TakePath(); ASSERT_EQ(1, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ("foo", path[0].identifier()); EXPECT_EQ(0, this->Parse("$['foo']")); // single quoted string path = this->driver_.TakePath(); ASSERT_EQ(1, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ("foo", path[0].identifier()); EXPECT_EQ(0, this->Parse("$.[\"foo\"]")); path = this->driver_.TakePath(); ASSERT_EQ(1, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ("foo", path[0].identifier()); EXPECT_EQ(0, this->Parse("$..[\"foo\"]")); path = this->driver_.TakePath(); ASSERT_EQ(2, path.size()); EXPECT_THAT(path[0], SegType(SegmentType::DESCENT)); EXPECT_THAT(path[1], SegType(SegmentType::IDENTIFIER)); EXPECT_EQ("foo", path[1].identifier()); EXPECT_NE(0, this->Parse("\"a\"")); EXPECT_NE(0, this->Parse("$\"a\"")); EXPECT_NE(0, this->Parse("$.\"a\"")); EXPECT_NE(0, this->Parse("$..\"a\"")); // Single quoted string EXPECT_NE(0, this->Parse("'a'")); EXPECT_NE(0, this->Parse("$'a'")); EXPECT_NE(0, this->Parse("$.'a'")); EXPECT_NE(0, this->Parse("$..'a'")); } TYPED_TEST(JsonPathTest, Path) { Path path; TypeParam json = ValidJson(R"({"v11":{ "f" : 1, "a2": [0]}, "v12": {"f": 2, "a2": [1]}, "v13": 3 })"); int called = 0; // Empty path EvaluatePath(path, json, [&](optional, const TypeParam& val) { ++called; }); ASSERT_EQ(1, called); called = 0; path.emplace_back(SegmentType::IDENTIFIER, "v13"); EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ++called; ASSERT_EQ(3, to_int(val)); EXPECT_EQ("v13", key); }); ASSERT_EQ(1, called); path.clear(); path.emplace_back(SegmentType::IDENTIFIER, "v11"); path.emplace_back(SegmentType::IDENTIFIER, "f"); called = 0; EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ++called; ASSERT_EQ(1, to_int(val)); EXPECT_EQ("f", key); }); ASSERT_EQ(1, called); path.clear(); path.emplace_back(SegmentType::WILDCARD); path.emplace_back(SegmentType::IDENTIFIER, "f"); called = 0; EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ++called; ASSERT_TRUE(is_int(val)); EXPECT_EQ("f", key); }); ASSERT_EQ(2, called); } TYPED_TEST(JsonPathTest, EvalDescent) { TypeParam json = ValidJson(R"( {"v11":{ "f" : 1, "a2": [0]}, "v12": {"f": 2, "v21": {"f": 3, "a2": [1]}}, "v13": { "a2" : { "b" : {"f" : 4}}} })"); Path path; int called_arr = 0, called_obj = 0; path.emplace_back(SegmentType::DESCENT); path.emplace_back(SegmentType::IDENTIFIER, "a2"); EvaluatePath(path, json, [&](optional key, const TypeParam& val) { EXPECT_EQ("a2", key); if (is_array(val)) { ++called_arr; } else if (is_object(val)) { ++called_obj; } else { FAIL() << "Unexpected type"; } }); ASSERT_EQ(2, called_arr); ASSERT_EQ(1, called_obj); path.pop_back(); path.emplace_back(SegmentType::IDENTIFIER, "f"); int called = 0; EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ASSERT_TRUE(is_int(val)); ASSERT_EQ("f", key); ++called; }); ASSERT_EQ(4, called); json = ValidJson(R"( {"a":[7], "inner": {"a": {"b": 2, "c": 1337}}} )"); path.pop_back(); path.emplace_back(SegmentType::IDENTIFIER, "a"); vector arr; auto gettype = [](const TypeParam& p) { if (is_array(p)) return 'a'; return is_object(p) ? 'o' : 'u'; }; EvaluatePath(path, json, [&](optional key, const TypeParam& val) { arr.push_back(gettype(val)); ASSERT_EQ("a", key); }); ASSERT_THAT(arr, ElementsAre('a', 'o')); } TYPED_TEST(JsonPathTest, EvalDescent2) { TypeParam json = ValidJson(R"( {"a":[{"val": 1}, {"val": 2}, {"val": 3}]} )"); ASSERT_EQ(0, this->Parse("$..val")); Path path = this->driver_.TakePath(); vector arr; EvaluatePath(path, json, [&](optional key, const TypeParam& val) { arr.push_back(to_int(val)); }); ASSERT_THAT(arr, ElementsAre(1, 2, 3)); int called = 0; ASSERT_EQ(0, this->Parse("$..*")); path = this->driver_.TakePath(); EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ++called; }); EXPECT_EQ(7, called); called = 0; json = ValidJson(R"( { "store": { "nums": [ 5 ] } } )"); EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ++called; }); EXPECT_EQ(3, called); } TYPED_TEST(JsonPathTest, Wildcard) { ASSERT_EQ(0, this->Parse("$.arr[*]")); Path path = this->driver_.TakePath(); ASSERT_EQ(2, path.size()); EXPECT_THAT(path[1], SegType(SegmentType::INDEX)); TypeParam json = ValidJson(R"({"arr": [1, 2, 3], "i":1})"); vector arr; EvaluatePath(path, json, [&](optional key, const TypeParam& val) { ASSERT_FALSE(key); arr.push_back(to_int(val)); }); ASSERT_THAT(arr, ElementsAre(1, 2, 3)); ASSERT_EQ(0, this->Parse("$.i[*]")); path = this->driver_.TakePath(); arr.clear(); EvaluatePath(path, json, [&](optional key, const TypeParam& val) { arr.push_back(to_int(val)); }); ASSERT_THAT(arr, ElementsAre()); } TYPED_TEST(JsonPathTest, Mutate) { ASSERT_EQ(0, this->Parse("$[*]")); Path path = this->driver_.TakePath(); TypeParam json = ValidJson(R"([1, 2, 3, 5, 6])"); auto cb = [](optional, JsonType* val) { int intval = val->as(); *val = intval + 1; }; vector arr; if constexpr (std::is_same_v) { MutatePath(path, cb, &json); for (JsonType& el : json.array_range()) { arr.push_back(to_int(el)); } } else { flexbuffers::Builder fbb; MutatePath(path, cb, json, &fbb); FlatJson fj = flexbuffers::GetRoot(fbb.GetBuffer()); auto vec = fj.AsVector(); for (unsigned i = 0; i < vec.size(); ++i) { arr.push_back(to_int(vec[i])); } } ASSERT_THAT(arr, ElementsAre(2, 3, 4, 6, 7)); json = ValidJson(R"( {"a":[7], "inner": {"a": {"bool": true, "c": 42}}} )"); ASSERT_EQ(0, this->Parse("$..a.*")); path = this->driver_.TakePath(); auto cb2 = [](optional key, JsonType* val) { if (val->is_int64() && !key) { // array element *val = 42; } if (val->is_bool()) { *val = false; } }; auto expected = ValidJson(R"({"a":[42],"inner":{"a":{"bool":false,"c":42}}})"); if constexpr (std::is_same_v) { MutatePath(path, cb2, &json); ASSERT_EQ(expected, json); } else { flexbuffers::Builder fbb; MutatePath(path, cb2, json, &fbb); FlatJson fj = flexbuffers::GetRoot(fbb.GetBuffer()); ASSERT_EQ(expected, FromFlat(fj)); } } TYPED_TEST(JsonPathTest, MutateRecursiveDescentKey) { ASSERT_EQ(0, this->Parse("$..value")); Path path = this->driver_.TakePath(); JsonType json = ValidJson(R"({"data":{"value":10,"subdata":{"value":20}}})"); JsonType replacement = ValidJson(R"({"value": 30})"); auto cb = [&](optional key, JsonType* val) { if (key && key.value() == "value" && (val->is_int64() || val->is_double())) { *val = replacement; } }; unsigned reported_matches = MutatePath(path, cb, &json); JsonType expected = ValidJson(R"({"data":{"subdata":{"value":{"value":30}},"value":{"value":30}}})"); EXPECT_EQ(expected, json); EXPECT_EQ(0, reported_matches); } TYPED_TEST(JsonPathTest, SubRange) { TypeParam json = ValidJson(R"({"arr": [1, 2, 3, 4, 5]})"); ASSERT_EQ(0, this->Parse("$.arr[1:2]")); Path path = this->driver_.TakePath(); ASSERT_EQ(2, path.size()); EXPECT_THAT(path[1], SegType(SegmentType::INDEX)); vector arr; auto cb = [&arr](optional key, const TypeParam& val) { ASSERT_FALSE(key); arr.push_back(to_int(val)); }; EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre(2)); arr.clear(); ASSERT_EQ(0, this->Parse("$.arr[0:2]")); path = this->driver_.TakePath(); EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre(1, 2)); arr.clear(); ASSERT_EQ(0, this->Parse("$.arr[2:-1]")); path = this->driver_.TakePath(); EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre(3, 4)); arr.clear(); ASSERT_EQ(0, this->Parse("$.arr[-2:-1]")); path = this->driver_.TakePath(); EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre(4)); arr.clear(); ASSERT_EQ(0, this->Parse("$.arr[-2:-2]")); path = this->driver_.TakePath(); EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre()); arr.clear(); ASSERT_EQ(0, this->Parse("$.arr[:2]")); path = this->driver_.TakePath(); EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre(1, 2)); arr.clear(); ASSERT_EQ(0, this->Parse("$.arr[2:]")); path = this->driver_.TakePath(); EvaluatePath(path, json, cb); ASSERT_THAT(arr, ElementsAre(3, 4, 5)); arr.clear(); } TYPED_TEST(JsonPathTest, DeleteNestedWithSameKey) { // Test for deleting nested elements with the same key using "$..a" // Corresponds to command: JSON.DEL doc1 "$..a" ASSERT_EQ(0, this->Parse("$..a")); Path path = this->driver_.TakePath(); TypeParam json = ValidJson(R"({"a": 1, "nested": {"a": 2, "b": 3}})"); if constexpr (std::is_same_v) { unsigned reported_matches = DeletePath(path, &json); EXPECT_EQ(2, reported_matches); auto expected = ValidJson(R"({"nested": {"b": 3}})"); EXPECT_EQ(expected, json); } else { flexbuffers::Builder fbb; unsigned reported_matches = DeletePath(path, json, &fbb); EXPECT_EQ(2, reported_matches); FlatJson result = flexbuffers::GetRoot(fbb.GetBuffer()); auto expected = ValidJson(R"({"nested": {"b": 3}})"); EXPECT_EQ(expected, FromFlat(result)); } } TYPED_TEST(JsonPathTest, DeleteRecursiveWithKeysAndArrayValues) { ASSERT_EQ(0, this->Parse("$..a")); Path path = this->driver_.TakePath(); TypeParam json = ValidJson( R"({"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [true, "a", "b"]}})"); if constexpr (std::is_same_v) { unsigned reported_matches = DeletePath(path, &json); EXPECT_EQ(1, reported_matches); auto expected = ValidJson(R"({"b": ["a", "b"], "nested": {"b": [true, "a", "b"]}})"); EXPECT_EQ(expected, json); } else { flexbuffers::Builder fbb; unsigned reported_matches = DeletePath(path, json, &fbb); EXPECT_EQ(1, reported_matches); FlatJson result = flexbuffers::GetRoot(fbb.GetBuffer()); auto expected = ValidJson(R"({"b": ["a", "b"], "nested": {"b": [true, "a", "b"]}})"); EXPECT_EQ(expected, FromFlat(result)); } } } // namespace dfly::json ================================================ FILE: src/core/json/lexer_impl.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "src/core/json/lexer_impl.h" #include namespace dfly::json { Lexer::Lexer() { } Lexer::~Lexer() { } std::string Lexer::UnknownTokenMsg() const { std::string res = absl::StrCat("Unknown token '", text(), "'"); return res; } } // namespace dfly::json ================================================ FILE: src/core/json/lexer_impl.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once // We should not include lexer.h when compiling from lexer.cc file because it already // includes lexer.h #ifndef DFLY_LEXER_CC #include "src/core/json/jsonpath_lexer.h" #endif #include "src/core/json/jsonpath_grammar.hh" namespace dfly { namespace json { class Lexer : public AbstractLexer { public: Lexer(); ~Lexer(); Parser::symbol_type Lex() final; private: dfly::json::location loc() { return location(); } std::string UnknownTokenMsg() const; }; } // namespace json } // namespace dfly ================================================ FILE: src/core/json/path.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "src/core/json/path.h" #include #include #include "base/logging.h" #include "core/json/detail/flat_dfs.h" #include "core/json/detail/jsoncons_dfs.h" #include "core/json/jsonpath_grammar.hh" #include "src/core/json/driver.h" #include "src/core/overloaded.h" using namespace std; using nonstd::make_unexpected; namespace dfly::json { using detail::Dfs; using detail::FlatDfs; namespace { class JsonPathDriver : public json::Driver { public: string msg; void Error(const json::location& l, const std::string& msg) final { this->msg = absl::StrCat("Error: ", msg); } }; } // namespace const char* SegmentName(SegmentType type) { switch (type) { case SegmentType::IDENTIFIER: return "IDENTIFIER"; case SegmentType::INDEX: return "INDEX"; case SegmentType::WILDCARD: return "WILDCARD"; case SegmentType::DESCENT: return "DESCENT"; case SegmentType::FUNCTION: return "FUNCTION"; } return nullptr; } IndexExpr IndexExpr::Normalize(size_t array_len) const { if (array_len == 0) return IndexExpr(1, 0); // empty range. IndexExpr res = *this; auto wrap = [array_len](int negative) { unsigned positive = -negative; return positive > array_len ? 0 : array_len - positive; }; if (res.second >= int(array_len)) { res.second = array_len - 1; } else if (res.second < 0) { res.second = wrap(res.second); DCHECK_GE(res.second, 0); } if (res.first < 0) { res.first = wrap(res.first); DCHECK_GE(res.first, 0); } return res; } void PathSegment::Evaluate(const JsonType& json) const { CHECK(type() == SegmentType::FUNCTION); AggFunction* func = std::get>(value_).get(); CHECK(func); func->Apply(json); } void PathSegment::Evaluate(FlatJson json) const { CHECK(type() == SegmentType::FUNCTION); AggFunction* func = std::get>(value_).get(); CHECK(func); func->Apply(json); } AggFunction::Result PathSegment::GetResult() const { CHECK(type() == SegmentType::FUNCTION); const auto& func = std::get>(value_).get(); CHECK(func); return func->GetResult(); } void EvaluatePath(const Path& path, const JsonType& json, PathCallback callback) { if (path.empty()) { // root node callback(nullopt, json); return; } if (path.front().type() != SegmentType::FUNCTION) { Dfs::Traverse(path, json, std::move(callback)); return; } // Handling the case of `func($.somepath)` // We pass our own callback to gather all the results and then call the function. JsonType result(JsonType::null()); absl::Span path_tail(path.data() + 1, path.size() - 1); const PathSegment& func_segment = path.front(); if (path_tail.empty()) { LOG(DFATAL) << "Invalid path"; // parser should not allow this. } else { Dfs::Traverse(path_tail, json, [&](auto, const JsonType& val) { func_segment.Evaluate(val); }); } AggFunction::Result res = func_segment.GetResult(); JsonType val = visit( // Transform the result to JsonType. Overloaded{ [](monostate) { return JsonType::null(); }, [&](double d) { return JsonType(d); }, [&](int64_t i) { return JsonType(i); }, }, res); callback(nullopt, val); } nonstd::expected ParsePath(string_view path) { if (path.size() > 8192) return nonstd::make_unexpected("Path too long"); VLOG(2) << "Parsing path: " << path; JsonPathDriver driver; Parser parser(&driver); driver.SetInput(string(path)); int res = parser(); if (res != 0) { return nonstd::make_unexpected(driver.msg); } return driver.TakePath(); } unsigned MutatePath(const Path& path, MutateCallback callback, JsonType* json) { if (path.empty()) { callback(nullopt, json); return 1; } Dfs dfs = Dfs::Mutate(path, callback, json); return dfs.matches(); } unsigned DeletePath(const Path& path, JsonType* json) { if (path.empty()) { // For empty path, we cannot delete the root JSON itself within this function // as it would require modifying the pointer itself. Return 0 for no deletion. return 0; } Dfs dfs = Dfs::Delete(path, json); return dfs.matches(); } // Flat json path evaluation void EvaluatePath(const Path& path, FlatJson json, PathFlatCallback callback) { if (path.empty()) { // root node callback(nullopt, json); return; } if (path.front().type() != SegmentType::FUNCTION) { FlatDfs::Traverse(path, json, std::move(callback)); return; } // Handling the case of `func($.somepath)` // We pass our own callback to gather all the results and then call the function. FlatJson result; absl::Span path_tail(path.data() + 1, path.size() - 1); const PathSegment& func_segment = path.front(); if (path_tail.empty()) { LOG(DFATAL) << "Invalid path"; // parser should not allow this. } else { FlatDfs::Traverse(path_tail, json, [&](auto, FlatJson val) { func_segment.Evaluate(val); }); } AggFunction::Result res = func_segment.GetResult(); flexbuffers::Builder fbb; FlatJson val = visit( // Transform the result to a flexbuffer reference. Overloaded{ [](monostate) { return FlatJson{}; }, [&](double d) { fbb.Double(d); fbb.Finish(); return flexbuffers::GetRoot(fbb.GetBuffer()); }, [&](int64_t i) { fbb.Int(i); fbb.Finish(); return flexbuffers::GetRoot(fbb.GetBuffer()); }, }, res); callback(nullopt, val); } JsonType FromFlat(FlatJson src) { if (src.IsNull()) { return JsonType::null(); } if (src.IsBool()) { return JsonType(src.AsBool()); } if (src.IsInt()) { return JsonType(src.AsInt64()); } if (src.IsFloat()) { return JsonType(src.AsDouble()); } if (src.IsString()) { flexbuffers::String str = src.AsString(); return JsonType(string_view{str.c_str(), str.size()}); } CHECK(src.IsVector()); auto vec = src.AsVector(); JsonType js = src.IsMap() ? JsonType{jsoncons::json_object_arg} : JsonType{jsoncons::json_array_arg}; auto keys = src.AsMap().Keys(); for (unsigned i = 0; i < vec.size(); ++i) { JsonType value = FromFlat(vec[i]); if (src.IsMap()) { js[keys[i].AsKey()] = std::move(value); } else { js.push_back(std::move(value)); } } return js; } void FromJsonType(const JsonType& src, flexbuffers::Builder* fbb) { if (src.is_null()) { return fbb->Null(); } if (src.is_bool()) { return fbb->Bool(src.as_bool()); } if (src.is_int64()) { return fbb->Int(src.as()); } if (src.is_double()) { return fbb->Double(src.as_double()); } if (src.is_string()) { string_view sv = src.as_string_view(); fbb->String(sv.data(), sv.size()); return; } if (src.is_object()) { auto range = src.object_range(); size_t start = fbb->StartMap(); for (auto it = range.cbegin(); it != range.cend(); ++it) { fbb->Key(it->key().c_str(), it->key().size()); FromJsonType(it->value(), fbb); } fbb->EndMap(start); return; } CHECK(src.is_array()); auto range = src.array_range(); size_t start = fbb->StartVector(); for (auto it = range.cbegin(); it != range.cend(); ++it) { FromJsonType(*it, fbb); } fbb->EndVector(start, false, false); } unsigned MutatePath(const Path& path, MutateCallback callback, FlatJson json, flexbuffers::Builder* fbb) { JsonType mut_json = FromFlat(json); unsigned res = MutatePath(path, std::move(callback), &mut_json); // Populate the output builder 'fbb' with the resulting JSON state // (mutated or original if res == 0) and finalize it. // The builder MUST be finished before returning so that the caller // can safely access the resulting flatbuffer data (e.g., via GetBuffer()). // Skipping Finish() would leave the builder in an invalid, unusable state. FromJsonType(mut_json, fbb); // Always convert (changed or not) JSON fbb->Finish(); // Always finish the builder // Return the number of actual mutations that occurred. return res; } unsigned DeletePath(const Path& path, FlatJson json, flexbuffers::Builder* fbb) { JsonType mut_json = FromFlat(json); unsigned res = DeletePath(path, &mut_json); FromJsonType(mut_json, fbb); fbb->Finish(); return res; } } // namespace dfly::json ================================================ FILE: src/core/json/path.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include "core/flatbuffers.h" #include "core/json/json_object.h" namespace dfly::json { enum class SegmentType { IDENTIFIER = 1, // $.identifier INDEX = 2, // $.array[index_expr] WILDCARD = 3, // $.* DESCENT = 4, // $..identifier FUNCTION = 5, // max($.prices[*]) }; const char* SegmentName(SegmentType type); class AggFunction { public: using Result = std::variant; virtual ~AggFunction() { } void Apply(const JsonType& src) { if (valid_ != 0) valid_ = ApplyImpl(src); } void Apply(FlatJson src) { if (valid_ != 0) valid_ = ApplyImpl(src); } // returns null if Apply was not called or ApplyImpl failed. Result GetResult() const { return valid_ == 1 ? GetResultImpl() : Result{}; } protected: virtual bool ApplyImpl(const JsonType& src) = 0; virtual bool ApplyImpl(FlatJson src) = 0; virtual Result GetResultImpl() const = 0; int valid_ = -1; }; // Bracket index representation, IndexExpr is a closed range, i.e. both ends are inclusive. // Single index is: , wildcard: <0, INT_MAX>, // [begin:end): // IndexExpr is 0-based, with negative indices referring to the array size of the applied object. struct IndexExpr : public std::pair { bool Empty() const { return first > second; } static IndexExpr All() { return IndexExpr{0, INT_MAX}; } using pair::pair; // Returns subrange with length `array_len`. IndexExpr Normalize(size_t array_len) const; // Returns IndexExpr representing [left_closed, right_open) range. static IndexExpr HalfOpen(int left_closed, int right_open) { return IndexExpr(left_closed, right_open - 1); } }; class PathSegment { public: PathSegment() : PathSegment(SegmentType::IDENTIFIER) { } PathSegment(SegmentType type, std::string identifier = std::string()) : type_(type), value_(std::move(identifier)) { } PathSegment(SegmentType type, IndexExpr index) : type_(type), value_(index) { } explicit PathSegment(std::shared_ptr func) : type_(SegmentType::FUNCTION), value_(std::move(func)) { } SegmentType type() const { return type_; } const std::string& identifier() const { return std::get(value_); } IndexExpr index() const { return std::get(value_); } void Evaluate(const JsonType& json) const; void Evaluate(FlatJson json) const; AggFunction::Result GetResult() const; private: SegmentType type_; // shared_ptr to preserve copy semantics. std::variant> value_; }; using Path = std::vector; // Passes the key name for object fields or nullopt for array elements. // The second argument is a json value of either object fields or array elements. using PathCallback = absl::FunctionRef, const JsonType&)>; using PathFlatCallback = absl::FunctionRef, FlatJson)>; // Returns true if the entry should be deleted, false otherwise. using MutateCallback = absl::FunctionRef, JsonType*)>; void EvaluatePath(const Path& path, const JsonType& json, PathCallback callback); // Same as above but for flatbuffers. void EvaluatePath(const Path& path, FlatJson json, PathFlatCallback callback); // returns number of matches found with the given path. unsigned MutatePath(const Path& path, MutateCallback callback, JsonType* json); unsigned MutatePath(const Path& path, MutateCallback callback, FlatJson json, flexbuffers::Builder* fbb); // Simplified deletion operation without callback - more efficient for JSON.DEL operations unsigned DeletePath(const Path& path, JsonType* json); unsigned DeletePath(const Path& path, FlatJson json, flexbuffers::Builder* fbb); // utility function to parse a jsonpath. Returns an error message if a parse error was // encountered. nonstd::expected ParsePath(std::string_view path); // Transforms FlatJson to JsonType. JsonType FromFlat(FlatJson src); // Transforms JsonType to a buffer using flexbuffers::Builder. // Does not call flexbuffers::Builder::Finish. void FromJsonType(const JsonType& src, flexbuffers::Builder* fbb); } // namespace dfly::json ================================================ FILE: src/core/linear_search_map.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "base/logging.h" namespace dfly { /* LinearSearchMap is a small key-value map implemented using an inlined vector of (key, value) pairs. It performs key lookup using linear search (O(n)) and is optimized for small maps (typically <32 keys). Compared to a hash map, it avoids hashing overhead and has better memory locality and cache performance. Use it when: - The number of keys is small - You care about minimal memory usage - Fast iteration is more important than fast lookup NOTE: - Insert() and Emplace() do NOT check for duplicate keys at runtime. Inserting a duplicate key results in undefined behavior. - You must ensure keys are unique when inserting. - This syntax is used to maintain compatibility with absl::InlinedVector. */ template class LinearSearchMap : public absl::InlinedVector, N> { private: using Base = absl::InlinedVector, N>; public: using Base::operator[]; using Base::erase; using iterator = typename Base::iterator; using const_iterator = typename Base::const_iterator; // Does not check if key already exists. // If key already exists - undefined behavior. void insert(Key key, Value value); template void emplace(Key key, Args&&... args); void erase(const Key& key); bool contains(const Key& key) const; iterator find(const Key& key); const_iterator find(const Key& key) const; size_t find_index(const Key& key) const; Value& operator[](const Key& key); const Value& operator[](const Key& key) const; }; // Implementation /******************************************************************/ template void LinearSearchMap::insert(Key key, Value value) { DCHECK(!contains(key)) << "Key already exists: " << key; this->emplace_back(std::move(key), std::move(value)); } template template void LinearSearchMap::emplace(Key key, Args&&... args) { DCHECK(!contains(key)) << "Key already exists: " << key; this->emplace_back(std::piecewise_construct, std::forward_as_tuple(std::move(key)), std::forward_as_tuple(std::forward(args)...)); } template void LinearSearchMap::erase(const Key& key) { erase(find(key)); } template bool LinearSearchMap::contains(const Key& key) const { return find(key) != this->end(); } template typename LinearSearchMap::iterator LinearSearchMap::find( const Key& key) { return std::find_if(this->begin(), this->end(), [&key](const auto& pair) { return pair.first == key; }); } template typename LinearSearchMap::const_iterator LinearSearchMap::find( const Key& key) const { return std::find_if(this->begin(), this->end(), [&key](const auto& pair) { return pair.first == key; }); } template size_t LinearSearchMap::find_index(const Key& key) const { return std::distance(this->begin(), find(key)); } template Value& LinearSearchMap::operator[](const Key& key) { return find(key)->second; } template const Value& LinearSearchMap::operator[](const Key& key) const { return find(key)->second; } } // namespace dfly ================================================ FILE: src/core/linear_search_map_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/linear_search_map.h" #include #include #include #include "base/gtest.h" #include "base/logging.h" namespace dfly { class LinearSearchMapTest : public testing::Test { protected: }; TEST_F(LinearSearchMapTest, Insert) { LinearSearchMap map; for (int i = 0; i < 100; ++i) { map.insert(i, i * 1.1); } for (int i = 199; i >= 100; --i) { map.insert(i, i * 12.1); } for (int i = 0; i < 200; ++i) { auto it = map.find(i); EXPECT_NE(it, map.end()); EXPECT_TRUE(map.contains(i)); EXPECT_EQ(it->second, (i < 100) ? i * 1.1 : i * 12.1); } } TEST_F(LinearSearchMapTest, Emplace) { struct Value { Value(double value_, std::string str_) : value(value_), str(std::move(str_)) { } double value; std::string str; }; LinearSearchMap map; for (int i = 0; i < 100; ++i) { map.emplace(i, i * 1.1, "value_" + std::to_string(i)); } for (int i = 199; i >= 100; --i) { map.emplace(i, i * 12.1, "value_" + std::to_string(i)); } for (int i = 0; i < 200; ++i) { auto it = map.find(i); EXPECT_NE(it, map.end()); EXPECT_TRUE(map.contains(i)); EXPECT_EQ(it->second.value, (i < 100) ? i * 1.1 : i * 12.1); EXPECT_EQ(it->second.str, "value_" + std::to_string(i)); } } TEST_F(LinearSearchMapTest, EraseSimple) { LinearSearchMap map; for (int i = 0; i < 200; ++i) { map.insert(i, i * 1.1); } // Erase by iterator for (int i = 0; i < 100; ++i) { auto it = map.find(i); EXPECT_NE(it, map.end()); EXPECT_TRUE(map.contains(i)); map.erase(it); EXPECT_FALSE(map.contains(i)); } // Erase by key for (int i = 100; i < 200; ++i) { EXPECT_TRUE(map.contains(i)); map.erase(i); EXPECT_FALSE(map.contains(i)); } EXPECT_TRUE(map.empty()); } TEST_F(LinearSearchMapTest, Erase) { std::unordered_map expected_map; LinearSearchMap map; // First wave insert / erase for (int i = 0; i < 300; i++) { double value = i * 1.1; map.insert(i, value); expected_map[i] = value; } for (int i = 0; i < 300; i += 3) { EXPECT_TRUE(map.contains(i)); map.erase(i); expected_map.erase(i); EXPECT_FALSE(map.contains(i)); } // Second wave insert / erase for (int i = 300; i < 600; i++) { double value = i * 2.2; map.insert(i, value); expected_map[i] = value; } for (int i = 300; i < 600; i += 5) { EXPECT_TRUE(map.contains(i)); map.erase(i); expected_map.erase(i); EXPECT_FALSE(map.contains(i)); } // Erase all remaining elements while (!expected_map.empty()) { size_t index = 0; const size_t step = 7; for (auto it = expected_map.begin(); it != expected_map.end(); ++index) { auto [i, value] = *it; EXPECT_TRUE(map.contains(i)); EXPECT_EQ(map.find(i)->second, value); if (index % step == 0) { map.erase(i); it = expected_map.erase(it); } else { ++it; } } } EXPECT_TRUE(map.empty()); } TEST_F(LinearSearchMapTest, BasicFunctionality) { LinearSearchMap map; for (double i = 0; i < 100; ++i) { map.insert(i, i * 1.1); } EXPECT_EQ(map.size(), 100); // Using indexes for (size_t i = 0; i < map.size(); ++i) { auto [key, value] = map[i]; EXPECT_EQ(value, key * 1.1); } // Get index by key for (double i = 0; i < 100; ++i) { size_t index = map.find_index(i); auto [key, value] = map[index]; EXPECT_EQ(value, key * 1.1); } // Get value by key for (double i = 0; i < 100; ++i) { EXPECT_EQ(map[i], i * 1.1); } // Iterate through the map for (const auto& [key, value] : map) { EXPECT_EQ(value, key * 1.1); } } } // namespace dfly ================================================ FILE: src/core/listpack_test.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/detail/listpack.h" #include #include #include "base/gtest.h" #include "base/logging.h" extern "C" { #include "redis/listpack.h" #include "redis/zmalloc.h" } namespace dfly { namespace detail { using namespace std; using namespace testing; class ListPackTest : public ::testing::Test { protected: static void SetUpTestSuite() { init_zmalloc_threadlocal(mi_heap_get_backing()); } void SetUp() override { ptr_ = lpNew(0); lp_ = ListPack(ptr_); } void TearDown() override { ptr_ = lp_.GetPointer(); lpFree(ptr_); // Ensure there are no memory leaks after every test EXPECT_EQ(zmalloc_used_memory_tl, 0); } unsigned Remove(string_view elem, unsigned count, QList::Where where) { return lp_.Remove(CollectionEntry{elem.data(), elem.size()}, count, where); } ListPack lp_; uint8_t* ptr_ = nullptr; }; TEST_F(ListPackTest, FindNotFound) { lp_.Push("first", QList::TAIL); lp_.Push("third", QList::TAIL); EXPECT_EQ(lp_.Find("second"), nullptr); } TEST_F(ListPackTest, RemoveIntegerFromHead) { lp_.Push("1", QList::TAIL); lp_.Push("2", QList::TAIL); lp_.Push("1", QList::TAIL); lp_.Push("3", QList::TAIL); // Remove integer value "1" from head unsigned removed = Remove("1", 0, QList::HEAD); EXPECT_EQ(2, removed); EXPECT_EQ(2, lp_.Size()); EXPECT_EQ("2", lp_.At(0)); EXPECT_EQ("3", lp_.At(1)); } TEST_F(ListPackTest, RemoveFromTailAll) { // List: a, b, a, c, a lp_.Push("a", QList::TAIL); lp_.Push("b", QList::TAIL); lp_.Push("a", QList::TAIL); lp_.Push("c", QList::TAIL); lp_.Push("a", QList::TAIL); // Remove all "a" from tail direction unsigned removed = Remove("a", 0, QList::TAIL); EXPECT_EQ(3, removed); EXPECT_EQ(2, lp_.Size()); // Remaining elements: b, c EXPECT_EQ("b", lp_.At(0)); EXPECT_EQ("c", lp_.At(1)); } TEST_F(ListPackTest, RemoveFromTailWithCount) { // List: a, b, a, c, a lp_.Push("a", QList::TAIL); lp_.Push("b", QList::TAIL); lp_.Push("a", QList::TAIL); lp_.Push("c", QList::TAIL); lp_.Push("a", QList::TAIL); // Remove only 2 occurrences of "a" from tail (removes indices 4 and 2) unsigned removed = Remove("a", 2, QList::TAIL); EXPECT_EQ(2, removed); EXPECT_EQ(3, lp_.Size()); // Remaining elements: a, b, c EXPECT_EQ("a", lp_.At(0)); EXPECT_EQ("b", lp_.At(1)); EXPECT_EQ("c", lp_.At(2)); } // Test removing consecutive tail elements - verifies lpLast is called correctly // after deleting the tail element to continue finding remaining matches. TEST_F(ListPackTest, RemoveFromTailConsecutive) { // List: x, target, target, target - three consecutive at tail lp_.Push("x", QList::TAIL); lp_.Push("target", QList::TAIL); lp_.Push("target", QList::TAIL); lp_.Push("target", QList::TAIL); unsigned removed = Remove("target", 0, QList::TAIL); EXPECT_EQ(3, removed); EXPECT_EQ(1, lp_.Size()); EXPECT_EQ("x", lp_.At(0)); } // Test removing the head element while iterating from TAIL direction. // After checking all elements from tail to head and deleting the head, // lpDelete returns pointer to element after head, and lpPrev on that returns nullptr, // correctly ending iteration. TEST_F(ListPackTest, RemoveFromTailDeletesHead) { // List: a, b, c - removing "a" (at head) while iterating from tail lp_.Push("a", QList::TAIL); lp_.Push("b", QList::TAIL); lp_.Push("c", QList::TAIL); unsigned removed = Remove("a", 0, QList::TAIL); EXPECT_EQ(1, removed); EXPECT_EQ(2, lp_.Size()); EXPECT_EQ("b", lp_.At(0)); EXPECT_EQ("c", lp_.At(1)); } TEST_F(ListPackTest, ReplaceAtIndex) { lp_.Push("first", QList::TAIL); lp_.Push("second", QList::TAIL); lp_.Push("third", QList::TAIL); // Replace element at index 1 uint8_t* pos = lp_.Seek(1); EXPECT_NE(pos, nullptr); lp_.Replace(pos, "replaced"); EXPECT_EQ(3, lp_.Size()); EXPECT_EQ("first", lp_.At(0)); EXPECT_EQ("replaced", lp_.At(1)); EXPECT_EQ("third", lp_.At(2)); } TEST_F(ListPackTest, ReplaceAtNegativeIndex) { lp_.Push("first", QList::TAIL); lp_.Push("second", QList::TAIL); lp_.Push("third", QList::TAIL); // Replace element at index -1 (last element) uint8_t* pos = lp_.Seek(-1); EXPECT_NE(pos, nullptr); lp_.Replace(pos, "new_last"); EXPECT_EQ(3, lp_.Size()); EXPECT_EQ("first", lp_.At(0)); EXPECT_EQ("second", lp_.At(1)); EXPECT_EQ("new_last", lp_.At(2)); } TEST_F(ListPackTest, ReplaceOutOfBounds) { lp_.Push("first", QList::TAIL); lp_.Push("second", QList::TAIL); // Replace at out-of-bounds index should return false uint8_t* pos = lp_.Seek(5); EXPECT_EQ(pos, nullptr); pos = lp_.Seek(-5); EXPECT_EQ(pos, nullptr); } TEST_F(ListPackTest, ReplaceWithLargerString) { lp_.Push("a", QList::TAIL); lp_.Push("b", QList::TAIL); // Replace with a much larger string string large(500, 'x'); uint8_t* pos = lp_.Seek(0); EXPECT_NE(pos, nullptr); lp_.Replace(pos, large); EXPECT_EQ(2, lp_.Size()); EXPECT_EQ(large, lp_.At(0)); EXPECT_EQ("b", lp_.At(1)); } TEST_F(ListPackTest, ReplaceWithEmptyString) { lp_.Push("first", QList::TAIL); lp_.Push("second", QList::TAIL); // Replace with empty string uint8_t* pos = lp_.Seek(0); EXPECT_NE(pos, nullptr); lp_.Replace(pos, ""); EXPECT_EQ(2, lp_.Size()); EXPECT_EQ("", lp_.At(0)); EXPECT_EQ("second", lp_.At(1)); } } // namespace detail } // namespace dfly ================================================ FILE: src/core/memory_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // // Disable mimalloc internal debug assertions for accessing internal structures #define MI_DEBUG 0 #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" // Stub out internal mimalloc assertions that aren't exported // These are used by inline functions in internal.h [[noreturn]] void _mi_assert_fail(const char* assertion, const char* fname, unsigned int line, const char* func) noexcept { fprintf(stderr, "mimalloc assertion failed: %s at %s:%u in %s\n", assertion, fname, line, func); abort(); } namespace dfly { class MiHeapTest : public ::testing::Test { protected: MiHeapTest() { } }; TEST_F(MiHeapTest, Basic) { mi_heap_t* heap = mi_heap_get_default(); void* ptr = mi_heap_malloc_aligned(heap, 1024 /* size*/, 64 /* alignment*/); ASSERT_TRUE(ptr != nullptr); EXPECT_EQ(heap->tld->stats.malloc_normal.current, 1024); EXPECT_EQ(heap->tld->stats.malloc_huge.current, 0); void* ptr2 = mi_heap_malloc_aligned(heap, 1024 * 1024 /* size*/, 64 /* alignment*/); EXPECT_EQ(heap->tld->stats.malloc_normal.current, 1024); EXPECT_GE(heap->tld->stats.malloc_huge.current, 1024 * 1024); mi_free(ptr); EXPECT_EQ(heap->tld->stats.malloc_normal.current, 0); EXPECT_GE(heap->tld->stats.malloc_huge.current, 1024 * 1024); mi_free(ptr2); EXPECT_EQ(heap->tld->stats.malloc_huge.current, 0); } TEST_F(MiHeapTest, Threaded) { mi_heap_t* heap = mi_heap_get_default(); void* ptr = mi_heap_malloc_aligned(heap, 1024 /* size*/, 64 /* alignment*/); ASSERT_TRUE(ptr != nullptr); // adding ptr to heap->thread_delayed_free std::thread t2([ptr]() { mi_free(ptr); // thread local stats are updated. EXPECT_EQ(mi_heap_get_default()->tld->stats.malloc_normal.current, -1024); }); t2.join(); EXPECT_EQ(heap->tld->stats.malloc_normal.current, 1024); EXPECT_EQ(heap->generic_collect_count, 0); // Force many mallocs to trigger delayed blocks collection. for (unsigned i = 0; i < 200; ++i) { ptr = mi_malloc(16 * i); mi_free(ptr); } // delayed collections was triggered EXPECT_GE(heap->generic_collect_count, 1); // mi_malloc does not track malloc back sizes back to the original heap threads. EXPECT_EQ(heap->tld->stats.malloc_normal.current, 1024); } // Verify that xthread_free lists are processed correctly during force collection // on full pages. TEST_F(MiHeapTest, FullPageThreadFreeInternal) { mi_heap_t* heap = mi_heap_get_default(); constexpr size_t block_size = 64; std::vector allocations; // Allocate blocks until page is full void* first_ptr = mi_heap_malloc(heap, block_size); ASSERT_TRUE(first_ptr != nullptr); allocations.push_back(first_ptr); mi_page_t* page = _mi_ptr_page(first_ptr); ASSERT_TRUE(page != nullptr); while (page->used < page->capacity) { void* ptr = mi_heap_malloc(heap, block_size); ASSERT_TRUE(ptr != nullptr); if (_mi_ptr_page(ptr) == page) { allocations.push_back(ptr); } else { mi_free(ptr); break; } } EXPECT_EQ(page->used, page->capacity); // Free one block from another thread void* cross_thread_ptr = allocations.back(); allocations.pop_back(); std::thread t([cross_thread_ptr]() { mi_free(cross_thread_ptr); }); t.join(); EXPECT_EQ(page->used, page->capacity); EXPECT_NE(mi_atomic_load_relaxed(&page->xthread_free), 0); // Force collection should process xthread_free mi_heap_collect(heap, true); EXPECT_LT(page->used, page->capacity); EXPECT_EQ(mi_atomic_load_relaxed(&page->xthread_free), 0); // New allocation should reuse the freed block void* new_ptr = mi_heap_malloc(heap, block_size); EXPECT_EQ(_mi_ptr_page(new_ptr), page); // Clean up mi_free(new_ptr); for (void* ptr : allocations) { mi_free(ptr); } } // Verify that MI_BIN_FULL pages are cleared during collection. TEST_F(MiHeapTest, FullBinQueueCollection) { mi_heap_t* heap = mi_heap_get_default(); constexpr size_t block_size = 64; auto count_xthread_free = [&heap]() { size_t count = 0; for (size_t i = 0; i <= MI_BIN_FULL; ++i) { for (mi_page_t* page = heap->pages[i].first; page != nullptr; page = page->next) { if (mi_atomic_load_relaxed(&page->xthread_free) != 0) { count++; } } } return count; }; // Allocate and cross-thread free to populate xthread_free lists std::vector allocations(2000); for (size_t i = 0; i < allocations.size(); ++i) { allocations[i] = mi_heap_malloc(heap, block_size); ASSERT_TRUE(allocations[i] != nullptr); } std::thread t([&allocations]() { for (size_t i = 0; i < allocations.size() / 2; ++i) { mi_free(allocations[i]); } }); t.join(); size_t xthread_before = count_xthread_free(); EXPECT_GT(xthread_before, 0); mi_heap_collect(heap, true); EXPECT_EQ(count_xthread_free(), 0) << "All xthread_free lists should be cleared"; // Clean up for (size_t i = allocations.size() / 2; i < allocations.size(); ++i) { mi_free(allocations[i]); } } // Test that verifies memory accounting and reclamation behavior when allocations are made in // one thread and freed in another after the allocating thread exits. This exercises the // MI_ABANDON / cross-thread free handling where mimalloc should properly reclaim pages from // the abandoned thread heap once collection runs. // // This test uses the default heap and verifies reclamation by checking its statistics. TEST_F(MiHeapTest, AbandonedHeapReclamation) { constexpr size_t block_size = 128; constexpr size_t num_blocks = 2000; std::vector allocations(num_blocks); mi_heap_t* main_heap = mi_heap_get_default(); // Allocate memory in a separate thread, then exit the thread std::thread allocator_thread([&]() { for (size_t i = 0; i < num_blocks; ++i) { allocations[i] = mi_malloc(block_size); ASSERT_TRUE(allocations[i] != nullptr); } }); allocator_thread.join(); // Free all allocations from the main thread (cross-thread free to abandoned heap) for (void* ptr : allocations) { mi_free(ptr); } // Force collection to reclaim abandoned segments mi_collect(true); // Verify memory and abandoned pages are reclaimed EXPECT_EQ(main_heap->tld->stats.malloc_normal.current, 0); } } // namespace dfly ================================================ FILE: src/core/mi_memory_resource.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/mi_memory_resource.h" #include #include "base/logging.h" namespace dfly { using namespace std; void* MiMemoryResource::do_allocate(size_t size, size_t align) { DCHECK(align); void* res = mi_heap_malloc_aligned(heap_, size, align); if (!res) throw bad_alloc{}; // It seems that mimalloc has a bug with larger allocations that causes // mi_heap_contains_block to lie. See https://github.com/microsoft/mimalloc/issues/587 // For now I avoid the check by checking the size. mi_usable_size works though. DCHECK(size > 33554400 || mi_heap_contains_block(heap_, res)); size_t delta = mi_usable_size(res); used_ += delta; DVLOG(1) << "do_allocate: " << heap_ << " " << delta; return res; } void MiMemoryResource::do_deallocate(void* ptr, size_t size, size_t align) { DCHECK(size > 33554400 || mi_heap_contains_block(heap_, ptr)); size_t usable = mi_usable_size(ptr); DVLOG(1) << "do_deallocate: " << heap_ << " " << usable; DCHECK_GE(used_, size); used_ -= usable; mi_free_size_aligned(ptr, size, align); } } // namespace dfly ================================================ FILE: src/core/mi_memory_resource.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "base/pmr/memory_resource.h" namespace dfly { // Per thread memory resource that uses mimalloc. class MiMemoryResource : public PMR_NS::memory_resource { public: explicit MiMemoryResource(mi_heap_t* heap) : heap_(heap) { } mi_heap_t* heap() { return heap_; } size_t used() const { return used_; } private: void* do_allocate(std::size_t size, std::size_t align) final; void do_deallocate(void* ptr, std::size_t size, std::size_t align) final; bool do_is_equal(const PMR_NS::memory_resource& o) const noexcept { return this == &o; } mi_heap_t* heap_; size_t used_ = 0; }; } // namespace dfly ================================================ FILE: src/core/oah_entry.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/oah_entry.h" #include "base/hash.h" #include "base/logging.h" namespace dfly { OAHEntry::OAHEntry(std::string_view key, uint32_t expiry) { uint32_t key_size = key.size(); uint32_t expiry_size = (expiry != UINT32_MAX) * sizeof(expiry); uint32_t key_len_field_size = key_size <= std::numeric_limits::max() ? 1 : 4; auto size = key_len_field_size + key_size + expiry_size; auto* expiry_pos = (char*)zmalloc(size); data_ = reinterpret_cast(expiry_pos); if (expiry_size) { SetExpiryBit(true); std::memcpy(expiry_pos, &expiry, sizeof(expiry)); } auto* key_size_pos = expiry_pos + expiry_size; if (key_len_field_size == 1) { SetSsoBit(); uint8_t sso_key_size = key_size; std::memcpy(key_size_pos, &sso_key_size, key_len_field_size); } else { std::memcpy(key_size_pos, &key_size, key_len_field_size); } auto* key_pos = key_size_pos + key_len_field_size; std::memcpy(key_pos, key.data(), key_size); } // returns the expiry time of the current entry or UINT32_MAX if no expiry is set. uint32_t OAHEntry::GetExpiry() const { std::uint32_t res = UINT32_MAX; if (HasExpiry()) { assert(!IsVector()); std::memcpy(&res, Raw(), sizeof(res)); } return res; } bool OAHEntry::CheckNoCollisions(const uint64_t ext_hash) { auto stored_hash = GetHash(); return ((stored_hash != ext_hash) & (stored_hash != 0)) | (Empty()); } void OAHEntry::SetExtHash(uint64_t ext_hash) { assert(data_); assert(!IsVector()); data_ = (data_ & ~kExtHashShiftedMask) | (ext_hash << kExtHashShift); } void OAHEntry::SetExpiry(uint32_t at_sec) { assert(!IsVector()); if (HasExpiry()) { auto* expiry_pos = Raw(); std::memcpy(expiry_pos, &at_sec, sizeof(at_sec)); } else { *this = OAHEntry(Key(), at_sec); } } void OAHEntry::ExpireIfNeeded(uint32_t time_now, uint32_t* set_size, size_t* alloc_used) { assert(!IsVector()); if (GetExpiry() <= time_now) { *alloc_used -= AllocSize(); Clear(); --*set_size; } } // TODO refactor, because it's inefficient size_t OAHEntry::Insert(OAHEntry&& e) { if (Empty()) { *this = std::move(e); return 0; } else if (!IsVector()) { OAHEntry tmp(PtrVector::FromLogSize(1)); auto& arr = tmp.AsVector(); arr[0] = std::move(*this); arr[1] = std::move(e); auto res = arr.AllocSize(); *this = std::move(tmp); return res; } else { auto& arr = AsVector(); size_t i = 0; for (; i < arr.Size(); ++i) { if (!arr[i]) { arr[i] = std::move(e); return 0; } } size_t prev_alloc_size = arr.AllocSize(); auto new_pos = arr.Size(); arr.ResizeLog(arr.LogSize() + 1); arr[new_pos] = (std::move(e)); return arr.AllocSize() - prev_alloc_size; } } uint32_t OAHEntry::ElementsNum() { if (Empty()) { return 0; } else if (!IsVector()) { return 1; } return AsVector().Size(); } // TODO remove, it is inefficient OAHEntry& OAHEntry::operator[](uint32_t pos) { assert(!Empty()); if (!IsVector()) { assert(pos == 0); return *this; } else { auto& arr = AsVector(); assert(pos < arr.Size()); return arr[pos]; } } OAHEntry OAHEntry::Remove(uint32_t pos) { if (Empty()) { // I'm not sure that this scenario should be check at all assert(pos == 0); return OAHEntry(); } else if (!IsVector()) { assert(pos == 0); return std::move(*this); } else { auto& arr = AsVector(); assert(pos < arr.Size()); return std::move(arr[pos]); } } OAHEntry OAHEntry::Pop() { if (IsVector()) { auto& arr = AsVector(); for (auto& e : arr) { if (e) return std::move(e); } return {}; } return std::move(*this); } void OAHEntry::Clear() { // TODO add optimization to avoid destructor calls during vector allocator if (!data_) return; if (IsVector()) { AsVector().~PtrVector(); } else { zfree(Raw()); } data_ = 0; } uint32_t OAHEntry::GetKeySize() const { if (HasSso()) { uint8_t size = 0; std::memcpy(&size, Raw() + GetExpirySize(), sizeof(size)); return size; } uint32_t size = 0; std::memcpy(&size, Raw() + GetExpirySize(), sizeof(size)); return size; } void OAHEntry::SetExpiryBit(bool b) { if (b) data_ |= kExpiryBit; else data_ &= ~kExpiryBit; } size_t OAHEntry::Size() { size_t key_field_size = HasSso() ? 1 : 4; size_t expiry_field_size = HasExpiry() ? 4 : 0; return expiry_field_size + key_field_size + GetKeySize(); } } // namespace dfly ================================================ FILE: src/core/oah_entry.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "base/hash.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { #define PREFETCH_READ(x) __builtin_prefetch(x, 0, 1) #define FORCE_INLINE __attribute__((always_inline)) // TODO add allocator support template class PtrVector { static constexpr size_t kVectorBit = 1ULL << 0; // first 3 bits aren't used by pointer static constexpr size_t kTagMask = (4095ULL << 52) | 7; // we reserve 12 high bits and 3 low bits static constexpr size_t kLogSizeShift = 56; static constexpr size_t kLogSizeMask = 0xFFULL; static constexpr size_t kLogSizeShiftedMask = kLogSizeMask << kLogSizeShift; public: static PtrVector FromLogSize(uint64_t log_size) { return PtrVector(log_size); } T* begin() const { return &Raw()[0]; } T* end() const { return &Raw()[Size()]; } PtrVector(PtrVector&& other) { uptr_ = other.uptr_; other.uptr_ = 0; } ~PtrVector() { Clear(); } size_t LogSize() const { return (uptr_ >> kLogSizeShift) & kLogSizeMask; } size_t Size() const { return 1 << LogSize(); } uint64_t Release() { uint64_t res = uptr_; uptr_ = 0; return res; } bool Empty() const { if (uptr_ == 0) return true; for (auto& el : *this) { if (el) return false; } return true; } void ResizeLog(uint64_t new_log_size) { auto new_ptr = reinterpret_cast(zmalloc(sizeof(T) << new_log_size)); size_t new_size = 1 << new_log_size; const size_t size = std::min(Size(), new_size); for (size_t i = 0; i < size; ++i) { new (new_ptr + i) T(std::move(Raw()[i])); } for (size_t i = size; i < new_size; ++i) { new (new_ptr + i) T(); } Clear(); uptr_ = reinterpret_cast(new_ptr); SetLogSize(new_log_size); } T& operator[](size_t idx) { return Raw()[idx]; } const T& operator[](size_t idx) const { return Raw()[idx]; } T* Raw() const { return (T*)(uptr_ & ~kTagMask); } size_t AllocSize() const { return Size() * sizeof(T); } private: void Clear() { const size_t size = Size(); T* raw = Raw(); if (!raw) return; for (size_t i = 0; i < size; ++i) { if (raw[i]) raw[i].~T(); } zfree(Raw()); uptr_ = 0; } // because of log_size I prefer to hide it PtrVector(uint64_t log_size) { assert(log_size <= 32); uptr_ = reinterpret_cast(zmalloc(sizeof(T) << log_size)); const uint64_t size = 1 << log_size; for (uint64_t i = 0; i < size; ++i) { new (reinterpret_cast(uptr_) + i) T(); } SetLogSize(log_size); } void SetLogSize(uint64_t log_size) { uptr_ = (uptr_ & ~kLogSizeShiftedMask) | kVectorBit | (uint64_t(log_size) << kLogSizeShift); } uint64_t uptr_ = 0; }; // doesn't possess memory, it should be created and release manually class OAHEntry { public: // we can assume that high 12 bits of user address space // can be used for tagging. At most 52 bits of address are reserved for // some configurations, and usually it's 48 bits. // https://docs.kernel.org/arch/arm64/memory.html // first 3 bits aren't used by pointer static constexpr size_t kVectorBit = 1ULL << 0; static constexpr size_t kExpiryBit = 1ULL << 1; // if bit is set the string length field is 1 byte instead of 4 static constexpr size_t kSsoBit = 1ULL << 2; // extended hash allows us to reduce keys comparisons static constexpr size_t kExtHashShift = 52; static constexpr uint32_t kExtHashSize = 12; static constexpr size_t kExtHashMask = 0xFFFULL; static constexpr size_t kExtHashShiftedMask = kExtHashMask << kExtHashShift; static constexpr size_t kTagMask = (4095ULL << 52) | 7; // we reserve 12 high bits and 3 low. OAHEntry() = default; OAHEntry(std::string_view key, uint32_t expiry = UINT32_MAX); // TODO add initializer list constructor OAHEntry(PtrVector&& vec) { data_ = vec.Release() | kVectorBit; } OAHEntry(const OAHEntry& e) = delete; OAHEntry(OAHEntry&& e) { data_ = e.data_; e.data_ = 0; } // consider manual removing, we waste a lot of time to check nullptr ~OAHEntry() { Clear(); } OAHEntry& operator=(const OAHEntry& e) = delete; OAHEntry& operator=(OAHEntry&& e) { std::swap(data_, e.data_); return *this; } bool Empty() const { return data_ == 0; } operator bool() const { return !Empty(); } bool IsVector() const { return (data_ & kVectorBit) != 0; } bool IsEntry() const { return (data_ != 0) & !(data_ & kVectorBit); } size_t AllocSize() const { return zmalloc_usable_size(Raw()); } PtrVector& AsVector() { static_assert(sizeof(PtrVector) == sizeof(uint64_t)); return *reinterpret_cast*>(&data_); } std::string_view Key() const { assert(!IsVector()); return {GetKeyData(), GetKeySize()}; } bool HasExpiry() const { return (data_ & kExpiryBit) != 0; } // returns the expiry time of the current entry or UINT32_MAX if no expiry is set. uint32_t GetExpiry() const; // TODO consider another option to implement iterator OAHEntry* operator->() { return this; } uint64_t GetHash() const { return (data_ & kExtHashShiftedMask) >> kExtHashShift; } bool CheckNoCollisions(const uint64_t ext_hash); void SetExtHash(uint64_t ext_hash); void ClearHash() { data_ &= ~kExtHashShiftedMask; } void SetExpiry(uint32_t at_sec); void ExpireIfNeeded(uint32_t time_now, uint32_t* set_size, size_t* alloc_used); // TODO refactor, because it's inefficient // Returns additional allocation size of ptrVector [[nodiscard]] size_t Insert(OAHEntry&& e); uint32_t ElementsNum(); // TODO remove, it is inefficient OAHEntry& operator[](uint32_t pos); OAHEntry Remove(uint32_t pos); OAHEntry Pop(); char* Raw() const { return (char*)(data_ & ~kTagMask); } protected: void Clear(); const char* GetKeyData() const { uint32_t key_field_size = HasSso() ? 1 : 4; return Raw() + GetExpirySize() + key_field_size; } uint32_t GetKeySize() const; void SetExpiryBit(bool b); void SetVectorBit() { data_ |= kVectorBit; } void SetSsoBit() { data_ |= kSsoBit; } bool HasSso() const { return (data_ & kSsoBit) != 0; } size_t Size(); std::uint32_t GetExpirySize() const { return HasExpiry() ? sizeof(std::uint32_t) : 0; } // memory daya layout [Expiry, key_size, key] uint64_t data_ = 0; }; } // namespace dfly ================================================ FILE: src/core/oah_set.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "core/detail/stateless_allocator.h" #include "oah_entry.h" namespace dfly { // TODO add template parameter instead of OAHEntry class OAHSet { // Open Addressing Hash Set using OAHEntryAllocator = StatelessAllocator; using Buckets = std::vector; public: class iterator { public: using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = OAHEntry; using pointer = OAHEntry*; using reference = OAHEntry&; iterator(OAHSet* owner, uint32_t bucket_id, uint32_t pos_in_bucket) : owner_(owner), bucket_(bucket_id), pos_(pos_in_bucket) { } void SetExpiryTime(uint32_t ttl_sec) { auto& entry = owner_->entries_[bucket_][pos_]; owner_->obj_alloc_used_ -= entry.AllocSize(); owner_->entries_[bucket_][pos_].SetExpiry(owner_->EntryTTL(ttl_sec)); owner_->obj_alloc_used_ += entry.AllocSize(); } iterator& operator++() { ++pos_; SetEntryIt(); return *this; } bool operator==(const iterator& r) const { if (owner_ == nullptr || r.owner_ == nullptr) { return owner_ == r.owner_; } assert(owner_ == r.owner_); return bucket_ == r.bucket_ && pos_ == r.pos_; } bool operator!=(const iterator& r) const { return !operator==(r); } reference operator*() { return owner_->entries_[bucket_][pos_]; } reference operator->() { return owner_->entries_[bucket_][pos_]; } bool HasExpiry() { return owner_->entries_[bucket_][pos_].HasExpiry(); } uint32_t ExpiryTime() { return owner_->entries_[bucket_][pos_].GetExpiry(); } uint32_t bucket_id() const { return bucket_; } operator bool() const { return owner_; } // find valid entry_ iterator starting from buckets_it_ and set it void SetEntryIt() { if (!owner_) return; for (auto num_entries = owner_->entries_.size(); bucket_ < num_entries; ++bucket_) { auto& bucket = owner_->entries_[bucket_]; for (uint32_t bucket_size = bucket.ElementsNum(); pos_ < bucket_size; ++pos_) { if (bucket[pos_]) return; } pos_ = 0; } owner_ = nullptr; } private: OAHSet* owner_ = nullptr; uint32_t bucket_ = 0; uint32_t pos_ = 0; }; iterator begin() { iterator res(this, 0, 0); res.SetEntryIt(); return res; } iterator end() { return iterator(nullptr, 0, 0); } explicit OAHSet() = default; bool Add(std::string_view str, uint32_t ttl_sec = UINT32_MAX) { uint64_t hash = Hash(str); auto bucket_id = BucketId(hash, capacity_log_); PREFETCH_READ(entries_.data() + bucket_id); PREFETCH_READ(entries_.data() + bucket_id + 8); if (size_ >= entries_.size()) { Reserve(BucketCount() * 2); bucket_id = BucketId(hash, capacity_log_); } uint32_t at = EntryTTL(ttl_sec); // TODO maybe we should split memory allocation and copying for the case when we can't add it // into set OAHEntry entry(str, at); SetEntryHash(entry, hash); if (FastCheck(bucket_id, str, hash)) { return false; } obj_alloc_used_ += entry.AllocSize(); AddUnique(std::move(entry), bucket_id, ttl_sec); return true; } void Reserve(size_t sz) { sz = absl::bit_ceil(sz); if (sz > entries_.size()) { auto prev_capacity_log = capacity_log_; capacity_log_ = std::max(kMinCapacityLog, uint32_t(absl::bit_width(sz) - 1)); size_t prev_size = entries_.size(); entries_.resize(Capacity()); Rehash(prev_capacity_log, prev_size); } assert(entries_.size() >= kDisplacementSize); } // Shrinks the table to the specified size. The new_size must be a power of 2, // >= kMinCapacity (which is 1 << kMinCapacityLog), and >= current number of elements. // This method should be called explicitly when memory reclamation is needed. void Shrink(size_t new_size) { assert(absl::has_single_bit(new_size)); assert(new_size >= (1u << kMinCapacityLog)); assert(new_size < entries_.size()); size_t prev_size = entries_.size(); capacity_log_ = absl::bit_width(new_size) - 1; // Process from low to high (opposite of Grow/Rehash). for (size_t i = 0; i < prev_size; ++i) { ShrinkBucket(i); } entries_.resize(Capacity()); entries_.shrink_to_fit(); } void Clear() { capacity_log_ = 0; entries_.resize(0); size_ = 0; obj_alloc_used_ = 0; ptr_vectors_alloc_used_ = 0; } // TODO should be removed, inefficient void AddUnique(OAHEntry&& e, uint32_t bid, uint32_t ttl_sec = UINT32_MAX) { ++size_; assert(Capacity() >= kDisplacementSize); for (uint32_t i = 0; i < kDisplacementSize; i++) { const uint32_t bucket_id = bid + i; if (entries_[bucket_id].Empty()) { entries_[bucket_id] = std::move(e); return; } // TODO add expiration logic } bid = GetExtensionPoint(bid); assert(bid < entries_.size()); ptr_vectors_alloc_used_ += entries_[bid].Insert(std::move(e)); } unsigned AddMany(absl::Span span, uint32_t ttl_sec = UINT32_MAX) { Reserve(span.size()); unsigned res = 0; for (auto& s : span) { if (Add(s, ttl_sec) != end()) { res++; } } return res; } // TODO: Consider using chunks for this as in StringSet void Fill(OAHSet* other) { assert(other->entries_.empty()); other->Reserve(UpperBoundSize()); other->set_time(time_now()); for (auto it = begin(), it_end = end(); it != it_end; ++it) { other->Add(it->Key(), it.HasExpiry() ? it.ExpiryTime() - time_now() : UINT32_MAX); } } /** * stable scanning api. has the same guarantees as redis scan command. * we avoid doing bit-reverse by using a different function to derive a bucket id * from hash values. By using msb part of hash we make it "stable" with respect to * rehashes. For example, with table log size 4 (size 16), entries in bucket id * 1110 come from hashes 1110XXXXX.... When a table grows to log size 5, * these entries can move either to 11100 or 11101. So if we traversed with our cursor * range [0000-1110], it's guaranteed that in grown table we do not need to cover again * [00000-11100]. Similarly with shrinkage, if a table is shrunk to log size 3, * keys from 1110 and 1111 will move to bucket 111. Again, it's guaranteed that we * covered the range [000-111] (all keys in that case). * Returns: next cursor or 0 if reached the end of scan. * cursor = 0 - initiates a new scan. */ using ItemCb = std::function; uint32_t Scan(uint32_t cursor, const ItemCb& cb) { if (entries_.empty()) return 0; uint32_t bucket_id = cursor >> (32 - capacity_log_); // First find the bucket to scan, skip empty buckets. for (; bucket_id < BucketCount(); ++bucket_id) { bool res = false; for (uint32_t i = 0; i < kDisplacementSize; i++) { const uint32_t shifted_bid = bucket_id + i; res |= ScanBucket(entries_[shifted_bid], cb, bucket_id); } if (res) break; } if (++bucket_id >= BucketCount()) { return 0; } return bucket_id << (32 - capacity_log_); } OAHEntry Pop() { for (auto& bucket : entries_) { if (auto res = bucket.Pop(); !res.Empty()) { assert(!res.IsVector()); --size_; obj_alloc_used_ -= res.AllocSize(); if (bucket.IsVector()) { if (bucket.AsVector().Empty()) { ptr_vectors_alloc_used_ -= bucket.AsVector().AllocSize(); bucket = OAHEntry(); } } return res; } } return {}; } bool Erase(std::string_view str) { if (entries_.empty()) return false; uint64_t hash = Hash(str); auto bucket_id = BucketId(hash, capacity_log_); auto item = FindInternal(bucket_id, str, hash); if (item != end()) { --size_; obj_alloc_used_ -= item->AllocSize(); *item = OAHEntry(); uint32_t erase_bucket = item.bucket_id(); if (entries_[erase_bucket].IsVector()) { if (entries_[erase_bucket].AsVector().Empty()) { ptr_vectors_alloc_used_ -= entries_[erase_bucket].AsVector().AllocSize(); entries_[erase_bucket] = OAHEntry(); } } return true; } return false; } iterator Find(std::string_view member) { if (entries_.empty()) return end(); uint64_t hash = Hash(member); auto bucket_id = BucketId(hash, capacity_log_); const auto ext_hash = CalcExtHash(hash, capacity_log_); // fast check for (uint32_t i = 0; i < kDisplacementSize; i++) { const uint32_t bid = bucket_id + i; if ((entries_[bid].GetHash() == ext_hash) && entries_[bid].IsEntry()) { if (entries_[bid].Key() == member) { entries_[bid].ExpireIfNeeded(time_now_, &size_, &obj_alloc_used_); return !entries_[bid].Empty() ? iterator{this, bid, 0} : end(); } } } auto res = FindInternal(bucket_id, member, hash); return res; } bool Contains(std::string_view member) { return Find(member) != end(); } // Returns the number of elements in the map. Note that it might be that some of these elements // have expired and can't be accessed. size_t UpperBoundSize() const { return size_; } bool Empty() const { return size_ == 0; } std::uint32_t BucketCount() const { return entries_.empty() ? 0 : (1 << capacity_log_); } std::uint32_t Capacity() const { return (1 << capacity_log_) + kDisplacementSize - 1; } // set an abstract time that allows expiry. void set_time(uint32_t val) { time_now_ = val; } uint32_t time_now() const { return time_now_; } size_t ObjAllocUsed() const { return obj_alloc_used_; } size_t SetAllocUsed() const { return entries_.capacity() * sizeof(OAHEntry) + ptr_vectors_alloc_used_; } bool ExpirationUsed() const { // TODO assert(false); return true; } size_t SizeSlow() { // TODO assert(false); // CollectExpired(); return size_; } private: static uint64_t Hash(std::string_view str) { constexpr XXH64_hash_t kHashSeed = 24061983; return XXH3_64bits_withSeed(str.data(), str.size(), kHashSeed); } static uint32_t BucketId(uint64_t hash, uint32_t capacity_log) { return hash >> (64 - capacity_log); } // was Grow in StringSet void Rehash(uint32_t prev_capacity_log, uint32_t prev_size) { if (prev_size == 0) { return; } // we should prevent moving elements before current possition to avoid double processing constexpr size_t mix_size = (2 << kShiftLog) - 1; std::array old_buckets{}; for (size_t i = 0; i < mix_size; ++i) { old_buckets[i] = std::move(entries_[i]); } for (size_t bucket_id = prev_size - 1; bucket_id >= mix_size; --bucket_id) { auto bucket = std::move(entries_[bucket_id]); for (uint32_t pos = 0, size = bucket.ElementsNum(); pos < size; ++pos) { if (bucket[pos]) { auto new_bucket_id = RehashEntry(bucket[pos], bucket_id, prev_capacity_log); new_bucket_id = FindEmptyAround(new_bucket_id); ptr_vectors_alloc_used_ += entries_[new_bucket_id].Insert(std::move(bucket[pos])); } } if (bucket.IsVector()) ptr_vectors_alloc_used_ -= bucket.AsVector().AllocSize(); } for (size_t bucket_id = 0; bucket_id < mix_size; ++bucket_id) { auto& bucket = old_buckets[bucket_id]; for (uint32_t pos = 0, size = bucket.ElementsNum(); pos < size; ++pos) { if (bucket[pos]) { auto new_bucket_id = RehashEntry(bucket[pos], bucket_id, prev_capacity_log); new_bucket_id = FindEmptyAround(new_bucket_id); ptr_vectors_alloc_used_ += entries_[new_bucket_id].Insert(std::move(bucket[pos])); } } if (bucket.IsVector()) ptr_vectors_alloc_used_ -= bucket.AsVector().AllocSize(); } } // it is inefficient for now, // TODO predict new position by current position and extended hash void ShrinkBucket(uint32_t bucket_id) { auto bucket = std::move(entries_[bucket_id]); if (bucket.Empty()) return; for (uint32_t pos = 0, size = bucket.ElementsNum(); pos < size; ++pos) { if (bucket[pos]) { // Check for TTL expiration during shrink - skip expired elements if (bucket[pos].HasExpiry() && bucket[pos].GetExpiry() <= time_now_) { obj_alloc_used_ -= bucket[pos].AllocSize(); --size_; continue; } auto hash = Hash(bucket[pos].Key()); auto new_bucket_id = BucketId(hash, capacity_log_); SetEntryHash(bucket[pos], hash); new_bucket_id = FindEmptyAround(new_bucket_id); ptr_vectors_alloc_used_ += entries_[new_bucket_id].Insert(std::move(bucket[pos])); } } if (bucket.IsVector()) { ptr_vectors_alloc_used_ -= bucket.AsVector().AllocSize(); } } uint32_t GetExtensionPoint(const uint32_t bid) const { constexpr uint32_t extension_point_shift = kDisplacementSize - 1; return bid | extension_point_shift; } bool FastCheck(const uint32_t bid, std::string_view str, uint64_t hash) { const auto ext_hash = CalcExtHash(hash, capacity_log_); const auto ext_bid = GetExtensionPoint(bid); bool res = true; for (uint32_t i = 0; i < kDisplacementSize; i++) { const uint32_t bucket_id = bid + i; res &= entries_[bucket_id].CheckNoCollisions(ext_hash); } if (res) { if (entries_[ext_bid].IsVector()) { auto& vec = entries_[ext_bid].AsVector(); auto raw_arr = vec.Raw(); for (size_t i = 0, size = vec.Size(); i < size; ++i) { res &= raw_arr[i].CheckNoCollisions(ext_hash); } } if (!res) { auto pos = FindInBucket(entries_[ext_bid], str, ext_hash); if (pos) { return true; } } } else { return FindInternal(bid, str, hash); } return false; } template >* = nullptr> bool ScanBucket(OAHEntry& entry, const T& cb, uint32_t bucket_id) { if (!entry.IsVector()) { entry.ExpireIfNeeded(time_now_, &size_, &obj_alloc_used_); if (CheckBucketAffiliation(entry, bucket_id)) { cb(entry.Key()); return true; } } else { auto& arr = entry.AsVector(); bool result = false; for (auto& el : arr) { el.ExpireIfNeeded(time_now_, &size_, &obj_alloc_used_); if (CheckBucketAffiliation(el, bucket_id)) { cb(el.Key()); result = true; } } return result; } return false; } uint32_t EntryTTL(uint32_t ttl_sec) const { return ttl_sec == UINT32_MAX ? ttl_sec : time_now_ + ttl_sec; } uint32_t FindEmptyAround(uint32_t bid) { for (uint32_t i = 0; i < kDisplacementSize; i++) { const uint32_t bucket_id = bid + i; if (entries_[bucket_id].Empty()) return bucket_id; // TODO add expiration logic } bid = GetExtensionPoint(bid); assert(bid < entries_.size()); return bid; } // Searches for a string within a bucket entry (which may be a single entry or a vector). // Returns the position within the bucket if found, or std::nullopt if not found. std::optional FindInBucket(OAHEntry& bucket, std::string_view str, uint64_t ext_hash) { if (bucket.IsEntry()) { bucket.ExpireIfNeeded(time_now_, &size_, &obj_alloc_used_); return CheckExtendedHash(bucket, ext_hash) && bucket.Key() == str ? 0 : std::optional(); } if (bucket.IsVector()) { auto& vec = bucket.AsVector(); auto raw_arr = vec.Raw(); for (size_t i = 0, size = vec.Size(); i < size; ++i) { raw_arr[i].ExpireIfNeeded(time_now_, &size_, &obj_alloc_used_); if (CheckExtendedHash(raw_arr[i], ext_hash) && raw_arr[i].Key() == str) { return i; } } } return std::nullopt; } // return bucket_id and position otherwise max iterator FindInternal(uint32_t bid, std::string_view str, uint64_t hash) { const auto ext_hash = CalcExtHash(hash, capacity_log_); for (uint32_t i = 0; i < kDisplacementSize; i++) { const uint32_t bucket_id = bid + i; auto pos = FindInBucket(entries_[bucket_id], str, ext_hash); if (pos) { return iterator{this, bucket_id, *pos}; } } return end(); } private: static constexpr std::uint32_t kShiftLog = 2; // TODO make template static constexpr std::uint32_t kMinCapacityLog = kShiftLog; // should be >= ShiftLog static constexpr std::uint32_t kDisplacementSize = (1 << kShiftLog); // TODO check static uint64_t CalcExtHash(uint64_t hash, uint32_t capacity_log) { const uint32_t start_hash_bit = capacity_log > kShiftLog ? capacity_log - kShiftLog : 0; const uint32_t ext_hash_shift = 64 - start_hash_bit - OAHEntry::kExtHashSize; return (hash >> ext_hash_shift) & OAHEntry::kExtHashMask; } uint64_t SetEntryHash(OAHEntry& entry, uint64_t hash) { uint64_t ext_hash = CalcExtHash(hash, capacity_log_); entry.SetExtHash(ext_hash); return ext_hash; } bool CheckBucketAffiliation(OAHEntry& entry, uint32_t bucket_id) { assert(!entry.IsVector()); if (entry.Empty()) return false; uint32_t bucket_id_hash_part = capacity_log_ > kShiftLog ? kShiftLog : capacity_log_; uint32_t bucket_mask = (1 << bucket_id_hash_part) - 1; bucket_id &= bucket_mask; auto stored_hash = entry.GetHash(); if (!stored_hash) { stored_hash = SetEntryHash(entry, Hash(entry.Key())); } uint32_t stored_bucket_id = stored_hash >> (OAHEntry::kExtHashSize - bucket_id_hash_part); return bucket_id == stored_bucket_id; } bool CheckExtendedHash(OAHEntry& entry, uint64_t ext_hash) { auto stored_hash = entry.GetHash(); if (!stored_hash) { if (entry.IsEntry()) { stored_hash = SetEntryHash(entry, Hash(entry.Key())); } else { return false; } } return stored_hash == ext_hash; } // return new bucket_id uint32_t RehashEntry(OAHEntry& entry, uint32_t current_bucket_id, uint32_t prev_capacity_log) { assert(!entry.IsVector()); auto stored_hash = entry.GetHash(); const uint32_t logs_diff = capacity_log_ - prev_capacity_log; const uint32_t prev_significant_bits = prev_capacity_log > kShiftLog ? kShiftLog : prev_capacity_log; const uint32_t needed_hash_bits = prev_significant_bits + logs_diff; if (!stored_hash || needed_hash_bits > OAHEntry::kExtHashSize) { auto hash = Hash(entry.Key()); SetEntryHash(entry, hash); return BucketId(hash, capacity_log_); } const uint32_t real_bucket_end = stored_hash >> (OAHEntry::kExtHashSize - prev_significant_bits); const uint32_t prev_shift_mask = (1 << prev_significant_bits) - 1; const uint32_t curr_shift = (current_bucket_id - real_bucket_end) & prev_shift_mask; const uint32_t prev_bucket_mask = (1 << prev_capacity_log) - 1; const uint32_t base_bucket_id = (current_bucket_id - curr_shift) & prev_bucket_mask; const uint32_t last_bits_mask = (1 << logs_diff) - 1; const uint32_t stored_hash_shift = OAHEntry::kExtHashSize - needed_hash_bits; const uint32_t last_bits = (stored_hash >> stored_hash_shift) & last_bits_mask; const uint32_t new_bucket_id = (base_bucket_id << logs_diff) | last_bits; entry.ClearHash(); // the cache is invalid after rehash operation assert(BucketId(Hash(entry.Key()), capacity_log_) == new_bucket_id); return new_bucket_id; } mutable size_t obj_alloc_used_ = 0; mutable size_t ptr_vectors_alloc_used_ = 0; std::uint32_t capacity_log_ = 0; std::uint32_t size_ = 0; // number of elements in the set. std::uint32_t time_now_ = 0; Buckets entries_; }; } // namespace dfly ================================================ FILE: src/core/oah_set_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/oah_set.h" #include #include #include #include #include #include #include "base/gtest.h" #include "core/mi_memory_resource.h" #include "glog/logging.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { using namespace std; class OAHSetTest : public ::testing::Test { protected: static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); InitTLStatelessAllocMR(PMR_NS::get_default_resource()); } static void TearDownTestSuite() { } void SetUp() override { ss_ = new OAHSet; generator_.seed(0); } void TearDown() override { delete ss_; // ensure there are no memory leaks after every test EXPECT_EQ(zmalloc_used_memory_tl, 0); } OAHSet* ss_; mt19937 generator_; }; static string random_string(mt19937& rand, unsigned len) { const string_view alpanum = "1234567890abcdefghijklmnopqrstuvwxyz"; string ret; ret.reserve(len); for (size_t i = 0; i < len; ++i) { ret += alpanum[rand() % alpanum.size()]; } return ret; } TEST_F(OAHSetTest, PtrVectorTest) { PtrVector vp(PtrVector::FromLogSize(3)); EXPECT_EQ(vp.Size(), 8); EXPECT_EQ(vp.LogSize(), 3); size_t i = 0; for (; i < vp.Size(); ++i) { EXPECT_EQ(vp[i], 0); vp[i] = i + 1; } vp.ResizeLog(4); for (; i < vp.Size(); ++i) { EXPECT_EQ(vp[i], 0); vp[i] = i + 1; } EXPECT_EQ(vp.Size(), 16); EXPECT_EQ(vp.LogSize(), 4); for (size_t i = 0; i < vp.Size(); ++i) { EXPECT_EQ(vp[i], i + 1); } } TEST_F(OAHSetTest, OAHEntryTest) { OAHEntry test("0123456789", 2); EXPECT_EQ(test.Key(), "0123456789"sv); EXPECT_EQ(test.GetExpiry(), 2); OAHEntry first("123456789"); EXPECT_EQ(test.Insert(std::move(first)), 16); EXPECT_EQ(test.Insert(OAHEntry("23456789")), 16); EXPECT_TRUE(test.Remove(0)); EXPECT_FALSE(test.Remove(0)); EXPECT_EQ(test.Remove(2).Key(), "23456789"); EXPECT_EQ(test.Pop().Key(), "123456789"); } TEST_F(OAHSetTest, OAHSetAddFindTest) { OAHSet ss; std::set test_set; for (int i = 0; i < 10000; ++i) { test_set.insert(base::RandStr(20)); } for (const auto& s : test_set) { EXPECT_TRUE(ss.Add(s)); } for (const auto& s : test_set) { auto e = ss.Find(s); EXPECT_EQ(e->Key(), s); } EXPECT_EQ(ss.BucketCount(), 16384); } TEST_F(OAHSetTest, Basic) { EXPECT_TRUE(ss_->Add("foo"sv)); EXPECT_TRUE(ss_->Add("bar"sv)); uint32_t size = ss_->UpperBoundSize(); EXPECT_FALSE(ss_->Add("foo"sv)); EXPECT_FALSE(ss_->Add("bar"sv)); EXPECT_EQ(ss_->UpperBoundSize(), size); EXPECT_TRUE(ss_->Contains("foo"sv)); EXPECT_TRUE(ss_->Contains("bar"sv)); EXPECT_EQ(2, ss_->UpperBoundSize()); } TEST_F(OAHSetTest, StandardAddErase) { EXPECT_TRUE(ss_->Add("@@@@@@@@@@@@@@@@") != ss_->end()); EXPECT_TRUE(ss_->Add("A@@@@@@@@@@@@@@@") != ss_->end()); EXPECT_TRUE(ss_->Add("AA@@@@@@@@@@@@@@") != ss_->end()); EXPECT_TRUE(ss_->Add("AAA@@@@@@@@@@@@@") != ss_->end()); EXPECT_TRUE(ss_->Add("AAAAAAAAA@@@@@@@") != ss_->end()); EXPECT_TRUE(ss_->Add("AAAAAAAAAA@@@@@@") != ss_->end()); EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAA@") != ss_->end()); EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAAA") != ss_->end()); EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAAD") != ss_->end()); EXPECT_TRUE(ss_->Add("BBBBBAAAAAAAAAAA") != ss_->end()); EXPECT_TRUE(ss_->Add("BBBBBBBBAAAAAAAA") != ss_->end()); EXPECT_TRUE(ss_->Add("CCCCCBBBBBBBBBBB") != ss_->end()); // Remove link in the middle of chain EXPECT_TRUE(ss_->Erase("BBBBBBBBAAAAAAAA")); // Remove start of a chain EXPECT_TRUE(ss_->Erase("CCCCCBBBBBBBBBBB")); // Remove end of link EXPECT_TRUE(ss_->Erase("AAA@@@@@@@@@@@@@")); // Remove only item in chain EXPECT_TRUE(ss_->Erase("AA@@@@@@@@@@@@@@")); EXPECT_TRUE(ss_->Erase("AAAAAAAAA@@@@@@@")); EXPECT_TRUE(ss_->Erase("AAAAAAAAAA@@@@@@")); EXPECT_TRUE(ss_->Erase("AAAAAAAAAAAAAAA@")); } TEST_F(OAHSetTest, DisplacedBug) { string_view vals[] = {"imY", "OVl", "NhH", "BCe", "YDL", "lpb", "nhF", "xod", "zYR", "PSa", "hce", "cTR"}; ss_->AddMany(absl::MakeSpan(vals), UINT32_MAX); ss_->Add("fIc"); ss_->Erase("YDL"); ss_->Add("fYs"); ss_->Erase("hce"); ss_->Erase("nhF"); ss_->Add("dye"); ss_->Add("xZT"); ss_->Add("LVK"); ss_->Erase("zYR"); ss_->Erase("fYs"); ss_->Add("ueB"); ss_->Erase("PSa"); ss_->Erase("OVl"); ss_->Add("cga"); ss_->Add("too"); ss_->Erase("ueB"); ss_->Add("HZe"); ss_->Add("oQn"); ss_->Erase("too"); ss_->Erase("HZe"); ss_->Erase("xZT"); ss_->Erase("cga"); ss_->Erase("cTR"); ss_->Erase("BCe"); ss_->Add("eua"); ss_->Erase("lpb"); ss_->Add("OXK"); ss_->Add("QmO"); ss_->Add("SzV"); ss_->Erase("QmO"); ss_->Add("jbe"); ss_->Add("BPN"); ss_->Add("OfH"); ss_->Add("Muf"); ss_->Add("CwP"); ss_->Erase("Muf"); ss_->Erase("xod"); ss_->Add("Cis"); ss_->Add("Xvd"); ss_->Erase("SzV"); ss_->Erase("eua"); ss_->Add("DGb"); ss_->Add("leD"); ss_->Add("MVX"); ss_->Add("HPq"); } TEST_F(OAHSetTest, Resizing) { constexpr size_t num_strs = 4096; unordered_set strs; while (strs.size() != num_strs) { auto str = random_string(generator_, 10); strs.insert(str); } unsigned size = 0; for (auto it = strs.begin(); it != strs.end(); ++it) { const auto& str = *it; EXPECT_TRUE(ss_->Add(str, 1)); EXPECT_EQ(ss_->UpperBoundSize(), size + 1); // make sure we haven't lost any items after a grow // which happens every power of 2 if ((size & (size - 1)) == 0) { for (auto j = strs.begin(); j != it; ++j) { const auto& str = *j; auto it = ss_->Find(str); ASSERT_NE(it, ss_->end()); EXPECT_TRUE(it.HasExpiry()); EXPECT_EQ(it.ExpiryTime(), ss_->time_now() + 1); } } ++size; } } TEST_F(OAHSetTest, SimpleScan) { unordered_set info = {"foo", "bar"}; unordered_set seen; for (auto str : info) { EXPECT_TRUE(ss_->Add(str)); } uint32_t cursor = 0; do { cursor = ss_->Scan(cursor, [&](std::string_view str) { EXPECT_TRUE(info.count(str)); seen.insert(str); }); } while (cursor != 0); EXPECT_EQ(seen.size(), info.size()); EXPECT_TRUE(equal(seen.begin(), seen.end(), info.begin())); } // // Ensure REDIS scan guarantees are met TEST_F(OAHSetTest, ScanGuarantees) { unordered_set to_be_seen = {"foo", "bar"}; unordered_set not_be_seen = {"AAA", "BBB"}; unordered_set maybe_seen = {"AA@@@@@@@@@@@@@@", "AAA@@@@@@@@@@@@@", "AAAAAAAAA@@@@@@@", "AAAAAAAAAA@@@@@@"}; unordered_set seen; auto scan_callback = [&](std::string_view str) { EXPECT_TRUE(to_be_seen.count(str) || maybe_seen.count(str)); EXPECT_FALSE(not_be_seen.count(str)); if (to_be_seen.count(str)) { seen.insert(str); } }; EXPECT_EQ(ss_->Scan(0, scan_callback), 0); for (auto str : not_be_seen) { EXPECT_TRUE(ss_->Add(str)); } for (auto str : not_be_seen) { EXPECT_TRUE(ss_->Erase(str)); } for (auto str : to_be_seen) { EXPECT_TRUE(ss_->Add(str)); } // should reach at least the first item in the set uint32_t cursor = ss_->Scan(0, scan_callback); for (auto str : maybe_seen) { EXPECT_TRUE(ss_->Add(str)); } while (cursor != 0) { cursor = ss_->Scan(cursor, scan_callback); } EXPECT_TRUE(seen.size() == to_be_seen.size()); } TEST_F(OAHSetTest, IntOnly) { constexpr size_t num_ints = 8192; unordered_set numbers; for (size_t i = 0; i < num_ints; ++i) { numbers.insert(i); EXPECT_TRUE(ss_->Add(to_string(i))); } EXPECT_EQ(ss_->UpperBoundSize(), num_ints); for (size_t i = 0; i < num_ints; ++i) { ASSERT_FALSE(ss_->Add(to_string(i))); } EXPECT_EQ(ss_->UpperBoundSize(), num_ints); size_t num_remove = generator_() % 4096; unordered_set removed; for (size_t i = 0; i < num_remove; ++i) { auto remove_int = generator_() % num_ints; auto remove = to_string(remove_int); if (numbers.count(remove_int)) { ASSERT_TRUE(ss_->Contains(remove)) << remove_int; EXPECT_TRUE(ss_->Erase(remove)); numbers.erase(remove_int); } else { EXPECT_FALSE(ss_->Erase(remove)); } EXPECT_FALSE(ss_->Contains(remove)); removed.insert(remove); } size_t expected_seen = 0; auto scan_callback = [&](std::string_view str_v) { std::string str(str_v); EXPECT_FALSE(removed.count(str)); if (numbers.count(std::atoi(str.data()))) { ++expected_seen; } }; uint32_t cursor = 0; do { cursor = ss_->Scan(cursor, scan_callback); // randomly throw in some new numbers uint32_t val = generator_(); ss_->Add(to_string(val)); } while (cursor != 0); EXPECT_GE(expected_seen + removed.size(), num_ints); } TEST_F(OAHSetTest, XtremeScanGrow) { unordered_set to_see, force_grow, seen; while (to_see.size() != 8) { to_see.insert(random_string(generator_, 10)); } while (force_grow.size() != 8192) { string str = random_string(generator_, 10); if (to_see.count(str)) { continue; } force_grow.insert(random_string(generator_, 10)); } for (auto& str : to_see) { EXPECT_TRUE(ss_->Add(str)); } auto scan_callback = [&](string_view strv) { std::string str(strv); if (to_see.count(str)) { seen.insert(str); } }; uint32_t cursor = ss_->Scan(0, scan_callback); // force approx 10 grows for (auto& s : force_grow) { EXPECT_TRUE(ss_->Add(s)); } while (cursor != 0) { cursor = ss_->Scan(cursor, scan_callback); } EXPECT_EQ(seen.size(), to_see.size()); } TEST_F(OAHSetTest, Pop) { constexpr size_t num_items = 8; unordered_set to_insert; while (to_insert.size() != num_items) { auto str = random_string(generator_, 10); if (to_insert.count(str)) { continue; } to_insert.insert(str); EXPECT_TRUE(ss_->Add(str)); } while (!ss_->Empty()) { size_t size = ss_->UpperBoundSize(); auto str = ss_->Pop(); DCHECK(ss_->UpperBoundSize() == to_insert.size() - 1); DCHECK(str); DCHECK(to_insert.count(std::string(str.Key()))); DCHECK_EQ(ss_->UpperBoundSize(), size - 1); to_insert.erase(std::string(str.Key())); } DCHECK(ss_->Empty()); DCHECK(to_insert.empty()); } TEST_F(OAHSetTest, Iteration) { ss_->Add("foo"); for (const auto& ptr : *ss_) { LOG(INFO) << ptr; } ss_->Clear(); constexpr size_t num_items = 8192; unordered_set to_insert; while (to_insert.size() != num_items) { auto str = random_string(generator_, 10); if (to_insert.count(str)) { continue; } to_insert.insert(str); EXPECT_TRUE(ss_->Add(str)); } for (const auto& ptr : *ss_) { std::string str(ptr.Key()); EXPECT_TRUE(to_insert.count(str)); to_insert.erase(str); } EXPECT_EQ(to_insert.size(), 0); } TEST_F(OAHSetTest, SetFieldExpireHasExpiry) { EXPECT_TRUE(ss_->Add("k1", 100)); auto k = ss_->Find("k1"); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 100); k.SetExpiryTime(1); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 1); } TEST_F(OAHSetTest, SetFieldExpireNoHasExpiry) { EXPECT_TRUE(ss_->Add("k1")); auto k = ss_->Find("k1"); EXPECT_FALSE(k.HasExpiry()); k.SetExpiryTime(10); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 10); } TEST_F(OAHSetTest, Ttl) { EXPECT_TRUE(ss_->Add("bla"sv, 1)); EXPECT_FALSE(ss_->Add("bla"sv, 1)); auto it = ss_->Find("bla"sv); EXPECT_EQ(1u, it.ExpiryTime()); ss_->set_time(1); EXPECT_TRUE(ss_->Add("bla"sv, 1)); EXPECT_EQ(1u, ss_->UpperBoundSize()); for (unsigned i = 0; i < 100; ++i) { EXPECT_TRUE(ss_->Add(absl::StrCat("foo", i), 1)); } EXPECT_EQ(101u, ss_->UpperBoundSize()); it = ss_->Find("foo50"); EXPECT_EQ("foo50"sv, it->Key()); EXPECT_EQ(2u, it.ExpiryTime()); ss_->set_time(2); // Cleanup all `foo` entries uint32_t cursor = 0; do { cursor = ss_->Scan(cursor, [&](std::string_view) {}); } while (cursor != 0); for (unsigned i = 0; i < 100; ++i) { EXPECT_TRUE(ss_->Add(absl::StrCat("bar", i))); } EXPECT_EQ(100u, ss_->UpperBoundSize()); it = ss_->Find("bar50"); EXPECT_FALSE(it.HasExpiry()); for (auto it = ss_->begin(); it != ss_->end(); ++it) { ASSERT_TRUE(absl::StartsWith(it->Key(), "bar")) << it->Key(); string str(it->Key()); VLOG(1) << *it; } } TEST_F(OAHSetTest, Grow) { for (size_t j = 0; j < 10; ++j) { for (size_t i = 0; i < 4098; ++i) { ss_->Reserve(generator_() % 256); auto str = random_string(generator_, 3); ss_->Add(str); } ss_->Clear(); } } TEST_F(OAHSetTest, Reserve) { vector strs; for (size_t i = 0; i < 10; ++i) { strs.push_back(random_string(generator_, 10)); ss_->Add(strs.back()); } for (size_t j = 2; j < 20; j += 3) { ss_->Reserve(j * 20); for (size_t i = 0; i < 10; ++i) { ASSERT_TRUE(ss_->Contains(strs[i])); } } } TEST_F(OAHSetTest, Fill) { for (size_t i = 0; i < 100; ++i) { ss_->Add(random_string(generator_, 10)); } OAHSet s2; ss_->Fill(&s2); EXPECT_EQ(s2.UpperBoundSize(), ss_->UpperBoundSize()); for (const auto& s : *ss_) { EXPECT_TRUE(s2.Contains(s.Key())); } } TEST_F(OAHSetTest, IterateEmpty) { for (const auto& s : *ss_) { // We're iterating to make sure there is no crash. However, if we got here, it's a bug CHECK(false) << "Found entry " << s << " in empty set"; } } static size_t MemUsed(OAHSet& obj) { return obj.ObjAllocUsed() + obj.SetAllocUsed(); } void BM_Clone(benchmark::State& state) { vector strs; mt19937 generator(0); OAHSet ss1, ss2; unsigned elems = state.range(0); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, 10); ss1.Add(str); } ss2.Reserve(ss1.UpperBoundSize()); while (state.KeepRunning()) { for (auto& src : ss1) { ss2.Add(src.Key()); } state.PauseTiming(); ss2.Clear(); ss2.Reserve(ss1.UpperBoundSize()); state.ResumeTiming(); } } BENCHMARK(BM_Clone)->ArgName("elements")->Arg(32000); void BM_Fill(benchmark::State& state) { unsigned elems = state.range(0); vector strs; mt19937 generator(0); OAHSet ss1, ss2; for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, 10); ss1.Add(str); } while (state.KeepRunning()) { ss1.Fill(&ss2); state.PauseTiming(); ss2.Clear(); state.ResumeTiming(); } } BENCHMARK(BM_Fill)->ArgName("elements")->Arg(32000); void BM_Clear(benchmark::State& state) { unsigned elems = state.range(0); mt19937 generator(0); OAHSet ss; while (state.KeepRunning()) { state.PauseTiming(); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, 16); ss.Add(str); } state.ResumeTiming(); ss.Clear(); } } BENCHMARK(BM_Clear)->ArgName("elements")->Arg(32000); void BM_Add(benchmark::State& state) { vector strs; mt19937 generator(0); OAHSet ss; unsigned elems = state.range(0); unsigned keySize = state.range(1); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, keySize); strs.push_back(str); } ss.Reserve(elems); size_t mem_used = 0; while (state.KeepRunning()) { for (auto& str : strs) ss.Add(str); state.PauseTiming(); mem_used += MemUsed(ss); ss.Clear(); ss.Reserve(elems); state.ResumeTiming(); } state.counters["Memory_Used"] = mem_used / state.iterations(); } BENCHMARK(BM_Add) ->ArgNames({"elements", "KeySize"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_AddMany(benchmark::State& state) { vector strs; mt19937 generator(0); OAHSet ss; unsigned elems = state.range(0); unsigned keySize = state.range(1); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, keySize); strs.push_back(str); } ss.Reserve(elems); vector svs; size_t mem_used = 0; for (const auto& str : strs) { svs.push_back(str); } while (state.KeepRunning()) { ss.AddMany(absl::MakeSpan(svs)); state.PauseTiming(); CHECK_EQ(ss.UpperBoundSize(), elems); mem_used += MemUsed(ss); ss.Clear(); ss.Reserve(elems); state.ResumeTiming(); } state.counters["Memory_Used"] = mem_used / state.iterations(); } BENCHMARK(BM_AddMany) ->ArgNames({"elements", "KeySize"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_Erase(benchmark::State& state) { std::vector strs; mt19937 generator(0); OAHSet ss; auto elems = state.range(0); auto keySize = state.range(1); for (long int i = 0; i < elems; ++i) { std::string str = random_string(generator, keySize); strs.push_back(str); ss.Add(str); } state.counters["Memory_Before_Erase"] = MemUsed(ss); size_t mem_used = 0; while (state.KeepRunning()) { for (auto& str : strs) { ss.Erase(str); } state.PauseTiming(); mem_used += MemUsed(ss); for (auto& str : strs) { ss.Add(str); } state.ResumeTiming(); } state.counters["Memory_After_Erase"] = mem_used / state.iterations(); } BENCHMARK(BM_Erase) ->ArgNames({"elements", "KeySize"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_Get(benchmark::State& state) { std::vector strs; mt19937 generator(0); OAHSet ss; auto elems = state.range(0); auto keySize = state.range(1); for (long int i = 0; i < elems; ++i) { std::string str = random_string(generator, keySize); strs.push_back(str); ss.Add(str); } while (state.KeepRunning()) { for (auto& str : strs) { ss.Find(str); } } } BENCHMARK(BM_Get) ->ArgNames({"elements", "KeySize"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_Grow(benchmark::State& state) { vector strs; mt19937 generator(0); OAHSet src; unsigned elems = 1 << 18; for (size_t i = 0; i < elems; ++i) { src.Add(random_string(generator, 16), UINT32_MAX); strs.push_back(random_string(generator, 16)); } while (state.KeepRunning()) { state.PauseTiming(); OAHSet tmp; src.Fill(&tmp); CHECK_EQ(tmp.BucketCount(), elems); state.ResumeTiming(); for (const auto& str : strs) { tmp.Add(str); if (tmp.BucketCount() > elems) { break; // we grew } } CHECK_GT(tmp.BucketCount(), elems); } } BENCHMARK(BM_Grow); // unsigned total_wasted_memory = 0; // TEST_F(OAHSetTest, ReallocIfNeeded) { // auto build_str = [](size_t i) { return to_string(i) + string(131, 'a'); }; // auto count_waste = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, // size_t block_size, void* arg) { // size_t used = block_size * area->used; // total_wasted_memory += area->committed - used; // return true; // }; // for (size_t i = 0; i < 10'000; i++) // ss_->Add(build_str(i)); // for (size_t i = 0; i < 10'000; i++) { // if (i % 10 == 0) // continue; // ss_->Erase(build_str(i)); // } // mi_heap_collect(mi_heap_get_backing(), true); // mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); // size_t wasted_before = total_wasted_memory; // size_t underutilized = 0; // for (auto it = ss_->begin(); it != ss_->end(); ++it) { // underutilized += zmalloc_page_is_underutilized(*it, 0.9); // it.ReallocIfNeeded(0.9); // } // // Check there are underutilized pages // CHECK_GT(underutilized, 0u); // total_wasted_memory = 0; // mi_heap_collect(mi_heap_get_backing(), true); // mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); // size_t wasted_after = total_wasted_memory; // // Check we waste significanlty less now // EXPECT_GT(wasted_before, wasted_after * 2); // EXPECT_EQ(ss_->UpperBoundSize(), 1000); // for (size_t i = 0; i < 1000; i++) // EXPECT_EQ(*ss_->Find(build_str(i * 10)), build_str(i * 10)); // } class ShrinkTest : public OAHSetTest, public ::testing::WithParamInterface {}; TEST_P(ShrinkTest, BasicShrink) { constexpr size_t num_strs = 1000000; size_t shrink_to = GetParam(); vector strs; for (size_t i = 0; i < num_strs; ++i) { strs.push_back(random_string(generator_, 10)); EXPECT_TRUE(ss_->Add(strs.back())); } // Grow to a larger size ss_->Reserve(1 << 22); size_t original_bucket_count = ss_->BucketCount(); EXPECT_EQ(original_bucket_count, 1u << 22); // Shrink to the parameterized size ss_->Shrink(shrink_to); EXPECT_EQ(ss_->BucketCount(), shrink_to); EXPECT_EQ(ss_->UpperBoundSize(), num_strs); // Verify all elements are still accessible for (const auto& str : strs) { EXPECT_TRUE(ss_->Contains(str)) << "Missing: " << str; } } INSTANTIATE_TEST_SUITE_P(ShrinkSizes, ShrinkTest, ::testing::Values(1u << 21, // 2M buckets (sparse) 1u << 20, // 1M buckets (~1 per bucket) 1u << 19), // 512K buckets (~2 per bucket) [](const auto& info) { return absl::StrCat("buckets_", info.param); }); TEST_F(OAHSetTest, ShrinkWithTTL) { constexpr size_t num_strs = 1000000; // Track elements by their TTL category vector expired_strs; // TTL 1-50, will expire vector surviving_strs; // TTL 51-100, will survive vector no_ttl_strs; // No TTL, will survive for (size_t i = 0; i < num_strs; ++i) { string str = random_string(generator_, 10); if (i % 3 == 0) { // No TTL EXPECT_TRUE(ss_->Add(str)); no_ttl_strs.push_back(str); } else if (i % 3 == 1) { // TTL 1-50 (will expire when time=50) uint32_t ttl = (i % 50) + 1; EXPECT_TRUE(ss_->Add(str, ttl)); expired_strs.push_back(str); } else { // TTL 51-100 (will survive when time=50) uint32_t ttl = (i % 50) + 51; EXPECT_TRUE(ss_->Add(str, ttl)); surviving_strs.push_back(str); } } // Grow to larger size ss_->Reserve(1 << 22); // Set time to 50 - this will expire elements with TTL <= 50 ss_->set_time(50); // Shrink ss_->Shrink(1 << 21); EXPECT_EQ(ss_->BucketCount(), 1u << 21); // Verify expired elements are gone for (const auto& str : expired_strs) { EXPECT_EQ(ss_->Find(str), ss_->end()) << "Should be expired: " << str; } // Verify surviving TTL elements are still accessible with correct TTL for (const auto& str : surviving_strs) { auto it = ss_->Find(str); ASSERT_NE(it, ss_->end()) << "Missing surviving TTL element: " << str; EXPECT_TRUE(it.HasExpiry()); EXPECT_GT(it.ExpiryTime(), 50u); } // Verify no-TTL elements are still accessible for (const auto& str : no_ttl_strs) { auto it = ss_->Find(str); ASSERT_NE(it, ss_->end()) << "Missing no-TTL element: " << str; EXPECT_FALSE(it.HasExpiry()); } } TEST_F(OAHSetTest, ScanWithShrinkBetweenCalls) { // Test that cursor-based scanning works correctly when Grow and Shrink happen between Scan calls // This verifies SCAN guarantees: elements present at start and end of scan must be seen constexpr size_t num_strs = 1000000; vector strs; unordered_set must_see; // Add elements and track them for (size_t i = 0; i < num_strs; ++i) { strs.push_back(random_string(generator_, 10)); EXPECT_TRUE(ss_->Add(strs.back())); must_see.insert(strs.back()); } // Note initial bucket count (will be ~1M after adding 1M elements) size_t initial_bucket_count = ss_->BucketCount(); unordered_set seen; auto scan_callback = [&](const string_view str) { seen.emplace(str); }; // Start scanning BEFORE Grow uint32_t cursor = ss_->Scan(0, scan_callback); EXPECT_NE(cursor, 0u) << "Should not finish in one iteration"; // Grow to large size in the middle of scanning ss_->Reserve(1 << 22); EXPECT_EQ(ss_->BucketCount(), 1u << 22); EXPECT_GT(ss_->BucketCount(), initial_bucket_count); // Continue scanning a bit after Grow cursor = ss_->Scan(cursor, scan_callback); // Now Shrink in the middle of scanning - this is the key test // Elements that existed at scan start must still be visible ss_->Shrink(1 << 21); EXPECT_EQ(ss_->BucketCount(), 1u << 21); // Continue scanning with the same cursor constexpr int max_iterations = 1 << 22; int iterations = 0; while (cursor != 0 && iterations < max_iterations) { cursor = ss_->Scan(cursor, scan_callback); iterations++; } EXPECT_LT(iterations, max_iterations) << "Hit iteration limit"; EXPECT_EQ(cursor, 0u) << "Scan should complete"; // Verify all original elements were seen for (const auto& str : must_see) { ASSERT_TRUE(seen.count(str)) << "Missing element after shrink: " << str; } EXPECT_EQ(seen.size(), must_see.size()) << "Should see exactly all original elements"; } } // namespace dfly ================================================ FILE: src/core/overloaded.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // // #pragma once namespace dfly { template struct Overloaded : Ts... { using Ts::operator()...; }; template Overloaded(Ts...) -> Overloaded; } // namespace dfly ================================================ FILE: src/core/page_usage/CMakeLists.txt ================================================ add_library(dfly_page_usage page_usage_stats.cc) target_link_libraries(dfly_page_usage base TRDP::hdr_histogram redis_lib absl::strings) ================================================ FILE: src/core/page_usage/page_usage_stats.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/page_usage/page_usage_stats.h" #include #include #include #include #include #include #include "base/cycle_clock.h" extern "C" { #include #include "redis/zmalloc.h" mi_page_usage_stats_t mi_heap_page_is_underutilized(mi_heap_t* heap, void* p, float ratio, bool collect_stats); } namespace dfly { using absl::StrAppend; using absl::StrFormat; using absl::StripTrailingAsciiWhitespace; namespace { constexpr auto kUsageHistPoints = std::array{50, 90, 99}; constexpr auto kHistSignificantFigures = 3; HllBufferPtr InitHllPtr() { HllBufferPtr p; p.size = getDenseHllSize(); p.hll = new uint8_t[p.size]; CHECK_EQ(0, createDenseHll(p)); return p; } } // namespace CycleQuota::CycleQuota(const uint64_t quota_usec) : CycleQuota(base::CycleClock::FromUsec(quota_usec), true) { } void CycleQuota::Arm() { start_cycles_ = base::CycleClock::Now(); } bool CycleQuota::Depleted() const { if (quota_cycles_ == kMaxQuota) return false; return UsedCycles() >= quota_cycles_; } uint64_t CycleQuota::UsedCycles() const { return base::CycleClock::Now() - start_cycles_; } CycleQuota CycleQuota::Unlimited() { return CycleQuota(kMaxQuota, true); } void CycleQuota::Extend(const uint64_t quota_usec) { if (quota_cycles_ == kMaxQuota) return; quota_cycles_ += base::CycleClock::FromUsec(quota_usec); } CycleQuota::CycleQuota(const uint64_t quota_cycles, bool /*tag*/) : quota_cycles_{quota_cycles} { Arm(); } void CollectedPageStats::Merge(CollectedPageStats&& other, uint16_t shard_id) { this->pages_scanned += other.pages_scanned; this->pages_marked_for_realloc += other.pages_marked_for_realloc; this->pages_full += other.pages_full; this->pages_reserved_for_malloc += other.pages_reserved_for_malloc; this->pages_with_heap_mismatch += other.pages_with_heap_mismatch; this->pages_above_threshold += other.pages_above_threshold; this->objects_skipped_not_required += other.objects_skipped_not_required; this->objects_skipped_not_supported += other.objects_skipped_not_supported; shard_wide_summary.emplace(std::make_pair(shard_id, std::move(other.page_usage_hist))); } CollectedPageStats CollectedPageStats::Merge(std::vector&& stats, const float threshold) { CollectedPageStats result; result.threshold = threshold; size_t shard_index = 0; for (CollectedPageStats& stat : stats) { result.Merge(std::move(stat), shard_index++); } return result; } std::string CollectedPageStats::ToString() const { std::string response; StrAppend(&response, "Page usage threshold: ", threshold * 100, "\n"); StrAppend(&response, "Pages scanned: ", pages_scanned, "\n"); StrAppend(&response, "Pages marked for reallocation: ", pages_marked_for_realloc, "\n"); StrAppend(&response, "Pages full: ", pages_full, "\n"); StrAppend(&response, "Pages reserved for malloc: ", pages_reserved_for_malloc, "\n"); StrAppend(&response, "Pages skipped due to heap mismatch: ", pages_with_heap_mismatch, "\n"); StrAppend(&response, "Pages with usage above threshold: ", pages_above_threshold, "\n"); StrAppend(&response, "Objects skipped (do not require defragmentation): ", objects_skipped_not_required, "\n"); StrAppend(&response, "Objects skipped (do not support defragmentation): ", objects_skipped_not_supported, "\n"); for (const auto& [shard_id, usage] : shard_wide_summary) { StrAppend(&response, "[Shard ", shard_id, "]\n"); for (const auto& [percentage, count] : usage) { StrAppend(&response, StrFormat(" %d%% pages are below %d%% block usage\n", percentage, count)); } } StripTrailingAsciiWhitespace(&response); return response; } PageUsage::UniquePages::UniquePages() : pages_scanned{InitHllPtr()}, pages_marked_for_realloc{InitHllPtr()}, pages_full{InitHllPtr()}, pages_reserved_for_malloc{InitHllPtr()}, pages_with_heap_mismatch{InitHllPtr()}, pages_above_threshold{InitHllPtr()} { hdr_histogram* h = nullptr; const auto init_result = hdr_init(1, 100, kHistSignificantFigures, &h); CHECK_EQ(0, init_result) << "failed to initialize histogram"; page_usage_hist = h; } PageUsage::UniquePages::~UniquePages() { delete[] pages_scanned.hll; delete[] pages_marked_for_realloc.hll; delete[] pages_full.hll; delete[] pages_reserved_for_malloc.hll; delete[] pages_with_heap_mismatch.hll; delete[] pages_above_threshold.hll; hdr_close(page_usage_hist); } void PageUsage::UniquePages::AddStat(mi_page_usage_stats_t stat) { // NOLINT should not be const const auto data = reinterpret_cast(&stat.page_address); auto record = [&data](HllBufferPtr ctr) { pfadd_dense(ctr, data, sizeof(stat.page_address)); }; record(pages_scanned); if (stat.flags & MI_DFLY_PAGE_BELOW_THRESHOLD) { record(pages_marked_for_realloc); } if (stat.flags & MI_DFLY_PAGE_FULL) { record(pages_full); } if (stat.flags & MI_DFLY_HEAP_MISMATCH) { record(pages_with_heap_mismatch); } if (stat.flags & MI_DFLY_PAGE_USED_FOR_MALLOC) { record(pages_reserved_for_malloc); } if (stat.flags == 0) { // No special flags means the page is above the threshold but not full - record usage for // histogram. This allows tuning the threshold for future commands. record(pages_above_threshold); hdr_record_value(page_usage_hist, 100.0 * stat.used / stat.capacity); } } CollectedPageStats PageUsage::UniquePages::CollectedStats() const { CollectedPageStats::ShardUsageSummary usage; for (const auto p : kUsageHistPoints) { usage[p] = hdr_value_at_percentile(page_usage_hist, p); } return CollectedPageStats{ .pages_scanned = static_cast(pfcountSingle(pages_scanned)), .pages_marked_for_realloc = static_cast(pfcountSingle(pages_marked_for_realloc)), .pages_full = static_cast(pfcountSingle(pages_full)), .pages_reserved_for_malloc = static_cast(pfcountSingle(pages_reserved_for_malloc)), .pages_with_heap_mismatch = static_cast(pfcountSingle(pages_with_heap_mismatch)), .pages_above_threshold = static_cast(pfcountSingle(pages_above_threshold)), .objects_skipped_not_required = objects_skipped_not_required, .objects_skipped_not_supported = objects_skipped_not_supported, .page_usage_hist = std::move(usage), .shard_wide_summary = {}}; } PageUsage::PageUsage(CollectPageStats collect_stats, float threshold, CycleQuota quota) : collect_stats_{collect_stats}, threshold_{threshold}, quota_{quota} { } void PageUsage::ArmQuotaTimer() { quota_.Arm(); } uint64_t PageUsage::UsedQuotaCycles() const { return quota_.UsedCycles(); } bool PageUsage::IsPageForObjectUnderUtilized(void* object) { mi_page_usage_stats_t stat; zmalloc_page_is_underutilized(object, threshold_, collect_stats_ == CollectPageStats::YES, &stat); return ConsumePageStats(stat); } bool PageUsage::IsPageForObjectUnderUtilized(mi_heap_t* heap, void* object) { return ConsumePageStats(mi_heap_page_is_underutilized(heap, object, threshold_, collect_stats_ == CollectPageStats::YES)); } bool PageUsage::ConsumePageStats(mi_page_usage_stats_t stat) { const bool should_reallocate = stat.flags == MI_DFLY_PAGE_BELOW_THRESHOLD; if (collect_stats_ == CollectPageStats::YES) { unique_pages_.AddStat(stat); } return force_reallocate_ || should_reallocate; } bool PageUsage::QuotaDepleted() const { return quota_.Depleted(); } void PageUsage::ExtendQuota(uint64_t quota_usec) { quota_.Extend(quota_usec); } } // namespace dfly ================================================ FILE: src/core/page_usage/page_usage_stats.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #define MI_BUILD_RELEASE 1 #include extern "C" { #include "redis/hyperloglog.h" } struct hdr_histogram; namespace dfly { class CycleQuota { public: static constexpr uint64_t kMaxQuota = std::numeric_limits::max(); static constexpr uint64_t kDefaultDefragQuota = 150; explicit CycleQuota(uint64_t quota_usec); // Sets the starting point for the quota to be counted from. Can be called multiple times to reset // the quota counter. void Arm(); bool Depleted() const; uint64_t UsedCycles() const; static CycleQuota Unlimited(); // Extends the quota by the given amount. If any quota was already left over, it is also retained // on top of the newly added quota. For example, if 80 usec was left, and we extend by 50 usec, // the task now has 130 usec before the quota will be depleted. void Extend(uint64_t quota_usec); private: explicit CycleQuota(uint64_t quota_cycles, bool /*tag*/); uint64_t quota_cycles_; uint64_t start_cycles_{0}; }; enum class CollectPageStats : uint8_t { YES, NO }; struct CollectedPageStats { double threshold{0.0}; uint64_t pages_scanned{0}; uint64_t pages_marked_for_realloc{0}; uint64_t pages_full{0}; uint64_t pages_reserved_for_malloc{0}; uint64_t pages_with_heap_mismatch{0}; uint64_t pages_above_threshold{0}; uint64_t objects_skipped_not_required{0}; uint64_t objects_skipped_not_supported{0}; using ShardUsageSummary = absl::btree_map; ShardUsageSummary page_usage_hist; absl::btree_map shard_wide_summary; void Merge(CollectedPageStats&& other, uint16_t shard_id); static CollectedPageStats Merge(std::vector&& stats, float threshold); std::string ToString() const; }; class PageUsage { public: PageUsage(CollectPageStats collect_stats, float threshold, CycleQuota quota = CycleQuota::Unlimited()); virtual ~PageUsage() = default; // Resets the quota timer to split defragmentation into different groups with separate quotas. // For example, first defragment objects with a quota and then defragment search indices with the // same quota independently. void ArmQuotaTimer(); uint64_t UsedQuotaCycles() const; virtual bool IsPageForObjectUnderUtilized(void* object); bool IsPageForObjectUnderUtilized(mi_heap_t* heap, void* object); CollectedPageStats CollectedStats() const { return unique_pages_.CollectedStats(); } bool ConsumePageStats(mi_page_usage_stats_t stats); void RecordNotRequired() { unique_pages_.objects_skipped_not_required += 1; } void RecordNotSupported() { unique_pages_.objects_skipped_not_supported += 1; } void SetForceReallocate(bool force_reallocate) { force_reallocate_ = force_reallocate; } bool QuotaDepleted() const; void ExtendQuota(uint64_t quota_usec); private: CollectPageStats collect_stats_{CollectPageStats::NO}; float threshold_; struct UniquePages { HllBufferPtr pages_scanned; HllBufferPtr pages_marked_for_realloc; HllBufferPtr pages_full; HllBufferPtr pages_reserved_for_malloc; HllBufferPtr pages_with_heap_mismatch; HllBufferPtr pages_above_threshold; hdr_histogram* page_usage_hist{}; uint64_t objects_skipped_not_required{0}; uint64_t objects_skipped_not_supported{0}; explicit UniquePages(); ~UniquePages(); void AddStat(mi_page_usage_stats_t stat); CollectedPageStats CollectedStats() const; }; UniquePages unique_pages_; CycleQuota quota_; // For use in testing, forces reallocate check to always return true bool force_reallocate_{false}; }; } // namespace dfly ================================================ FILE: src/core/page_usage_stats_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/page_usage/page_usage_stats.h" #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/compact_object.h" #include "core/qlist.h" #include "core/score_map.h" #include "core/search/block_list.h" #include "core/search/search.h" #include "core/small_string.h" #include "core/sorted_map.h" #include "core/string_map.h" #include "core/string_set.h" #include "redis/redis_aux.h" #include "util/fibers/fibers.h" extern "C" { #include "redis/zmalloc.h" } ABSL_DECLARE_FLAG(bool, experimental_flat_json); using namespace dfly; using namespace std::chrono_literals; namespace { std::string GenerateTestJSON(size_t num_objects) { std::string data = R"({"contents":[)"; for (size_t i = 0; i < num_objects; ++i) { const auto si = std::to_string(i); data += R"({"id":)" + si + R"(,"class":"v___)" + si + R"(","value":)" + si + R"(})"; if (i < num_objects - 1) { data += ","; } } data += R"(], "data": "some", "count": 1, "checked": false})"; return data; } // Helper to defragment only if a randomly generated value is less than preset probability. For // benchmarking realistic situations, where some nodes are fragmented and others are not class SelectiveDefragment : public PageUsage { public: explicit SelectiveDefragment(const double fragmentation_probability) : PageUsage(CollectPageStats::NO, 0), frag_prob_{fragmentation_probability} { } bool IsPageForObjectUnderUtilized(void*) override { return dist_(rng_) < frag_prob_; } private: double frag_prob_; std::mt19937 rng_{99}; std::uniform_real_distribution dist_{0.0, 1.0}; }; struct MemStats { size_t total_reserved{0}; size_t total_committed{0}; size_t total_used{0}; size_t total_wasted{0}; size_t num_pages{0}; }; MemStats LogMemStats(const mi_heap_t* heap) { MemStats stats; mi_heap_visit_blocks( heap, false, [](const mi_heap_t* /*h*/, const mi_heap_area_t* area, void* /*block*/, size_t block_size, void* arg) { const size_t committed = area->committed; const size_t used = area->used * block_size; const auto s = static_cast(arg); s->num_pages++; s->total_committed += committed; s->total_reserved += area->reserved; s->total_used += used; s->total_wasted += committed - used; return true; }, &stats); LOG(INFO) << "Pages: " << stats.num_pages; LOG(INFO) << "Reserved : " << stats.total_reserved << " bytes"; LOG(INFO) << "Committed: " << stats.total_committed << " bytes"; LOG(INFO) << "Used: " << stats.total_used << " bytes"; LOG(INFO) << "Wasted: " << stats.total_wasted << " bytes"; if (stats.total_committed) { LOG(INFO) << "Wasted%: " << static_cast(stats.total_wasted) / stats.total_committed * 100.0; LOG(INFO) << "Utilization%: " << static_cast(stats.total_used) / stats.total_committed * 100.0; } return stats; } } // namespace class PageUsageStatsTest : public ::testing::Test { protected: static void SetUpTestSuite() { init_zmalloc_threadlocal(mi_heap_get_backing()); } static void TearDownTestSuite() { mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks( mi_heap_get_backing(), false, [](auto*, auto* a, void*, size_t block_sz, void*) { LOG(ERROR) << "Unfreed allocations: block_size " << block_sz << ", allocated: " << a->used * block_sz; return true; }, nullptr); } PageUsageStatsTest() : m_(mi_heap_get_backing()) { InitTLStatelessAllocMR(&m_); } void SetUp() override { CompactObj::InitThreadLocal(&m_); score_map_ = std::make_unique(); sorted_map_ = std::make_unique(); string_set_ = std::make_unique(); string_map_ = std::make_unique(); SmallString::InitThreadLocal(m_.heap()); qlist_ = std::make_unique(2, 2); } void TearDown() override { score_map_.reset(); sorted_map_.reset(); string_set_.reset(); string_map_.reset(); small_string_.Free(); qlist_->Clear(); EXPECT_EQ(zmalloc_used_memory_tl, 0); c_obj_.Reset(); CleanupStatelessAllocMR(); } MiMemoryResource m_; std::unique_ptr score_map_; std::unique_ptr sorted_map_; std::unique_ptr string_set_; std::unique_ptr string_map_; SmallString small_string_{}; std::unique_ptr qlist_; CompactValue c_obj_{}; }; TEST_F(PageUsageStatsTest, Defrag) { score_map_->AddOrUpdate("test", 0.1); sorted_map_->InsertNew(0.1, "x"); string_set_->Add("a"); string_map_->AddOrUpdate("key", "value"); small_string_.Assign("small-string"); // INT_TAG, defrag will be skipped c_obj_.SetString("1"); qlist_->Push("xxxx", QList::HEAD); { PageUsage p{CollectPageStats::YES, 0.1}; score_map_->begin().ReallocIfNeeded(&p); sorted_map_->DefragIfNeeded(&p); string_set_->begin().ReallocIfNeeded(&p); string_map_->begin().ReallocIfNeeded(&p); small_string_.DefragIfNeeded(&p); c_obj_.DefragIfNeeded(&p); qlist_->DefragIfNeeded(&p); const auto stats = p.CollectedStats(); EXPECT_GT(stats.pages_scanned, 0); EXPECT_EQ(stats.objects_skipped_not_required, 1); } { PageUsage p{CollectPageStats::NO, 0.1}; score_map_->begin().ReallocIfNeeded(&p); sorted_map_->DefragIfNeeded(&p); string_set_->begin().ReallocIfNeeded(&p); string_map_->begin().ReallocIfNeeded(&p); small_string_.DefragIfNeeded(&p); qlist_->DefragIfNeeded(&p); EXPECT_EQ(p.CollectedStats().pages_scanned, 0); } } TEST_F(PageUsageStatsTest, StatCollection) { constexpr auto threshold = 0.5; PageUsage p{CollectPageStats::YES, threshold}; for (size_t i = 0; i < 10000; ++i) { p.ConsumePageStats({.page_address = uintptr_t{100000 + i}, .block_size = 1, .capacity = 100, .reserved = 100, .used = 65, .flags = 0}); } for (size_t i = 0; i < 2000; ++i) { p.ConsumePageStats({.page_address = uintptr_t{200000 + i}, .block_size = 1, .capacity = 100, .reserved = 100, .used = 85, .flags = 0}); } for (size_t i = 0; i < 1000; ++i) { p.ConsumePageStats({.page_address = uintptr_t{300000 + i}, .block_size = 1, .capacity = 100, .reserved = 100, .used = 89, .flags = 0}); } constexpr auto page_count_per_flag = 150; auto start = 0; for (const uint8_t flag : {MI_DFLY_PAGE_FULL, MI_DFLY_PAGE_USED_FOR_MALLOC, MI_DFLY_HEAP_MISMATCH, MI_DFLY_PAGE_BELOW_THRESHOLD}) { for (size_t i = 0; i < page_count_per_flag; ++i) { p.ConsumePageStats({.page_address = uintptr_t{start + i}, .block_size = 1, .capacity = 100, .reserved = 100, .used = 100, .flags = flag}); } start += page_count_per_flag; } CollectedPageStats st; st.Merge(p.CollectedStats(), 1); EXPECT_GT(st.pages_scanned, 12000); // Expect a small error margin due to HLL EXPECT_NEAR(st.pages_full, page_count_per_flag, 5); EXPECT_NEAR(st.pages_reserved_for_malloc, page_count_per_flag, 5); EXPECT_NEAR(st.pages_marked_for_realloc, page_count_per_flag, 5); const auto usage = st.shard_wide_summary; EXPECT_EQ(usage.size(), 1); EXPECT_TRUE(usage.contains(1)); const CollectedPageStats::ShardUsageSummary expected{{50, 65}, {90, 85}, {99, 89}}; EXPECT_EQ(usage.at(1), expected); } TEST_F(PageUsageStatsTest, JSONCons) { // Because of the static encoding it is not possible to easily test the flat encoding. Once the // encoding flag is set, it is not re-read. If friend class is used to access the compact object // inner fields and call `DefragIfNeeded` directly on the flat variant of the union, the test will // still fail. This is because freeing the compact object code path takes the wrong branch based // on encoding. The flat encoding was tested manually adjusting this same test with changed // encoding. std::string data = GenerateTestJSON(1000); auto* mr = static_cast(CompactObj::memory_resource()); size_t before = mr->used(); auto parsed = ParseJsonUsingShardHeap(data); EXPECT_TRUE(parsed.has_value()); c_obj_.SetJson(std::move(parsed.value())); c_obj_.SetJsonSize(mr->used() - before); EXPECT_GT(c_obj_.MallocUsed(), 0); PageUsage p{CollectPageStats::YES, 0.1}; p.SetForceReallocate(true); c_obj_.DefragIfNeeded(&p); EXPECT_GT(c_obj_.MallocUsed(), 0); const auto stats = p.CollectedStats(); EXPECT_GT(stats.pages_scanned, 0); EXPECT_EQ(stats.objects_skipped_not_required, 0); EXPECT_EQ(c_obj_.ObjType(), OBJ_JSON); auto json_obj = c_obj_.GetJson(); EXPECT_EQ(json_obj->at("data").as_string_view(), "some"); EXPECT_EQ(json_obj->at("count").as_integer(), 1); EXPECT_EQ(json_obj->at("checked").as_bool(), false); } TEST_F(PageUsageStatsTest, JsonDefragEmpty) { auto parsed = ParseJsonUsingShardHeap(R"({})"); EXPECT_TRUE(parsed.has_value()); PageUsage p{CollectPageStats::NO, 0}; p.SetForceReallocate(true); Defragment(parsed.value(), &p); EXPECT_TRUE(parsed->empty()); } TEST_F(PageUsageStatsTest, JsonDefragNested) { constexpr auto data = R"({"a":{"b":{"c":{"d":"value"}}}})"; auto parsed = ParseJsonUsingShardHeap(data); EXPECT_TRUE(parsed.has_value()); PageUsage p{CollectPageStats::NO, 0}; p.SetForceReallocate(true); Defragment(parsed.value(), &p); EXPECT_EQ(parsed->at("a").at("b").at("c").at("d").as_string_view(), "value"); } TEST_F(PageUsageStatsTest, JsonDefragRemainsInSameHeap) { // This is a brute force test that defragmentation does not erroneously move data to the default // heap. Comparing allocators before/after defragmentation is not useful as stateless allocators // are all equal. It might be possible to compare the allocator type, but this approach checks // that the pointers in a JSON object belong to the same heap as they did before defragmentation. const std::string data = R"({ "data": {"sub-data": "attr1"}, "values": [true, false, 1.11, 2], "secretkey": ")" + std::string(1024, '.') + "\"}"; auto json = ParseJsonUsingShardHeap(data); EXPECT_TRUE(json.has_value()); auto key_before = json->at("secretkey").as_string_view(); auto sub_before = json->at("data").at("sub-data").as_string_view(); auto values_before = &*json->at("values").array_range().begin(); EXPECT_TRUE(mi_heap_contains_block(m_.heap(), key_before.data())); EXPECT_TRUE(mi_heap_contains_block(m_.heap(), sub_before.data())); EXPECT_TRUE(mi_heap_contains_block(m_.heap(), values_before)); PageUsage p{CollectPageStats::NO, 0}; p.SetForceReallocate(true); Defragment(json.value(), &p); auto key_after = json->at("secretkey").as_string_view(); auto sub_after = json->at("data").at("sub-data").as_string_view(); auto values_after = &*json->at("values").array_range().begin(); // Data still managed by the same heap. EXPECT_TRUE(mi_heap_contains_block(m_.heap(), key_after.data())); EXPECT_TRUE(mi_heap_contains_block(m_.heap(), sub_after.data())); EXPECT_TRUE(mi_heap_contains_block(m_.heap(), values_after)); // Defragment actually changed addresses EXPECT_NE(key_after.data(), key_before.data()); EXPECT_NE(sub_after.data(), sub_before.data()); EXPECT_NE(values_after, values_before); } TEST_F(PageUsageStatsTest, QuotaChecks) { { PageUsage p{CollectPageStats::NO, 0}; EXPECT_FALSE(p.QuotaDepleted()); } { PageUsage p{CollectPageStats::NO, 0, CycleQuota{4}}; util::ThisFiber::SleepFor(5us); EXPECT_TRUE(p.QuotaDepleted()); } } TEST_F(PageUsageStatsTest, BlockList) { search::BlockList> bl{&m_, 20}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); // empty list auto result = bl.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 0); // single item will move twice, once for the blocklist and once for the sorted vector bl.Insert(1); result = bl.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 2); // quota depleted without defragmentation PageUsage p_zero{CollectPageStats::NO, 0.1, CycleQuota{0}}; p_zero.SetForceReallocate(true); result = bl.Defragment(&p_zero); EXPECT_TRUE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 0); } TEST_F(PageUsageStatsTest, BlockListDefragmentResumes) { search::BlockList> bl{&m_, 20}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); for (size_t i = 0; i < 1000; ++i) { bl.Insert(i); } PageUsage p_small_quota{CollectPageStats::NO, 0.1, CycleQuota{10}}; p_small_quota.SetForceReallocate(true); util::ThisFiber::SleepFor(10us); auto result = bl.Defragment(&p_small_quota); EXPECT_TRUE(result.quota_depleted); EXPECT_GE(result.objects_moved, 0); result = bl.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_GT(result.objects_moved, 0); } TEST_F(PageUsageStatsTest, BlockListWithPairs) { search::BlockList>> bl{&m_, 20}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); for (size_t i = 0; i < 100; ++i) { bl.Insert({i, i * 1.1}); } const auto result = bl.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_GT(result.objects_moved, 0); } TEST_F(PageUsageStatsTest, BlockListWithNonDefragmentableContainer) { search::BlockList bl{&m_, 20}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); // empty list auto result = bl.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 0); // will reallocate once for the blocklist, the inner sorted set will be skipped bl.Insert(1); result = bl.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 1); } class MockDocument final : public search::DocumentAccessor { public: MockDocument() { words.reserve(1000); for (size_t i = 0; i < 1000; ++i) { words.push_back(absl::StrFormat("word-%d", i)); } } std::optional GetStrings(std::string_view active_field) const override { return {{words[absl::GetCurrentTimeNanos() % words.size()]}}; } std::optional GetVector(std::string_view active_field, size_t dim) const override { return std::nullopt; } std::optional GetNumbers(std::string_view active_field) const override { return {{1, 2, 3, 4}}; } std::optional GetTags(std::string_view active_field) const override { return {{words[absl::GetCurrentTimeNanos() % words.size()]}}; } std::vector words; }; TEST_F(PageUsageStatsTest, DefragmentTagIndex) { search::Schema schema; schema.fields["field_name"] = search::SchemaField{search::SchemaField::TAG, 0, "fn", search::SchemaField::TagParams{}}; search::FieldIndices index{schema, {}, &m_, nullptr}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); // Empty index search::DefragmentResult result = index.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 0); const MockDocument md; index.Add(1, md); result = index.Defragment(&p); EXPECT_FALSE(result.quota_depleted); // single doc with single term returned by `GetTags` should result in two reallocations. EXPECT_EQ(result.objects_moved, 2); PageUsage p_zero{CollectPageStats::NO, 0.1, CycleQuota{0}}; p_zero.SetForceReallocate(true); result = index.Defragment(&p_zero); EXPECT_TRUE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 0); } TEST_F(PageUsageStatsTest, TagIndexDefragResumeWithChanges) { search::Schema schema; schema.fields["field_name"] = search::SchemaField{search::SchemaField::TAG, 0, "fn", search::SchemaField::TagParams{}}; search::FieldIndices index{schema, {}, &m_, nullptr}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); const MockDocument md; for (size_t i = 0; i < 100; ++i) { index.Add(i, md); } PageUsage p_small_quota{CollectPageStats::NO, 0.1, CycleQuota{10}}; p_small_quota.SetForceReallocate(true); util::ThisFiber::SleepFor(10us); search::DefragmentResult result = index.Defragment(&p_small_quota); EXPECT_TRUE(result.quota_depleted); EXPECT_GE(result.objects_moved, 0); index.Remove(99, md); for (size_t i = 200; i < 300; ++i) { index.Add(i, md); } result = index.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_GT(result.objects_moved, 0); } TEST_F(PageUsageStatsTest, DefragmentIndexWithNonDefragmentableFields) { search::Schema schema; schema.fields["text"] = search::SchemaField{search::SchemaField::TEXT, 0, "fn", search::SchemaField::TextParams{}}; schema.fields["num"] = search::SchemaField{search::SchemaField::NUMERIC, 0, "fn", search::SchemaField::NumericParams{}}; search::IndicesOptions options{{}}; search::FieldIndices index{schema, options, &m_, nullptr}; PageUsage p{CollectPageStats::NO, 0.1}; p.SetForceReallocate(true); const MockDocument md; index.Add(1, md); // Unsupported index types will skip defragmenting themselves const search::DefragmentResult result = index.Defragment(&p); EXPECT_FALSE(result.quota_depleted); EXPECT_EQ(result.objects_moved, 0); } TEST_F(PageUsageStatsTest, DefragReducesWaste) { // This test works with actual defragmentation, by deleting every other json object which creates // holes in pages which cannot be directly freed. The test asserts that wasted memory goes down as // well as committed memory after defragmentation. std::vector> all_objects; constexpr auto total_json = 100; all_objects.reserve(total_json); for (auto i = 0; i < total_json; ++i) { auto parsed = ParseJsonUsingShardHeap(GenerateTestJSON(500)); EXPECT_TRUE(parsed.has_value()); all_objects.emplace_back(std::move(parsed.value())); } // Delete every other object to create gaps, so that the pages are partially used. for (size_t i = 0; i < all_objects.size(); i += 2) { all_objects[i].reset(); } // Allow mimalloc to free any completely empty pages, if any mi_heap_collect(m_.heap(), true); // Collects stats using mi_visit.. also logs, to see logs run the test with: // --vmodule=page_usage_stats_test=1 --logtostderr const auto before = LogMemStats(m_.heap()); PageUsage p{CollectPageStats::NO, 0.8}; for (auto& j : all_objects) { if (j.has_value()) { Defragment(j.value(), &p); } } mi_heap_collect(m_.heap(), true); const auto after = LogMemStats(m_.heap()); EXPECT_LT(after.total_wasted, before.total_wasted); EXPECT_LT(after.total_committed, before.total_committed); } TEST_F(PageUsageStatsTest, MixedFlagHandling) { PageUsage p{CollectPageStats::YES, 0.0}; auto add_pages = [&](size_t count, uintptr_t start_address, uint8_t flags) { for (const size_t i : std::views::iota(0UL, count)) { p.ConsumePageStats({.page_address = uintptr_t{start_address + i}, .block_size = 100, .capacity = 1000, .reserved = 100, .used = 99, .flags = flags}); } }; add_pages(2000, 10, MI_DFLY_PAGE_FULL | MI_DFLY_PAGE_USED_FOR_MALLOC | MI_DFLY_HEAP_MISMATCH); add_pages(500, 50000, MI_DFLY_PAGE_BELOW_THRESHOLD); const auto stats = p.CollectedStats(); constexpr auto tolerance = 60; EXPECT_NEAR(stats.pages_full, 2000, tolerance); EXPECT_NEAR(stats.pages_reserved_for_malloc, 2000, tolerance); EXPECT_NEAR(stats.pages_with_heap_mismatch, 2000, tolerance); EXPECT_EQ(stats.pages_full, stats.pages_reserved_for_malloc); EXPECT_EQ(stats.pages_full, stats.pages_with_heap_mismatch); EXPECT_NEAR(stats.pages_marked_for_realloc, 500, 15); } namespace { void InitBenchMemRes() { static bool initialized = false; if (!initialized) { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); static MiMemoryResource m{tlh}; InitTLStatelessAllocMR(&m); CompactObj::InitThreadLocal(&m); initialized = true; } } } // namespace void BM_JSONDefragSelective(benchmark::State& state) { InitBenchMemRes(); std::string json_data = GenerateTestJSON(state.range(0)); for (auto _ : state) { state.PauseTiming(); auto parsed = ParseJsonUsingShardHeap(json_data); DCHECK(parsed.has_value()); SelectiveDefragment p{state.range(1) / 100.0}; state.ResumeTiming(); Defragment(parsed.value(), &p); benchmark::DoNotOptimize(parsed); } } BENCHMARK(BM_JSONDefragSelective) ->ArgNames({"objects_per_json", "fragmentation_probability"}) ->Args({250, 0}) ->Args({250, 30}) ->Args({250, 70}) ->Args({250, 100}) ->Args({1000, 0}) ->Args({1000, 30}) ->Args({1000, 70}) ->Args({1000, 100}) ->Args({4000, 0}) ->Args({4000, 30}) ->Args({4000, 70}) ->Args({4000, 100}); ================================================ FILE: src/core/qlist.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/qlist.h" extern "C" { #include "redis/listpack.h" #include "redis/lzfP.h" #include "redis/zmalloc.h" } #include #include #include #include #include #include "base/logging.h" #include "core/page_usage/page_usage_stats.h" using namespace std; /* Maximum size in bytes of any multi-element listpack. * Larger values will live in their own isolated listpacks. * This is used only if we're limited by record count. when we're limited by * size, the maximum limit is bigger, but still safe. * 8k is a recommended / default size limit */ #define SIZE_SAFETY_LIMIT 8192 /* Maximum estimate of the listpack entry overhead. * Although in the worst case(sz < 64), we will waste 6 bytes in one * quicklistNode, but can avoid memory waste due to internal fragmentation * when the listpack exceeds the size limit by a few bytes (e.g. being 16388). */ #define SIZE_ESTIMATE_OVERHEAD 8 /* Minimum listpack size in bytes for attempting compression. */ #define MIN_COMPRESS_BYTES 256 /* Minimum size reduction in bytes to store compressed quicklistNode data. * This also prevents us from storing compression if the compression * resulted in a larger size than the original data. */ #define MIN_COMPRESS_IMPROVE 32 #define QL_NODE_IS_PLAIN(node) ((node)->container == QUICKLIST_NODE_CONTAINER_PLAIN) namespace dfly { namespace { static_assert(sizeof(QList) == 48); static_assert(sizeof(QList::Node) == 40); enum IterDir : uint8_t { FWD = 1, REV = 0 }; /* This is for test suite development purposes only, 0 means disabled. */ size_t packed_threshold = 0; /* Optimization levels for size-based filling. * Note that the largest possible limit is 64k, so even if each record takes * just one byte, it still won't overflow the 16 bit count field. */ const size_t kOptLevel[] = {4096, 8192, 16384, 32768, 65536}; /* Calculate the size limit of the quicklist node based on negative 'fill'. */ size_t NodeNegFillLimit(int fill) { DCHECK_LT(fill, 0); size_t offset = (-fill) - 1; constexpr size_t max_level = ABSL_ARRAYSIZE(kOptLevel); if (offset >= max_level) offset = max_level - 1; return kOptLevel[offset]; } const uint8_t* uint_ptr(string_view sv) { static uint8_t empty = 0; return sv.empty() ? &empty : reinterpret_cast(sv.data()); } bool IsLargeElement(size_t sz, int fill) { if (ABSL_PREDICT_FALSE(packed_threshold != 0)) return sz >= packed_threshold; if (fill >= 0) return sz > SIZE_SAFETY_LIMIT; else return sz > NodeNegFillLimit(fill); } /* Calculate the size limit or length limit of the quicklist node * based on 'fill', and is also used to limit list listpack. */ void quicklistNodeLimit(int fill, size_t* size, unsigned int* count) { *size = SIZE_MAX; *count = UINT_MAX; if (fill >= 0) { /* Ensure that one node have at least one entry */ *count = (fill == 0) ? 1 : fill; } else { *size = NodeNegFillLimit(fill); } } #define sizeMeetsSafetyLimit(sz) ((sz) <= SIZE_SAFETY_LIMIT) /* Check if the limit of the quicklist node has been reached to determine if * insertions, merges or other operations that would increase the size of * the node can be performed. * Return 1 if exceeds the limit, otherwise 0. */ int quicklistNodeExceedsLimit(int fill, size_t new_sz, unsigned int new_count) { size_t sz_limit; unsigned int count_limit; quicklistNodeLimit(fill, &sz_limit, &count_limit); if (ABSL_PREDICT_TRUE(sz_limit != SIZE_MAX)) { return new_sz > sz_limit; } else if (count_limit != UINT_MAX) { /* when we reach here we know that the limit is a size limit (which is * safe, see comments next to optimization_level and SIZE_SAFETY_LIMIT) */ if (!sizeMeetsSafetyLimit(new_sz)) return 1; return new_count > count_limit; } ABSL_UNREACHABLE(); } bool NodeAllowInsert(const QList::Node* node, const int fill, const size_t sz) { if (ABSL_PREDICT_FALSE(!node)) return false; if (ABSL_PREDICT_FALSE(QL_NODE_IS_PLAIN(node) || IsLargeElement(sz, fill))) return false; /* Estimate how many bytes will be added to the listpack by this one entry. * We prefer an overestimation, which would at worse lead to a few bytes * below the lowest limit of 4k (see optimization_level). * Note: No need to check for overflow below since both `node->sz` and * `sz` are to be less than 1GB after the plain/large element check above. */ size_t new_sz = node->sz + sz + SIZE_ESTIMATE_OVERHEAD; return !quicklistNodeExceedsLimit(fill, new_sz, node->count + 1); } bool NodeAllowMerge(const QList::Node* a, const QList::Node* b, const int fill) { if (!a || !b) return false; if (ABSL_PREDICT_FALSE(QL_NODE_IS_PLAIN(a) || QL_NODE_IS_PLAIN(b))) return false; /* approximate merged listpack size (- 7 to remove one listpack * header/trailer, see LP_HDR_SIZE and LP_EOF) */ unsigned int merge_sz = a->sz + b->sz - 7; // Allow merge if new node will not exceed the limit. return !quicklistNodeExceedsLimit(fill, merge_sz, a->count + b->count); } // the owner over entry is passed to the node. QList::Node* CreateRAW(int container, uint8_t* entry, size_t sz) { QList::Node* node = (QList::Node*)zmalloc(sizeof(*node)); node->entry = entry; node->count = 1; node->sz = sz; node->next = node->prev = NULL; node->encoding = QUICKLIST_NODE_ENCODING_RAW; node->container = container; node->recompress = 0; node->dont_compress = 0; node->offloaded = 0; return node; } uint8_t* LP_Insert(uint8_t* lp, string_view elem, uint8_t* pos, int lp_where) { DCHECK(pos); return lpInsertString(lp, uint_ptr(elem), elem.size(), pos, lp_where, NULL); } uint8_t* LP_Append(uint8_t* lp, string_view elem) { return lpAppend(lp, uint_ptr(elem), elem.size()); } uint8_t* LP_Prepend(uint8_t* lp, string_view elem) { return lpPrepend(lp, uint_ptr(elem), elem.size()); } QList::Node* CreateFromSV(int container, string_view value) { uint8_t* entry = nullptr; size_t sz = 0; if (container == QUICKLIST_NODE_CONTAINER_PLAIN) { DCHECK(!value.empty()); sz = value.size(); entry = (uint8_t*)zmalloc(sz); memcpy(entry, value.data(), sz); } else { entry = LP_Append(lpNew(0), value); sz = lpBytes(entry); } return CreateRAW(container, entry, sz); } // Returns the relative increase in size. inline ssize_t NodeSetEntry(QList::Node* node, uint8_t* entry) { node->entry = entry; size_t new_sz = lpBytes(node->entry); ssize_t diff = new_sz - node->sz; node->sz = new_sz; return diff; } /* quicklistLZF is a 8+N byte struct holding 'sz' followed by 'compressed'. * 'sz' is byte length of 'compressed' field. * 'compressed' is LZF data with total (compressed) length 'sz' * NOTE: uncompressed length is stored in quicklistNode->sz. * When quicklistNode->entry is compressed, node->entry points to a quicklistLZF */ using quicklistLZF = struct quicklistLZF { size_t sz; /* LZF size in bytes*/ char compressed[]; }; inline quicklistLZF* GetLzf(QList::Node* node) { DCHECK(node->encoding == QUICKLIST_NODE_ENCODING_LZF || node->encoding == QLIST_NODE_ENCODING_LZ4); return (quicklistLZF*)node->entry; } bool CompressLZF(QList::Node* node) { // We allocate LZF_STATE on heap, piggy-backing on the existing allocation. char* uptr = (char*)zmalloc(sizeof(quicklistLZF) + node->sz + sizeof(LZF_STATE)); quicklistLZF* lzf = (quicklistLZF*)uptr; LZF_HSLOT* sdata = (LZF_HSLOT*)(uptr + sizeof(quicklistLZF) + node->sz); /* Cancel if compression fails or doesn't compress small enough */ if (((lzf->sz = lzf_compress(node->entry, node->sz, lzf->compressed, node->sz, sdata)) == 0) || lzf->sz + MIN_COMPRESS_IMPROVE >= node->sz) { /* lzf_compress aborts/rejects compression if value not compressible. */ DVLOG(2) << "Uncompressable " << node->sz << " vs " << lzf->sz; zfree(lzf); QList::stats.bad_compression_attempts++; return false; } DVLOG(2) << "Compressed " << node->sz << " to " << lzf->sz; QList::stats.compressed_bytes += lzf->sz; QList::stats.raw_compressed_bytes += node->sz; lzf = (quicklistLZF*)zrealloc(lzf, sizeof(*lzf) + lzf->sz); zfree(node->entry); node->entry = (unsigned char*)lzf; node->encoding = QUICKLIST_NODE_ENCODING_LZF; return true; } bool CompressLZ4(QList::Node* node) { LZ4F_cctx* cntx; LZ4F_errorCode_t code = LZ4F_createCompressionContext(&cntx, LZ4F_VERSION); CHECK(!LZ4F_isError(code)); LZ4F_preferences_t lz4_pref = LZ4F_INIT_PREFERENCES; lz4_pref.compressionLevel = -1; lz4_pref.frameInfo.contentSize = node->sz; size_t buf_size = LZ4F_compressFrameBound(node->sz, &lz4_pref); // We reuse quicklistLZF struct for LZ4 metadata. quicklistLZF* dest = (quicklistLZF*)zmalloc(sizeof(quicklistLZF) + buf_size); size_t compr_sz = LZ4F_compressFrame_usingCDict(cntx, dest->compressed, buf_size, node->entry, node->sz, nullptr /* dict */, &lz4_pref); CHECK(!LZ4F_isError(compr_sz)); code = LZ4F_freeCompressionContext(cntx); CHECK(!LZ4F_isError(code)); if (compr_sz + MIN_COMPRESS_IMPROVE >= node->sz) { QList::stats.bad_compression_attempts++; zfree(dest); return false; } dest->sz = compr_sz; dest = (quicklistLZF*)zrealloc(dest, sizeof(quicklistLZF) + compr_sz); QList::stats.compressed_bytes += compr_sz; QList::stats.raw_compressed_bytes += node->sz; zfree(node->entry); node->entry = (unsigned char*)dest; node->encoding = QLIST_NODE_ENCODING_LZ4; return true; } /* Compress the listpack in 'node' and update encoding details. * Returns true if listpack compressed successfully. * Returns false if compression failed or if listpack too small to compress. */ bool CompressRaw(QList::Node* node, unsigned method) { DCHECK(node->encoding == QUICKLIST_NODE_ENCODING_RAW); DCHECK(!node->dont_compress); /* validate that the node is neither * tail nor head (it has prev and next)*/ DCHECK(node->prev && node->next); node->recompress = 0; /* Don't bother compressing small values */ if (node->sz < MIN_COMPRESS_BYTES) return false; QList::stats.compression_attempts++; if (method == static_cast(QList::LZF)) { return CompressLZF(node); } return CompressLZ4(node); } ssize_t TryCompress(QList::Node* node, unsigned method) { DCHECK(node); if (node->encoding == QUICKLIST_NODE_ENCODING_RAW) { node->attempted_compress = 1; if (!node->dont_compress) { if (CompressRaw(node, method)) return ssize_t(GetLzf(node)->sz) - node->sz; } } return 0; } /* Uncompress the listpack in 'node' and update encoding details. * Returns 1 on successful decode, 0 on failure to decode. */ bool DecompressRaw(bool recompress, QList::Node* node) { DCHECK(node->encoding == QUICKLIST_NODE_ENCODING_LZF || node->encoding == QLIST_NODE_ENCODING_LZ4); node->recompress = int(recompress); void* decompressed = zmalloc(node->sz); quicklistLZF* lzf = GetLzf(node); QList::stats.decompression_calls++; QList::stats.compressed_bytes -= lzf->sz; QList::stats.raw_compressed_bytes -= node->sz; if (node->encoding == QLIST_NODE_ENCODING_LZ4) { LZ4F_dctx* dctx = nullptr; LZ4F_errorCode_t code = LZ4F_createDecompressionContext(&dctx, LZ4F_VERSION); CHECK(!LZ4F_isError(code)); size_t decompressed_sz = node->sz; size_t left = LZ4F_decompress(dctx, decompressed, &decompressed_sz, lzf->compressed, &lzf->sz, nullptr); CHECK_EQ(left, 0u); CHECK_EQ(decompressed_sz, node->sz); LZ4F_freeDecompressionContext(dctx); } else { if (lzf_decompress(lzf->compressed, lzf->sz, decompressed, node->sz) == 0) { LOG(DFATAL) << "Invalid LZF compressed data"; /* Someone requested decompress, but we can't decompress. Not good. */ zfree(decompressed); return false; } } zfree(lzf); node->entry = (uint8_t*)decompressed; node->encoding = QUICKLIST_NODE_ENCODING_RAW; return true; } /* Decompress only compressed nodes. recompress: if true, the node will be marked for recompression after decompression. returns by how much the size of the node has increased. */ ssize_t TryDecompressInternal(bool recompress, QList::Node* node) { if (node->encoding != QUICKLIST_NODE_ENCODING_RAW) { size_t compressed_sz = GetLzf(node)->sz; if (DecompressRaw(recompress, node)) { return node->sz - compressed_sz; } } return 0; } ssize_t RecompressOnly(QList::Node* node, unsigned method) { if (node->recompress && !node->dont_compress) { if (CompressRaw(node, method)) return (GetLzf(node))->sz - node->sz; } return 0; } // If after is true, returns a new node with elements in [offset, inf), otherwise // returns [0, offset-1]. QList::Node* SplitNode(QList::Node* node, int offset, bool after, ssize_t* diff) { DCHECK(node->container == QUICKLIST_NODE_CONTAINER_PACKED); size_t zl_sz = node->sz; uint8_t* entry = (uint8_t*)zmalloc(zl_sz); memcpy(entry, node->entry, zl_sz); /* Need positive offset for calculating extent below. */ if (offset < 0) offset = node->count + offset; /* Ranges to be trimmed: -1 here means "continue deleting until the list ends" */ int orig_start = after ? offset + 1 : 0; int orig_extent = after ? -1 : offset; int new_start = after ? 0 : offset; int new_extent = after ? offset + 1 : -1; ssize_t diff_existing = NodeSetEntry(node, lpDeleteRange(node->entry, orig_start, orig_extent)); node->count = lpLength(node->entry); entry = lpDeleteRange(entry, new_start, new_extent); QList::Node* new_node = CreateRAW(QUICKLIST_NODE_CONTAINER_PACKED, entry, lpBytes(entry)); new_node->count = lpLength(new_node->entry); *diff = diff_existing; return new_node; } } // namespace __thread QList::Stats QList::stats; QList::Stats& QList::Stats::operator+=(const Stats& other) { #define ADD_FIELD(field) this->field += other.field; ADD_FIELD(compression_attempts); ADD_FIELD(bad_compression_attempts); ADD_FIELD(decompression_calls); ADD_FIELD(compressed_bytes); ADD_FIELD(raw_compressed_bytes); ADD_FIELD(interior_node_reads); ADD_FIELD(total_node_reads); ADD_FIELD(offload_requests); ADD_FIELD(onload_requests); #undef ADD_FIELD return *this; } size_t QList::Node::GetLZF(void** data) const { DCHECK(encoding == QUICKLIST_NODE_ENCODING_LZF || encoding == QLIST_NODE_ENCODING_LZ4); quicklistLZF* lzf = (quicklistLZF*)entry; *data = lzf->compressed; return lzf->sz; } void QList::SetPackedThreshold(unsigned threshold) { packed_threshold = threshold; } size_t QList::DefragIfNeeded(PageUsage* page_usage) { size_t reallocated = 0; for (Node* curr = head_; curr; curr = curr->next) { if (!page_usage->IsPageForObjectUnderUtilized(curr->entry)) { continue; } // Data pointed to by the nodes is reallocated. The nodes themselves are not reallocated because // of their constant (and relatively small, ~40 bytes per object) size. Defragmentation fixes // fragmented memory allocation, which usually happens when variable-sized blocks of data are // allocated and deallocated, which is not expected with nodes. uint8_t* new_entry = static_cast(zmalloc(curr->sz)); memcpy(new_entry, curr->entry, curr->sz); uint8_t* old_entry = curr->entry; curr->entry = new_entry; zfree(old_entry); ++reallocated; } return reallocated; } void QList::SetTieringParams(const TieringParams& params) { tiering_params_ = make_unique(params); } QList::QList(int fill, int compress) : fill_(fill), compress_(compress), bookmark_count_(0) { compr_method_ = 0; } QList::QList(QList&& other) noexcept : head_(other.head_), count_(other.count_), len_(other.len_), fill_(other.fill_), compress_(other.compress_), bookmark_count_(other.bookmark_count_) { other.head_ = nullptr; other.len_ = other.count_ = 0; } QList::~QList() { Clear(); } QList& QList::operator=(QList&& other) noexcept { if (this != &other) { Clear(); head_ = other.head_; len_ = other.len_; count_ = other.count_; fill_ = other.fill_; compress_ = other.compress_; bookmark_count_ = other.bookmark_count_; tiering_params_ = std::move(other.tiering_params_); num_offloaded_nodes_ = other.num_offloaded_nodes_; other.head_ = nullptr; other.len_ = other.count_ = other.num_offloaded_nodes_ = 0; } return *this; } void QList::Clear() noexcept { Node* current = head_; while (len_) { Node* next = current->next; if (current->encoding != QUICKLIST_NODE_ENCODING_RAW) { quicklistLZF* lzf = (quicklistLZF*)current->entry; stats.compressed_bytes -= lzf->sz; stats.raw_compressed_bytes -= current->sz; } zfree(current->entry); zfree(current); len_--; current = next; } head_ = nullptr; count_ = 0; malloc_size_ = 0; num_offloaded_nodes_ = 0; } void QList::Push(string_view value, Where where) { DVLOG(3) << "Push " << absl::CHexEscape(value) << " " << (where == HEAD ? "HEAD" : "TAIL"); /* The head and tail should never be compressed (we don't attempt to decompress them) */ if (head_) { DCHECK(head_->encoding != QUICKLIST_NODE_ENCODING_LZF); DCHECK(head_->prev->encoding != QUICKLIST_NODE_ENCODING_LZF); } Node* orig = head_; uint32_t orig_id = 0; if (where == TAIL && orig) { orig = orig->prev; orig_id = len_ - 1; } InsertOpt opt = where == HEAD ? BEFORE : AFTER; size_t sz = value.size(); if (ABSL_PREDICT_FALSE(IsLargeElement(sz, fill_))) { InsertPlainNode(orig, value, orig_id, opt); return; } count_++; if (ABSL_PREDICT_TRUE(NodeAllowInsert(orig, fill_, sz))) { auto func = (where == HEAD) ? LP_Prepend : LP_Append; malloc_size_ += NodeSetEntry(orig, func(orig->entry, value)); orig->count++; if (len_ == 1) { // sanity check DCHECK_EQ(malloc_size_, orig->sz); } DCHECK(head_->prev->next == nullptr); return; } Node* node = CreateFromSV(QUICKLIST_NODE_CONTAINER_PACKED, value); InsertNode(orig, node, orig_id, opt); DCHECK(head_->prev->next == nullptr); } string QList::Pop(Where where) { DCHECK_GT(count_, 0u); Node* node = head_; if (where == TAIL) { node = head_->prev; } /* The head and tail should never be compressed */ DCHECK(node->encoding != QUICKLIST_NODE_ENCODING_LZF); DCHECK(head_->prev->next == nullptr); string res; if (ABSL_PREDICT_FALSE(QL_NODE_IS_PLAIN(node))) { // TODO: We could avoid this copy by returning the pointer of the plain node. // But the higher level APIs should support this. res.assign(reinterpret_cast(node->entry), node->sz); DelNode(node); } else { uint8_t* pos = where == HEAD ? lpFirst(node->entry) : lpLast(node->entry); unsigned int vlen; long long vlong; uint8_t* vstr = lpGetValue(pos, &vlen, &vlong); if (vstr) { res.assign(reinterpret_cast(vstr), vlen); } else { res = absl::StrCat(vlong); } DelPackedIndex(node, pos); } DCHECK(head_ == nullptr || head_->prev->next == nullptr); return res; } void QList::AppendListpack(unsigned char* zl) { Node* node = CreateRAW(QUICKLIST_NODE_CONTAINER_PACKED, zl, lpBytes(zl)); node->count = lpLength(node->entry); InsertNode(_Tail(), node, len_ ? len_ - 1 : 0, AFTER); count_ += node->count; } void QList::AppendPlain(unsigned char* data, size_t sz) { Node* node = CreateRAW(QUICKLIST_NODE_CONTAINER_PLAIN, data, sz); InsertNode(_Tail(), node, len_ ? len_ - 1 : 0, AFTER); ++count_; } bool QList::Insert(std::string_view pivot, std::string_view elem, InsertOpt opt) { Iterator it = GetIterator(HEAD); if (it.Valid()) { do { if (it.Get() == pivot) { Insert(it, elem, opt); return true; } } while (it.Next()); } return false; } bool QList::Replace(long index, std::string_view elem) { Iterator it = GetIterator(index); if (it.Valid()) { Replace(it, elem); return true; } return false; } size_t QList::MallocUsed(bool slow) const { size_t node_size = len_ * sizeof(Node) + znallocx(sizeof(QList)); if (slow) { for (Node* node = head_; node; node = node->next) { node_size += zmalloc_usable_size(node->entry); } return node_size; } return node_size + malloc_size_; } void QList::Iterate(IterateFunc cb, long start, long end) const { long llen = Size(); if (llen == 0) return; if (end < 0 || end >= long(Size())) end = Size() - 1; Iterator it = GetIterator(start); if (it.Valid()) { do { if (start > end || !cb(it.Get())) break; start++; } while (it.Next()); } } auto QList::InsertPlainNode(Node* old_node, string_view value, uint32_t old_node_id, InsertOpt insert_opt) -> Node* { Node* new_node = CreateFromSV(QUICKLIST_NODE_CONTAINER_PLAIN, value); InsertNode(old_node, new_node, old_node_id, insert_opt); count_++; return new_node; } void QList::InsertNode(Node* old_node, Node* new_node, uint32_t old_node_id, InsertOpt insert_opt) { if (insert_opt == AFTER) { new_node->prev = old_node; if (old_node) { new_node->next = old_node->next; if (old_node->next) old_node->next->prev = new_node; old_node->next = new_node; if (head_->prev == old_node) // if old_node is tail, update the tail to the new node. head_->prev = new_node; } } else { // BEFORE new_node->next = old_node; if (old_node) { new_node->prev = old_node->prev; // if old_node is not head, link its prev to the new node. // head->prev is tail, so we don't need to update it. if (old_node != head_) old_node->prev->next = new_node; old_node->prev = new_node; } if (head_ == old_node) head_ = new_node; } /* If this insert creates the only element so far, initialize head/tail. */ if (len_ == 0) { head_ = new_node; head_->prev = new_node; } /* Update len first, so in Compress we know exactly len */ len_++; malloc_size_ += new_node->sz; // Calculate final positions AFTER all linkage and len_ updates are complete. uint32_t new_node_id; if (insert_opt == AFTER && old_node) { new_node_id = old_node_id + 1; // new_node inserted after, old_node position unchanged } else { new_node_id = old_node_id; // new_node takes old_node's position old_node_id++; // old_node shifts one position forward } if (old_node) CoolOff(old_node, old_node_id); CoolOff(new_node, new_node_id); } void QList::Insert(Iterator it, std::string_view elem, InsertOpt insert_opt) { DCHECK(it.current_); DCHECK(it.zi_); int full = 0, at_tail = 0, at_head = 0, avail_next = 0, avail_prev = 0; Node* node = it.current_; size_t sz = elem.size(); bool after = insert_opt == AFTER; /* Populate accounting flags for easier boolean checks later */ if (!NodeAllowInsert(node, fill_, sz)) { full = 1; } if (after && (it.offset_ == node->count - 1 || it.offset_ == -1)) { at_tail = 1; if (NodeAllowInsert(node->next, fill_, sz)) { avail_next = 1; } } if (!after && (it.offset_ == 0 || it.offset_ == -(node->count))) { at_head = 1; if (NodeAllowInsert(node->prev, fill_, sz)) { avail_prev = 1; } } uint32_t node_id = it.node_id_; if (ABSL_PREDICT_FALSE(IsLargeElement(sz, fill_))) { if (QL_NODE_IS_PLAIN(node) || (at_tail && after) || (at_head && !after)) { InsertPlainNode(node, elem, node_id, insert_opt); } else { AccessForReads(true, node); ssize_t diff_existing = 0; // if after == true, the order will be node, entry_node, new_node // otherwise: new_node, entry_node, node. Node* new_node = SplitNode(node, it.offset_, after, &diff_existing); Node* entry_node = InsertPlainNode(node, elem, node_id, insert_opt); uint32_t entry_node_id = after ? node_id + 1 : node_id; InsertNode(entry_node, new_node, entry_node_id, insert_opt); malloc_size_ += diff_existing; } return; } /* Now determine where and how to insert the new element */ if (!full) { AccessForReads(true, node); uint8_t* new_entry = LP_Insert(node->entry, elem, it.zi_, after ? LP_AFTER : LP_BEFORE); malloc_size_ += NodeSetEntry(node, new_entry); node->count++; malloc_size_ += RecompressOnly(node, compr_method_); } else { bool insert_tail = at_tail && after; bool insert_head = at_head && !after; if (insert_tail && avail_next) { /* If we are: at tail, next has free space, and inserting after: * - insert entry at head of next node. */ auto* new_node = node->next; AccessForReads(true, new_node); malloc_size_ += NodeSetEntry(new_node, LP_Prepend(new_node->entry, elem)); new_node->count++; malloc_size_ += RecompressOnly(new_node, compr_method_); malloc_size_ += RecompressOnly(node, compr_method_); } else if (insert_head && avail_prev) { /* If we are: at head, previous has free space, and inserting before: * - insert entry at tail of previous node. */ auto* new_node = node->prev; AccessForReads(true, new_node); malloc_size_ += NodeSetEntry(new_node, LP_Append(new_node->entry, elem)); new_node->count++; malloc_size_ += RecompressOnly(new_node, compr_method_); malloc_size_ += RecompressOnly(node, compr_method_); } else if (insert_tail || insert_head) { /* If we are: full, and our prev/next has no available space, then: * - create new node and attach to qlist */ auto* new_node = CreateFromSV(QUICKLIST_NODE_CONTAINER_PACKED, elem); InsertNode(node, new_node, node_id, insert_opt); } else { /* else, node is full we need to split it. */ /* covers both after and !after cases */ AccessForReads(true, node); ssize_t diff_existing = 0; auto* new_node = SplitNode(node, it.offset_, after, &diff_existing); auto func = after ? LP_Prepend : LP_Append; malloc_size_ += NodeSetEntry(new_node, func(new_node->entry, elem)); new_node->count++; InsertNode(node, new_node, node_id, insert_opt); MergeNodes(node); malloc_size_ += diff_existing; } } count_++; } void QList::Replace(Iterator it, std::string_view elem) { Node* node = it.current_; uint8_t* newentry = nullptr; size_t sz = elem.size(); uint32_t node_id = it.node_id_; if (ABSL_PREDICT_TRUE(!QL_NODE_IS_PLAIN(node) && !IsLargeElement(sz, fill_) && (newentry = lpReplace(node->entry, &it.zi_, uint_ptr(elem), sz)) != NULL)) { malloc_size_ += NodeSetEntry(node, newentry); CoolOff(node, node_id); } else if (QL_NODE_IS_PLAIN(node)) { if (IsLargeElement(sz, fill_)) { zfree(node->entry); uint8_t* new_entry = (uint8_t*)zmalloc(sz); memcpy(new_entry, elem.data(), sz); malloc_size_ += NodeSetEntry(node, new_entry); CoolOff(node, node_id); } else { Insert(it, elem, AFTER); DelNode(node); } } else { /* The node is full or data is a large element */ Node *split_node = NULL, *new_node; node->dont_compress = 1; /* Prevent compression in InsertNode() */ /* If the entry is not at the tail, split the node at the entry's offset. */ if (it.offset_ != node->count - 1 && it.offset_ != -1) { ssize_t diff_existing = 0; split_node = SplitNode(node, it.offset_, 1, &diff_existing); malloc_size_ += diff_existing; } /* Create a new node and insert it after the original node. * If the original node was split, insert the split node after the new node. */ new_node = CreateFromSV(IsLargeElement(sz, fill_) ? QUICKLIST_NODE_CONTAINER_PLAIN : QUICKLIST_NODE_CONTAINER_PACKED, elem); // The order is: node, new_node, split_node. InsertNode(node, new_node, node_id, AFTER); if (split_node) InsertNode(new_node, split_node, node_id + 1, AFTER); count_++; /* Delete the replaced element. */ if (node->count == 1) { DelNode(node); } else { unsigned char* p = lpSeek(node->entry, -1); DelPackedIndex(node, p); node->dont_compress = 0; /* Re-enable compression */ new_node = MergeNodes(new_node); /* We can't know if the current node and its sibling nodes are correctly compressed, * and we don't know if they are within the range of compress depth, so we need to * use UpdateCompression() for compression, which checks if node is within compress * depth before compressing. */ // TODO: node_id might be off after merges. CoolOff(new_node, node_id + 1); CoolOff(new_node->prev, node_id); if (new_node->next) CoolOff(new_node->next, node_id + 2); } } } void QList::CoolOff(Node* node, uint32_t node_id) { if (tiering_params_) { // Dry run for offloading decision. // a. Node id is withing the offloadable depth - offload it if not already offloaded. // b. Node id is outside the offloadable depth - but we have too many nodes that are not // offloaded - take the O(n) route to traverse and offload them. The reason for having such // nodes is because (a) handles node that we touch during operations. // if for example we just perform lpush, then we won't touch any interior nodes, and they // will never get offloaded. The good news is that once interior nodes are offloaded, // we won't need to traverse them again for "trivial" access patterns unless they // get accessed again. Another reason for missing offloaded nodes is that node_id can be // off due to merges (can be improved in future). if (node_id >= tiering_params_->node_depth_threshold && node_id + tiering_params_->node_depth_threshold < len_) { if (!node->offloaded) { OffloadNode(node); } } else if (num_offloaded_nodes_ * 2 + tiering_params_->node_depth_threshold * 2 < len_) { // We check `num_offloaded_nodes_ * 2` above to avoid frequent traversals. // So only when the gap between offloaded and non-offloaded nodes is large enough, // we do a traversal to offload more nodes. auto* fw = head_; auto* rev = head_->prev; uint32_t traverse_node_id = 0; // Traverse from both ends towards the middle as we expect more offloads towards the ends // due to usual access patterns of adding items via lpush/rpush. while (traverse_node_id <= len_ / 2 && (num_offloaded_nodes_ + 2 * tiering_params_->node_depth_threshold) < len_) { if (traverse_node_id >= tiering_params_->node_depth_threshold) { if (fw->offloaded == 0) { OffloadNode(fw); } // Avoid offloading the same node twice when fw and rev meet in the middle. if (rev != fw && rev->offloaded == 0) { OffloadNode(rev); } } fw = fw->next; rev = rev->prev; traverse_node_id++; } } } /* Force 'quicklist' to meet compression guidelines set by compress depth. * The only way to guarantee interior nodes get compressed is to iterate * to our "interior" compress depth then compress the next node we find. * If compress depth is larger than the entire list, we return immediately. */ if (node->recompress) CompressRaw(node, this->compr_method_); else this->CompressByDepth(node); } void QList::CompressByDepth(Node* node) { if (len_ == 0) return; /* The head and tail should never be compressed (we should not attempt to recompress them) */ DCHECK(head_->recompress == 0 && head_->prev->recompress == 0); /* If length is less than our compress depth (from both sides), * we can't compress anything. */ if (!AllowCompression() || len_ < (unsigned int)(compress_ * 2)) return; /* Iterate until we reach compress depth for both sides of the list.a * Note: because we do length checks at the *top* of this function, * we can skip explicit null checks below. Everything exists. */ Node* forward = head_; Node* reverse = head_->prev; int depth = 0; int in_depth = 0; while (depth++ < compress_) { malloc_size_ += TryDecompressInternal(false, forward); malloc_size_ += TryDecompressInternal(false, reverse); if (forward == node || reverse == node) in_depth = 1; /* We passed into compress depth of opposite side of the quicklist * so there's no need to compress anything and we can exit. */ if (forward == reverse || forward->next == reverse) return; forward = forward->next; reverse = reverse->prev; } if (!in_depth && node) { malloc_size_ += TryCompress(node, this->compr_method_); } /* At this point, forward and reverse are one node beyond depth */ malloc_size_ += TryCompress(forward, this->compr_method_); malloc_size_ += TryCompress(reverse, this->compr_method_); } void QList::AccessForReads(bool recompress, Node* node) { DCHECK(node); stats.total_node_reads++; if (node->offloaded) { DCHECK(tiering_params_); stats.onload_requests++; num_offloaded_nodes_--; node->offloaded = 0; } if (len_ > 2 && node != head_ && node->next != nullptr) { stats.interior_node_reads++; } ssize_t res = TryDecompressInternal(recompress, node); malloc_size_ += res; } /* Attempt to merge listpacks within two nodes on either side of 'center'. * * We attempt to merge: * - (center->prev->prev, center->prev) * - (center->next, center->next->next) * - (center->prev, center) * - (center, center->next) * * Returns the new 'center' after merging. */ auto QList::MergeNodes(Node* center) -> Node* { Node *prev = NULL, *prev_prev = NULL, *next = NULL; Node *next_next = NULL, *target = NULL; if (center->prev) { prev = center->prev; if (center->prev->prev) prev_prev = center->prev->prev; } if (center->next) { next = center->next; if (center->next->next) next_next = center->next->next; } /* Try to merge prev_prev and prev */ if (NodeAllowMerge(prev, prev_prev, fill_)) { ListpackMerge(prev_prev, prev); prev_prev = prev = NULL; /* they could have moved, invalidate them. */ } /* Try to merge next and next_next */ if (NodeAllowMerge(next, next_next, fill_)) { ListpackMerge(next, next_next); next = next_next = NULL; /* they could have moved, invalidate them. */ } /* Try to merge center node and previous node */ if (NodeAllowMerge(center, center->prev, fill_)) { target = ListpackMerge(center->prev, center); center = NULL; /* center could have been deleted, invalidate it. */ } else { /* else, we didn't merge here, but target needs to be valid below. */ target = center; } /* Use result of center merge (or original) to merge with next node. */ if (NodeAllowMerge(target, target->next, fill_)) { target = ListpackMerge(target, target->next); } return target; } /* Given two nodes, try to merge their listpacks. * * This helps us not have a quicklist with 3 element listpacks if * our fill factor can handle much higher levels. * * Note: 'a' must be to the LEFT of 'b'. * * After calling this function, both 'a' and 'b' should be considered * unusable. The return value from this function must be used * instead of re-using any of the quicklistNode input arguments. * * Returns the input node picked to merge against or NULL if * merging was not possible. */ auto QList::ListpackMerge(Node* a, Node* b) -> Node* { AccessForReads(false, a); AccessForReads(false, b); if ((lpMerge(&a->entry, &b->entry))) { /* We merged listpacks! Now remove the unused Node. */ Node *keep = NULL, *nokeep = NULL; if (!a->entry) { nokeep = a; keep = b; } else if (!b->entry) { nokeep = b; keep = a; } keep->count = lpLength(keep->entry); malloc_size_ += NodeSetEntry(keep, keep->entry); keep->recompress = 0; /* Prevent 'keep' from being recompressed if * it becomes head or tail after merging. */ nokeep->count = 0; DelNode(nokeep); CoolOff(keep, 0); // TODO: node_id is unknown here, so just pass 0. return keep; } /* else, the merge returned NULL and nothing changed. */ return NULL; } void QList::DelNode(Node* node) { if (node->next) node->next->prev = node->prev; if (node == head_) { head_ = node->next; } else { // for non-head nodes, update prev->next to point to node->next // (If node==head, prev is tail and should always point to NULL). node->prev->next = node->next; if (node == head_->prev) // tail head_->prev = node->prev; } /* Update len first, so in CompressByDepth we know exactly len */ len_--; count_ -= node->count; malloc_size_ -= node->sz; if (node->offloaded) { num_offloaded_nodes_--; } /* If we deleted a node within our compress depth, we * now have compressed nodes needing to be decompressed. */ CompressByDepth(NULL); zfree(node->entry); zfree(node); } /* Delete one entry from list given the node for the entry and a pointer * to the entry in the node. * * Note: DelPackedIndex() *requires* uncompressed nodes because you * already had to get *p from an uncompressed node somewhere. * * Returns true if the entire node was deleted, false if node still exists. * Also updates in/out param 'p' with the next offset in the listpack. */ bool QList::DelPackedIndex(Node* node, uint8_t* p) { DCHECK(!QL_NODE_IS_PLAIN(node)); if (node->count == 1) { DelNode(node); return true; } malloc_size_ += NodeSetEntry(node, lpDelete(node->entry, p, NULL)); node->count--; count_--; return false; } void QList::OffloadNode(Node* node) { DCHECK(tiering_params_ && node->offloaded == 0); num_offloaded_nodes_++; stats.offload_requests++; node->offloaded = 1; } void QList::InitIteratorEntry(Iterator* it) const { DCHECK(it->current_); const_cast(this)->AccessForReads(true, it->current_); if (QL_NODE_IS_PLAIN(it->current_)) { it->zi_ = it->current_->entry; } else { it->zi_ = lpSeek(it->current_->entry, it->offset_); } } auto QList::GetIterator(Where where) const -> Iterator { Iterator it; it.owner_ = this; it.zi_ = NULL; if (where == HEAD) { it.current_ = head_; it.offset_ = 0; it.direction_ = FWD; it.node_id_ = 0; } else { it.current_ = _Tail(); it.offset_ = -1; it.direction_ = REV; it.node_id_ = len_ - 1; } if (it.current_) { InitIteratorEntry(&it); } return it; } auto QList::GetIterator(long idx) const -> Iterator { unsigned long long accum = 0; int forward = idx < 0 ? 0 : 1; /* < 0 -> reverse, 0+ -> forward */ uint64_t index = forward ? idx : (-idx) - 1; if (index >= count_) return {}; DCHECK(head_); /* Seek in the other direction if that way is shorter. */ int seek_forward = forward; unsigned long long seek_index = index; if (index > (count_ - 1) / 2) { seek_forward = !forward; seek_index = count_ - 1 - index; } Node* n = seek_forward ? head_ : head_->prev; unsigned node_cnt = 0; while (ABSL_PREDICT_TRUE(n)) { if ((accum + n->count) > seek_index) { break; } else { accum += n->count; n = seek_forward ? n->next : n->prev; node_cnt++; } } DCHECK(n); if (!n) return {}; /* Fix accum so it looks like we seeked in the other direction. */ if (seek_forward != forward) accum = count_ - n->count - accum; Iterator iter; iter.owner_ = this; iter.direction_ = forward ? FWD : REV; iter.current_ = n; iter.node_id_ = seek_forward ? node_cnt : (len_ - 1 - node_cnt); if (forward) { /* forward = normal head-to-tail offset. */ iter.offset_ = index - accum; } else { /* reverse = need negative offset for tail-to-head, so undo * the result of the original index = (-idx) - 1 above. */ iter.offset_ = (-index) - 1 + accum; } InitIteratorEntry(&iter); return iter; } auto QList::Erase(Iterator it) -> Iterator { DCHECK(it.current_); Node* node = it.current_; Node* prev = node->prev; Node* next = node->next; bool deleted_node = false; if (QL_NODE_IS_PLAIN(node)) { DelNode(node); deleted_node = true; } else { deleted_node = DelPackedIndex(node, it.zi_); } it.zi_ = NULL; // Reset current entry pointer // If current node is deleted, we must update iterator node and offset. if (deleted_node) { if (it.direction_ == FWD) { it.current_ = next; it.offset_ = 0; it.node_id_++; } else if (it.direction_ == REV) { it.current_ = len_ ? prev : nullptr; it.offset_ = -1; it.node_id_ = it.node_id_ ? it.node_id_ - 1 : len_ - 1; } } if (it.current_) { InitIteratorEntry(&it); } // Sanity, should be noop in release mode. if (len_ == 1) { DCHECK_EQ(count_, head_->count); DCHECK_EQ(malloc_size_, head_->sz); } /* else if (!deleted_node), no changes needed. * we already reset iter->zi above, and the existing iter->offset * doesn't move again because: * - [1, 2, 3] => delete offset 1 => [1, 3]: next element still offset 1 * - [1, 2, 3] => delete offset 0 => [2, 3]: next element still offset 0 * if we deleted the last element at offset N and now * length of this listpack is N-1, the next call into * quicklistNext() will jump to the next node. */ return it; } bool QList::Erase(const long start, unsigned count) { if (count == 0) return false; unsigned extent = count; /* range is inclusive of start position */ if (start >= 0 && extent > (count_ - start)) { /* if requesting delete more elements than exist, limit to list size. */ extent = count_ - start; } else if (start < 0 && extent > (unsigned long)(-start)) { /* else, if at negative offset, limit max size to rest of list. */ extent = -start; /* c.f. LREM -29 29; just delete until end. */ } Iterator it = GetIterator(start); Node* node = it.current_; long offset = it.offset_; /* iterate over next nodes until everything is deleted. */ while (extent) { Node* next = node->next; unsigned long del; int delete_entire_node = 0; if (offset == 0 && extent >= node->count) { /* If we are deleting more than the count of this node, we * can just delete the entire node without listpack math. */ delete_entire_node = 1; del = node->count; } else if (offset >= 0 && extent + offset >= node->count) { /* If deleting more nodes after this one, calculate delete based * on size of current node. */ del = node->count - offset; } else if (offset < 0) { /* If offset is negative, we are in the first run of this loop * and we are deleting the entire range * from this start offset to end of list. Since the Negative * offset is the number of elements until the tail of the list, * just use it directly as the deletion count. */ del = -offset; /* If the positive offset is greater than the remaining extent, * we only delete the remaining extent, not the entire offset. */ if (del > extent) del = extent; } else { /* else, we are deleting less than the extent of this node, so * use extent directly. */ del = extent; } if (delete_entire_node || QL_NODE_IS_PLAIN(node)) { DelNode(node); } else { AccessForReads(true, node); malloc_size_ += NodeSetEntry(node, lpDeleteRange(node->entry, offset, del)); node->count -= del; count_ -= del; if (node->count == 0) { DelNode(node); } else { malloc_size_ += RecompressOnly(node, compr_method_); } } extent -= del; node = next; offset = 0; } return true; } uint8_t* QList::TryExtractListpack() { if (len_ != 1 || QL_NODE_IS_PLAIN(head_) || !ShouldStoreAsListPack(head_->sz) || head_->IsCompressed()) { return nullptr; } uint8_t* res = std::exchange(head_->entry, nullptr); DelNode(head_); return res; } bool QList::Iterator::Next() { if (!current_) return false; int plain = QL_NODE_IS_PLAIN(current_); // Advance to the next element in the current node. if (ABSL_PREDICT_FALSE(plain)) { zi_ = NULL; } else { unsigned char* (*nextFn)(unsigned char*, unsigned char*) = lpNext; int offset_update = 1; if (direction_ == REV) { DCHECK_EQ(REV, direction_); nextFn = lpPrev; offset_update = -1; } zi_ = nextFn(current_->entry, zi_); offset_ += offset_update; } if (zi_) return true; // Move to the next node. const_cast(owner_)->CompressByDepth(current_); if (direction_ == FWD) { /* Forward traversal, Jumping to start of next node */ current_ = current_->next; offset_ = 0; node_id_++; } else { /* Reverse traversal, Jumping to end of previous node */ DCHECK_EQ(REV, direction_); offset_ = -1; current_ = (current_ == owner_->head_) ? nullptr : current_->prev; node_id_--; } if (!current_) return false; owner_->InitIteratorEntry(this); return zi_ != nullptr; } auto QList::Iterator::Get() const -> Entry { int plain = QL_NODE_IS_PLAIN(current_); if (ABSL_PREDICT_FALSE(plain)) { char* str = reinterpret_cast(current_->entry); return Entry(str, current_->sz); } DCHECK(zi_); /* Populate value from existing listpack position */ unsigned int sz = 0; long long val; uint8_t* ptr = lpGetValue(zi_, &sz, &val); return ptr ? Entry(reinterpret_cast(ptr), sz) : Entry(val); } } // namespace dfly ================================================ FILE: src/core/qlist.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "core/collection_entry.h" #define QL_COMP_BITS 16 #define QL_BM_BITS 4 /* quicklist node encodings */ #define QUICKLIST_NODE_ENCODING_RAW 1 #define QUICKLIST_NODE_ENCODING_LZF 2 #define QLIST_NODE_ENCODING_LZ4 3 /* quicklist node container formats */ #define QUICKLIST_NODE_CONTAINER_PLAIN 1 #define QUICKLIST_NODE_CONTAINER_PACKED 2 namespace dfly { class PageUsage; // Heuristic: for values smaller than 2 KiB we prefer the compact listpack // representation. 2048 was chosen as a conservative threshold that matches // common quicklist usage patterns and avoids creating very large listpacks // that are costly to reallocate or compress. inline bool ShouldStoreAsListPack(size_t size) { return size < 2048; } class QList { public: enum Where : uint8_t { TAIL, HEAD }; enum COMPR_METHOD : uint8_t { LZF = 0, LZ4 = 1 }; /* Node is a 40 byte struct describing a listpack for a quicklist. * We use bit fields keep the Node at 40 bytes. * count: 16 bits, max 65536 (max lp bytes is 65k, so max count actually < 32k). * encoding: 2 bits, RAW=1, LZF=2. * container: 2 bits, PLAIN=1 (a single item as char array), PACKED=2 (listpack with multiple * items). recompress: 1 bit, bool, true if node is temporary decompressed for usage. * attempted_compress: 1 bit, boolean, used for verifying during testing. * dont_compress: 1 bit, boolean, used for preventing compression of entry. * */ struct Node { Node* prev; Node* next; unsigned char* entry; size_t sz : 48; /* entry size in bytes */ size_t count : 16; /* count of items in listpack */ uint16_t encoding : 2; /* RAW==1 or LZF==2 */ uint16_t container : 2; /* PLAIN==1 or PACKED==2 */ uint16_t recompress : 1; /* was this node previous compressed? */ uint16_t attempted_compress : 1; /* node can't compress; too small */ uint16_t dont_compress : 1; /* prevent compression of entry that will be used later */ uint16_t offloaded : 1; /* node is offloaded to colder storage */ uint16_t reserved1 : 8; /* reserved for future use */ uint16_t reserved2; /* more bits to steal for future usage */ uint32_t reserved3; /* more bits to steal for future usage */ bool IsCompressed() const { return encoding != QUICKLIST_NODE_ENCODING_RAW; } size_t GetLZF(void** data) const; }; using Entry = CollectionEntry; class Iterator { public: // Returns true if the iterator is valid (points to an element). bool Valid() const { return zi_ != nullptr; } Entry Get() const; // Advances to the next/prev element. Returns false if no more entries. bool Next(); private: const QList* owner_ = nullptr; Node* current_ = nullptr; unsigned char* zi_ = nullptr; /* points to the current element */ int32_t offset_ = 0; /* offset in current listpack */ int32_t node_id_ = 0; /* node index in the list, 0 is head */ uint8_t direction_ = 1; friend class QList; }; using IterateFunc = absl::FunctionRef; enum InsertOpt : uint8_t { BEFORE, AFTER }; struct TieringParams { // TODO: hook functions and params that allow qlist offloading nodes to colder storage. uint32_t node_depth_threshold = 2; }; /** * fill: The number of entries allowed per internal list node can be specified * as a fixed maximum size or a maximum number of elements. * For a fixed maximum size, use -5 through -1, meaning: * -5: max size: 64 Kb <-- not recommended for normal workloads * -4: max size: 32 Kb <-- not recommended * -3: max size: 16 Kb <-- probably not recommended * -2: max size: 8 Kb <-- good * -1: max size: 4 Kb <-- good * Positive numbers mean store up to _exactly_ that number of elements * per list node. * The highest performing option is usually -2 (8 Kb size) or -1 (4 Kb size), * but if your use case is unique, adjust the settings as necessary. * * * Lists may also be compressed. * "compress" is the number of quicklist listpack nodes from *each* side of * the list to *exclude* from compression. The head and tail of the list * are always uncompressed for fast push/pop operations. Settings are: * 0: disable all list compression * 1: depth 1 means "don't start compressing until after 1 node into the list, * going from either the head or tail" * So: [head]->node->node->...->node->[tail] * [head], [tail] will always be uncompressed; inner nodes will compress. * 2: [head]->[next]->node->node->...->node->[prev]->[tail] * 2 here means: don't compress head or head->next or tail->prev or tail, * but compress all nodes between them. * 3: [head]->[next]->[next]->node->node->...->node->[prev]->[prev]->[tail] * etc. * */ explicit QList(int fill = -2, int compress = 0); QList(QList&&) noexcept; QList(const QList&) = delete; ~QList(); QList& operator=(const QList&) = delete; QList& operator=(QList&&) noexcept; size_t Size() const { return count_; } void Clear() noexcept; void Push(std::string_view value, Where where); // Returns the popped value. Precondition: list is not empty. std::string Pop(Where where); void AppendListpack(uint8_t* zl); void AppendPlain(uint8_t* zl, size_t sz); // Returns true if pivot found and elem inserted, false otherwise. bool Insert(std::string_view pivot, std::string_view elem, InsertOpt opt); void Insert(Iterator it, std::string_view elem, InsertOpt opt); // Returns true if item was replaced, false if index is out of range. bool Replace(long index, std::string_view elem); size_t MallocUsed(bool slow) const; // Iterates over entries from start to end (inclusive). void Iterate(IterateFunc cb, long start, long end) const; // Returns an iterator to tail or the head of the list. // result.Valid() is true if the list is not empty. Iterator GetIterator(Where where) const; // Returns an iterator at a specific index 'idx', // or Invalid iterator if index is out of range. // negative index - means counting from the tail. // result.Valid() is true if the index is within range. Iterator GetIterator(long idx) const; uint32_t node_count() const { return len_; } unsigned compress_param() const { return compress_; } Iterator Erase(Iterator it); // Returns true if elements were deleted, false if list has not changed. // Negative start index is allowed. bool Erase(long start, unsigned count); // Needed by tests and the rdb code. const Node* Head() const { return head_; } const Node* Tail() const { return _Tail(); } // Returns nullptr if quicklist does not fit the necessary requirements // to be converted to listpack, and listpack otherwise. The ownership over the listpack // blob is moved to the caller. uint8_t* TryExtractListpack(); void set_fill(int fill) { fill_ = fill; } void set_compr_method(COMPR_METHOD cm) { compr_method_ = static_cast(cm); } static void SetPackedThreshold(unsigned threshold); // Moves nodes away from underused pages by reallocating if the underlying page usage is low. // Returns count of nodes reallocated to help in testing. size_t DefragIfNeeded(PageUsage* page_usage); void SetTieringParams(const TieringParams& params); struct Stats { uint64_t compression_attempts = 0; // compression attempts with compression ratio that was not good enough to keep. // Subset of compression_attempts. uint64_t bad_compression_attempts = 0; uint64_t decompression_calls = 0; // How many bytes we currently keep compressed. size_t compressed_bytes = 0; // how many bytes we compressed from. // Compressed savings are calculated as raw_compressed_bytes - compressed_bytes. size_t raw_compressed_bytes = 0; uint64_t interior_node_reads = 0; uint64_t total_node_reads = 0; uint64_t offload_requests = 0; uint64_t onload_requests = 0; Stats& operator+=(const Stats& other); }; static __thread Stats stats; private: bool AllowCompression() const { return compress_ != 0; } Node* _Tail() const { return head_ ? head_->prev : nullptr; } // Returns newly created plain node. Node* InsertPlainNode(Node* old_node, std::string_view elem, uint32_t old_node_id, InsertOpt insert_opt); void InsertNode(Node* old_node, Node* new_node, uint32_t old_node_id, InsertOpt insert_opt); // Reduces the "warmth" of the node. Current implementation can decide on // compressing the node based on its position in the list. void CoolOff(Node* node, uint32_t node_id); void Replace(Iterator it, std::string_view elem); void CompressByDepth(Node* node); // Prepares the node for read access. void AccessForReads(bool recompress, Node* node); Node* MergeNodes(Node* node); // Deletes one of the nodes and returns the other. Node* ListpackMerge(Node* a, Node* b); void DelNode(Node* node); bool DelPackedIndex(Node* node, uint8_t* p); void OffloadNode(Node* node); // Initializes iterator's zi_ to point to the element at offset_. // Decompresses the node if needed. Assumes current_ is not null. void InitIteratorEntry(Iterator* it) const; Node* head_ = nullptr; size_t malloc_size_ = 0; // size of the quicklist struct uint32_t count_ = 0; /* total count of all entries in all listpacks */ uint32_t len_ = 0; /* number of quicklistNodes */ int16_t fill_; /* fill factor for individual nodes */ int16_t compr_method_ : 2; // 0 - lzf, 1 - lz4 int16_t reserved1_ : 14; unsigned compress_ : QL_COMP_BITS; /* depth of end nodes not to compress;0=off */ unsigned bookmark_count_ : QL_BM_BITS; unsigned reserved2_ : 12; uint32_t num_offloaded_nodes_ = 0; std::unique_ptr tiering_params_; }; } // namespace dfly ================================================ FILE: src/core/qlist_test.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/qlist.h" #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/mi_memory_resource.h" #include "core/page_usage/page_usage_stats.h" #include "io/file.h" #include "io/line_reader.h" extern "C" { #include "redis/listpack.h" #include "redis/zmalloc.h" } /* quicklist compression disable */ #define QUICKLIST_NOCOMPRESS 0 namespace dfly { using namespace std; using namespace testing; using absl::StrCat; static int ql_verify_compress(const QList& ql) { int errors = 0; unsigned compress_param = ql.compress_param(); if (compress_param > 0) { const auto* node = ql.Head(); unsigned int low_raw = compress_param; unsigned int high_raw = ql.node_count() - compress_param; for (unsigned int at = 0; at < ql.node_count(); at++, node = node->next) { if (node && (at < low_raw || at >= high_raw)) { if (node->encoding != QUICKLIST_NODE_ENCODING_RAW) { LOG(ERROR) << "Incorrect compression: node " << at << " is compressed at depth " << compress_param << " ((" << low_raw << "," << high_raw << " total nodes: " << ql.node_count() << "; size: " << node->sz << "; recompress: " << node->recompress; errors++; } } else { if (node->encoding != QUICKLIST_NODE_ENCODING_LZF && !node->attempted_compress) { LOG(ERROR) << absl::StrFormat( "Incorrect non-compression: node %d is NOT " "compressed at depth %d ((%u, %u); total " "nodes: %lu; size: %zu; recompress: %d; attempted: %d)", at, compress_param, low_raw, high_raw, ql.node_count(), node->sz, node->recompress, node->attempted_compress); errors++; } } } } return errors; } /* Verify list metadata matches physical list contents. */ static int ql_verify(const QList& ql, uint32_t nc, uint32_t count, uint32_t head_count, uint32_t tail_count) { int errors = 0; if (nc != ql.node_count()) { LOG(ERROR) << "quicklist length wrong: expected " << nc << " got " << ql.node_count(); errors++; } if (count != ql.Size()) { LOG(ERROR) << "quicklist count wrong: expected " << count << " got " << ql.Size(); errors++; } auto* node = ql.Head(); size_t node_size = 0; while (node) { node_size += node->count; node = node->next; CHECK(node != ql.Head()); } if (node_size != ql.Size()) { LOG(ERROR) << "quicklist cached count not match actual count: expected " << ql.Size() << " got " << node_size; errors++; } node = ql.Tail(); node_size = 0; while (node) { node_size += node->count; node = (node == ql.Head()) ? nullptr : node->prev; } if (node_size != ql.Size()) { LOG(ERROR) << "has different forward count than reverse count! " "Forward count is " << ql.Size() << ", reverse count is " << node_size; errors++; } if (ql.node_count() == 0 && errors == 0) { return 0; } if (head_count != ql.Head()->count && head_count != lpLength(ql.Head()->entry)) { LOG(ERROR) << absl::StrFormat("head count wrong: expected %u got cached %u vs. actual %lu", head_count, ql.Head()->count, lpLength(ql.Head()->entry)); errors++; } if (tail_count != ql.Tail()->count && tail_count != lpLength(ql.Tail()->entry)) { LOG(ERROR) << "tail count wrong: expected " << tail_count << "got cached " << ql.Tail()->count << " vs. actual " << lpLength(ql.Tail()->entry); errors++; } errors += ql_verify_compress(ql); return errors; } static void SetupMalloc() { // configure redis lib zmalloc which requires mimalloc heap to work. auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); mi_option_set(mi_option_purge_delay, -1); // disable purging of segments (affects benchmarks) } class QListTest : public ::testing::Test { protected: QListTest() : mr_(mi_heap_get_backing()) { } static void SetUpTestSuite() { SetupMalloc(); } static void TearDownTestSuite() { mi_heap_collect(mi_heap_get_backing(), true); auto cb_visit = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { LOG(ERROR) << "Unfreed allocations: block_size " << block_size << ", allocated: " << area->used * block_size; return true; }; mi_heap_visit_blocks(mi_heap_get_backing(), false /* do not visit all blocks*/, cb_visit, nullptr); } vector ToItems() const; MiMemoryResource mr_; QList ql_; }; vector QListTest::ToItems() const { vector res; auto cb = [&](const QList::Entry& e) { res.push_back(e.to_string()); return true; }; ql_.Iterate(cb, 0, ql_.Size()); return res; } TEST_F(QListTest, Basic) { EXPECT_EQ(0, ql_.Size()); ql_.Push("abc", QList::HEAD); EXPECT_EQ(1, ql_.Size()); EXPECT_TRUE(ql_.Tail() == ql_.Head()); EXPECT_LE(ql_.MallocUsed(false), ql_.MallocUsed(true)); auto it = ql_.GetIterator(QList::HEAD); ASSERT_TRUE(it.Valid()); // Iterator is valid immediately. EXPECT_EQ("abc", it.Get().view()); ASSERT_FALSE(it.Next()); ql_.Push("def", QList::TAIL); EXPECT_EQ(2, ql_.Size()); EXPECT_LE(ql_.MallocUsed(false), ql_.MallocUsed(true)); it = ql_.GetIterator(QList::TAIL); ASSERT_TRUE(it.Valid()); EXPECT_EQ("def", it.Get().view()); ASSERT_TRUE(it.Next()); EXPECT_EQ("abc", it.Get().view()); ASSERT_FALSE(it.Next()); it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); EXPECT_EQ("abc", it.Get().view()); it = ql_.GetIterator(-1); ASSERT_TRUE(it.Valid()); EXPECT_EQ("def", it.Get().view()); vector items = ToItems(); EXPECT_THAT(items, ElementsAre("abc", "def")); EXPECT_GT(ql_.MallocUsed(false), ql_.MallocUsed(true) * 0.8); } TEST_F(QListTest, ListPack) { string_view sv = "abcded"sv; uint8_t* lp1 = lpPrepend(lpNew(0), (uint8_t*)sv.data(), sv.size()); uint8_t* lp2 = lpAppend(lpNew(0), (uint8_t*)sv.data(), sv.size()); ASSERT_EQ(lpBytes(lp1), lpBytes(lp2)); ASSERT_EQ(0, memcmp(lp1, lp2, lpBytes(lp1))); lpFree(lp1); lpFree(lp2); } TEST_F(QListTest, InsertDelete) { EXPECT_FALSE(ql_.Insert("abc", "def", QList::BEFORE)); ql_.Push("abc", QList::HEAD); EXPECT_TRUE(ql_.Insert("abc", "def", QList::BEFORE)); auto items = ToItems(); EXPECT_THAT(items, ElementsAre("def", "abc")); EXPECT_TRUE(ql_.Insert("abc", "123456", QList::AFTER)); items = ToItems(); EXPECT_THAT(items, ElementsAre("def", "abc", "123456")); auto it = ql_.GetIterator(QList::HEAD); ASSERT_TRUE(it.Valid()); // Erase the items one by one. it = ql_.Erase(it); items = ToItems(); EXPECT_THAT(items, ElementsAre("abc", "123456")); ASSERT_TRUE(it.Valid()); ASSERT_EQ("abc", it.Get().view()); it = ql_.Erase(it); items = ToItems(); EXPECT_THAT(items, ElementsAre("123456")); ASSERT_TRUE(it.Valid()); ASSERT_EQ(123456, it.Get().ival()); it = ql_.Erase(it); items = ToItems(); EXPECT_THAT(items, ElementsAre()); ASSERT_FALSE(it.Valid()); EXPECT_EQ(0, ql_.Size()); } TEST_F(QListTest, EraseLastElementInNodeAdvancesToNextNode) { // Regression test for iterator semantics: when erasing the last element // within a multi-entry node and another node follows, the iterator should // correctly advance to the first element of the next node. // Create a QList with fill=2 to ensure max 2 elements per node ql_ = QList(2, QUICKLIST_NOCOMPRESS); // Push 3 elements: this creates 2 nodes (first with 2 elements, second with 1) ql_.Push("first", QList::HEAD); // Will be at index 2 after all pushes ql_.Push("second", QList::HEAD); // Will be at index 1 after all pushes ql_.Push("third", QList::HEAD); // Will be at index 0 after all pushes // Verify we have 2 nodes as expected ASSERT_EQ(2, ql_.node_count()); ASSERT_EQ(3, ql_.Size()); // Node structure should be: // Node 1: ["third", "second"] // Node 2: ["first"] auto items = ToItems(); EXPECT_THAT(items, ElementsAre("third", "second", "first")); // Get iterator to "second" (last element in first node) auto it = ql_.GetIterator(1); ASSERT_TRUE(it.Valid()); ASSERT_EQ("second", it.Get().view()); // Erase "second" - this is the last element in the first node it = ql_.Erase(it); // Iterator should now point to "first" (first element of the second node) ASSERT_TRUE(it.Valid()); EXPECT_EQ("first", it.Get().view()); // Verify the list is correct items = ToItems(); EXPECT_THAT(items, ElementsAre("third", "first")); EXPECT_EQ(2, ql_.Size()); } TEST_F(QListTest, PushPlain) { // push a value large enough to trigger plain node insertion. string val(9000, 'a'); ql_.Push(val, QList::HEAD); auto items = ToItems(); EXPECT_THAT(items, ElementsAre(val)); } TEST_F(QListTest, GetNum) { ql_.Push("1251977", QList::HEAD); QList::Iterator it = ql_.GetIterator(QList::HEAD); ASSERT_TRUE(it.Valid()); EXPECT_EQ(1251977, it.Get().ival()); } TEST_F(QListTest, CompressionPlain) { char buf[256]; QList::SetPackedThreshold(1); ql_ = QList(-2, 1); for (int i = 0; i < 500; i++) { /* Set to 256 to allow the node to be triggered to compress, * if it is less than 48(nocompress), the test will be successful. */ snprintf(buf, sizeof(buf), "hello%d", i); ql_.Push(string_view{buf, sizeof(buf)}, QList::HEAD); } QList::SetPackedThreshold(0); QList::Iterator it = ql_.GetIterator(QList::TAIL); int i = 0; ASSERT_TRUE(it.Valid()); do { string_view sv = it.Get().view(); ASSERT_EQ(sizeof(buf), sv.size()); ASSERT_TRUE(absl::StartsWith(sv, StrCat("hello", i))); i++; } while (it.Next()); EXPECT_EQ(500, i); } TEST_F(QListTest, LargeValues) { string val(100000, 'a'); ql_.Push(val, QList::HEAD); ql_.Push(val, QList::HEAD); ql_.Pop(QList::HEAD); auto items = ToItems(); EXPECT_THAT(items, ElementsAre(val)); } TEST_F(QListTest, RemoveListpack) { ql_.Push("ABC", QList::TAIL); ql_.Push("DEF", QList::TAIL); auto it = ql_.GetIterator(QList::TAIL); ASSERT_TRUE(it.Valid()); // Iterator is valid immediately. ql_.Erase(it); it = ql_.GetIterator(QList::TAIL); ASSERT_TRUE(it.Valid()); it = ql_.Erase(it); ASSERT_FALSE(it.Valid()); } TEST_F(QListTest, DefragListpackRaw) { PageUsage page_usage{CollectPageStats::YES, 100.0}; page_usage.SetForceReallocate(true); ql_.Push("first", QList::TAIL); ql_.Push("second", QList::TAIL); ASSERT_EQ(ql_.DefragIfNeeded(&page_usage), 1); EXPECT_THAT(ToItems(), ElementsAre("first", "second")); ql_.Clear(); } TEST_F(QListTest, DefragPlainTextRaw) { PageUsage page_usage{CollectPageStats::YES, 100.0}; page_usage.SetForceReallocate(true); string big(100000, 'x'); ql_.Push(big, QList::HEAD); ASSERT_EQ(ql_.DefragIfNeeded(&page_usage), 1); EXPECT_THAT(ToItems(), ElementsAre(big)); ql_.Clear(); } TEST_F(QListTest, DefragmentListpackCompressed) { PageUsage page_usage{CollectPageStats::YES, 100.0}; page_usage.SetForceReallocate(true); // MIN_COMPRESS_BYTES = 256 char buf[256]; constexpr auto items_per_list = 4; constexpr auto total_items = 20; ql_ = QList{items_per_list, 1}; for (auto i = 0; i < total_items; ++i) { absl::SNPrintF(buf, 256, "test__%d", i); ql_.Push(string_view{buf, 256}, QList::TAIL); } ASSERT_EQ(total_items / items_per_list, ql_.DefragIfNeeded(&page_usage)); auto i = 0; auto it = ql_.GetIterator(QList::HEAD); ASSERT_TRUE(it.Valid()); do { auto v = it.Get().view(); ASSERT_EQ(v.size(), 256); ASSERT_TRUE(absl::StartsWith(v, StrCat("test__", i))); ++i; } while (it.Next()); ASSERT_EQ(i, total_items); } TEST_F(QListTest, Tiering) { QList::stats.offload_requests = 0; ql_.SetTieringParams(QList::TieringParams{.node_depth_threshold = 1}); for (int i = 0; i < 8000; i++) { ql_.Push(absl::StrCat("value", i), QList::TAIL); } EXPECT_EQ(QList::stats.offload_requests, 9); } using FillCompress = tuple; class PrintToFillCompress { public: std::string operator()(const TestParamInfo& info) const { int fill = get<0>(info.param); int compress = get<1>(info.param); QList::COMPR_METHOD method = get<2>(info.param); string fill_str = fill >= 0 ? absl::StrCat("f", fill) : absl::StrCat("fminus", -fill); string method_str = method == QList::LZF ? "lzf" : "lz4"; return absl::StrCat(fill_str, "compr", compress, method_str); } }; class OptionsTest : public QListTest, public WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(Matrix, OptionsTest, Combine(Values(-5, -4, -3, -2, -1, 0, 1, 2, 32, 66, 128, 999), Values(0, 1, 2, 3, 4, 5, 6, 10), Values(QList::LZF, QList::LZ4)), PrintToFillCompress()); TEST_P(OptionsTest, Numbers) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.set_compr_method(method); array nums; for (unsigned i = 0; i < nums.size(); i++) { nums[i] = -5157318210846258176 + i; string val = absl::StrCat(nums[i]); ql_.Push(val, QList::TAIL); } ql_.Push("xxxxxxxxxxxxxxxxxxxx", QList::TAIL); for (unsigned i = 0; i < nums.size(); i++) { auto it = ql_.GetIterator(i); ASSERT_TRUE(it.Valid()); ASSERT_EQ(nums[i], it.Get().ival()) << i; } auto it = ql_.GetIterator(nums.size()); ASSERT_TRUE(it.Valid()); EXPECT_EQ("xxxxxxxxxxxxxxxxxxxx", it.Get().view()); } TEST_P(OptionsTest, NumbersIndex) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.set_compr_method(method); long long nums[5000]; for (int i = 0; i < 760; i++) { nums[i] = -5157318210846258176 + i; ql_.Push(absl::StrCat(nums[i]), QList::TAIL); } unsigned i = 437; QList::Iterator it = ql_.GetIterator(i); ASSERT_TRUE(it.Valid()); do { ASSERT_EQ(nums[i], it.Get().ival()); i++; } while (it.Next()); ASSERT_EQ(760, i); } TEST_P(OptionsTest, DelRangeA) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.set_compr_method(method); long long nums[5000]; for (int i = 0; i < 33; i++) { nums[i] = -5157318210846258176 + i; ql_.Push(absl::StrCat(nums[i]), QList::TAIL); } if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 2, 33, 32, 1)); } /* ltrim 3 3 (keep [3,3] inclusive = 1 remaining) */ ql_.Erase(0, 3); ql_.Erase(-29, 4000); /* make sure not loop forever */ if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 1, 1, 1, 1)); } auto it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); EXPECT_EQ(-5157318210846258173, it.Get().ival()); } TEST_P(OptionsTest, DelRangeB) { auto [fill, _, method] = GetParam(); ql_ = QList(fill, QUICKLIST_NOCOMPRESS); // ignore compress parameter ql_.set_compr_method(method); long long nums[5000]; for (int i = 0; i < 33; i++) { nums[i] = i; ql_.Push(absl::StrCat(nums[i]), QList::TAIL); } if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 2, 33, 32, 1)); } /* ltrim 5 16 (keep [5,16] inclusive = 12 remaining) */ ql_.Erase(0, 5); ql_.Erase(-16, 16); if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 1, 12, 12, 12)); } auto it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); EXPECT_EQ(5, it.Get().ival()); it = ql_.GetIterator(-1); ASSERT_TRUE(it.Valid()); EXPECT_EQ(16, it.Get().ival()); ql_.Push("bobobob", QList::TAIL); it = ql_.GetIterator(-1); ASSERT_TRUE(it.Valid()); EXPECT_EQ("bobobob", it.Get().view()); for (int i = 0; i < 12; i++) { it = ql_.GetIterator(i); ASSERT_TRUE(it.Valid()); EXPECT_EQ(i + 5, it.Get().ival()); } } TEST_P(OptionsTest, DelRangeC) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.set_compr_method(method); long long nums[5000]; for (int i = 0; i < 33; i++) { nums[i] = -5157318210846258176 + i; ql_.Push(absl::StrCat(nums[i]), QList::TAIL); } if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 2, 33, 32, 1)); } /* ltrim 3 3 (keep [3,3] inclusive = 1 remaining) */ ql_.Erase(0, 3); ql_.Erase(-29, 4000); /* make sure not loop forever */ if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 1, 1, 1, 1)); } auto it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); ASSERT_EQ(-5157318210846258173, it.Get().ival()); } TEST_P(OptionsTest, DelRangeD) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.set_compr_method(method); long long nums[5000]; for (int i = 0; i < 33; i++) { nums[i] = -5157318210846258176 + i; ql_.Push(absl::StrCat(nums[i]), QList::TAIL); } if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 2, 33, 32, 1)); } ql_.Erase(-12, 3); ASSERT_EQ(30, ql_.Size()); } TEST_P(OptionsTest, DelRangeNode) { auto [_, compress, method] = GetParam(); ql_ = QList(-2, compress); ql_.set_compr_method(method); for (int i = 0; i < 32; i++) ql_.Push(StrCat("hello", i), QList::HEAD); ASSERT_EQ(0, ql_verify(ql_, 1, 32, 32, 32)); ql_.Erase(0, 32); ASSERT_EQ(0, ql_verify(ql_, 0, 0, 0, 0)); } TEST_P(OptionsTest, DelRangeNodeOverflow) { auto [_, compress, method] = GetParam(); ql_ = QList(-2, compress); ql_.set_compr_method(method); for (int i = 0; i < 32; i++) ql_.Push(StrCat("hello", i), QList::HEAD); ASSERT_EQ(0, ql_verify(ql_, 1, 32, 32, 32)); ql_.Erase(0, 128); ASSERT_EQ(0, ql_verify(ql_, 0, 0, 0, 0)); } TEST_P(OptionsTest, DelRangeMiddle100of500) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); ASSERT_EQ(0, ql_verify(ql_, 16, 500, 32, 20)); ql_.Erase(200, 100); ASSERT_EQ(0, ql_verify(ql_, 14, 400, 32, 20)); } TEST_P(OptionsTest, DelLessFillAcrossNodes) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); ASSERT_EQ(0, ql_verify(ql_, 16, 500, 32, 20)); ql_.Erase(60, 10); ASSERT_EQ(0, ql_verify(ql_, 16, 490, 32, 20)); } TEST_P(OptionsTest, DelNegOne) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); ASSERT_EQ(0, ql_verify(ql_, 16, 500, 32, 20)); ql_.Erase(-1, 1); ASSERT_EQ(0, ql_verify(ql_, 16, 499, 32, 19)); } TEST_P(OptionsTest, DelNegOneOverflow) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); ASSERT_EQ(0, ql_verify(ql_, 16, 500, 32, 20)); ql_.Erase(-1, 128); ASSERT_EQ(0, ql_verify(ql_, 16, 499, 32, 19)); } TEST_P(OptionsTest, DelNeg100From500) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); ql_.Erase(-100, 100); QList::Iterator it = ql_.GetIterator(QList::TAIL); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello400", it.Get()); ASSERT_EQ(0, ql_verify(ql_, 13, 400, 32, 16)); } TEST_P(OptionsTest, DelMin10_5_from50) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 50; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); ASSERT_EQ(0, ql_verify(ql_, 2, 50, 32, 18)); ql_.Erase(-10, 5); ASSERT_EQ(0, ql_verify(ql_, 2, 45, 32, 13)); } TEST_P(OptionsTest, DelElems) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); const char* words[] = {"abc", "foo", "bar", "foobar", "foobared", "zap", "bar", "test", "foo"}; const char* result[] = {"abc", "foo", "foobar", "foobared", "zap", "test", "foo"}; const char* resultB[] = {"abc", "foo", "foobar", "foobared", "zap", "test"}; for (int i = 0; i < 9; i++) ql_.Push(words[i], QList::TAIL); /* lrem 0 bar */ auto iter = ql_.GetIterator(QList::HEAD); while (iter.Valid()) { if (iter.Get() == "bar") { iter = ql_.Erase(iter); // iter now points to next element, don't call Next() } else { if (!iter.Next()) break; } } EXPECT_THAT(ToItems(), ElementsAreArray(result)); ql_.Push("foo", QList::TAIL); /* lrem -2 foo */ iter = ql_.GetIterator(QList::TAIL); int del = 2; while (iter.Valid()) { if (iter.Get() == "foo") { iter = ql_.Erase(iter); del--; if (del == 0) break; // iter now points to next element, don't call Next() } else { if (!iter.Next()) break; } } /* check result of lrem -2 foo */ /* (we're ignoring the '2' part and still deleting all foo * because we only have two foo) */ EXPECT_THAT(ToItems(), ElementsAreArray(resultB)); } TEST_P(OptionsTest, IterateReverse) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i), QList::HEAD); QList::Iterator it = ql_.GetIterator(QList::TAIL); int i = 0; ASSERT_TRUE(it.Valid()); do { ASSERT_EQ(StrCat("hello", i), it.Get()); i++; } while (it.Next()); ASSERT_EQ(500, i); ASSERT_EQ(0, ql_verify(ql_, 16, 500, 20, 32)); } TEST_P(OptionsTest, Iterate500) { auto [_, compress, method] = GetParam(); ql_ = QList(32, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i), QList::HEAD); QList::Iterator it = ql_.GetIterator(QList::HEAD); int i = 499, count = 0; ASSERT_TRUE(it.Valid()); do { QList::Entry entry = it.Get(); ASSERT_EQ(StrCat("hello", i), entry); i--; count++; } while (it.Next()); EXPECT_EQ(500, count); ASSERT_EQ(0, ql_verify(ql_, 16, 500, 20, 32)); it = ql_.GetIterator(QList::TAIL); i = 0; ASSERT_TRUE(it.Valid()); do { ASSERT_EQ(StrCat("hello", i), it.Get()); i++; } while (it.Next()); EXPECT_EQ(500, i); } TEST_P(OptionsTest, IterateAfterOne) { auto [_, compress, method] = GetParam(); ql_ = QList(-2, compress); ql_.Push("hello", QList::HEAD); QList::Iterator it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); ql_.Insert(it, "abc", QList::AFTER); ASSERT_EQ(0, ql_verify(ql_, 1, 2, 2, 2)); /* verify results */ it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello", it.Get()); it = ql_.GetIterator(1); ASSERT_TRUE(it.Valid()); ASSERT_EQ("abc", it.Get()); } TEST_P(OptionsTest, IterateDelete) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.Push("abc", QList::TAIL); ql_.Push("def", QList::TAIL); ql_.Push("hij", QList::TAIL); ql_.Push("jkl", QList::TAIL); ql_.Push("oop", QList::TAIL); QList::Iterator it = ql_.GetIterator(QList::HEAD); while (it.Valid()) { if (it.Get() == "hij") { it = ql_.Erase(it); } else { it.Next(); } } ASSERT_THAT(ToItems(), ElementsAre("abc", "def", "jkl", "oop")); } TEST_P(OptionsTest, InsertBeforeOne) { auto [_, compress, method] = GetParam(); ql_ = QList(-2, compress); ql_.Push("hello", QList::HEAD); QList::Iterator it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); ql_.Insert(it, "abc", QList::BEFORE); ql_verify(ql_, 1, 2, 2, 2); /* verify results */ it = ql_.GetIterator(0); ASSERT_TRUE(it.Valid()); ASSERT_EQ("abc", it.Get()); it = ql_.GetIterator(1); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello", it.Get()); } TEST_P(OptionsTest, InsertWithHeadFull) { auto [_, compress, method] = GetParam(); ql_ = QList(4, compress); for (int i = 0; i < 10; i++) ql_.Push(StrCat("hello", i), QList::TAIL); ql_.set_fill(-1); QList::Iterator it = ql_.GetIterator(-10); ASSERT_TRUE(it.Valid()); char buf[4096] = {0}; ql_.Insert(it, string_view{buf, sizeof(buf)}, QList::BEFORE); ql_verify(ql_, 4, 11, 1, 2); } TEST_P(OptionsTest, InsertWithTailFull) { auto [_, compress, method] = GetParam(); ql_ = QList(4, compress); for (int i = 0; i < 10; i++) ql_.Push(StrCat("hello", i), QList::HEAD); ql_.set_fill(-1); QList::Iterator it = ql_.GetIterator(-1); ASSERT_TRUE(it.Valid()); char buf[4096] = {0}; ql_.Insert(it, string_view{buf, sizeof(buf)}, QList::AFTER); ql_verify(ql_, 4, 11, 2, 1); } TEST_P(OptionsTest, InsertOnceWhileIterating) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); ql_.Push("abc", QList::TAIL); ql_.set_fill(1); ql_.Push("def", QList::TAIL); ql_.set_fill(fill); ql_.Push("bob", QList::TAIL); ql_.Push("foo", QList::TAIL); ql_.Push("zoo", QList::TAIL); /* insert "bar" before "bob" while iterating over list. */ QList::Iterator it = ql_.GetIterator(QList::HEAD); if (it.Valid()) { do { if (it.Get() == "bob") { ql_.Insert(it, "bar", QList::BEFORE); break; /* didn't we fix insert-while-iterating? */ } } while (it.Next()); } EXPECT_THAT(ToItems(), ElementsAre("abc", "def", "bar", "bob", "foo", "zoo")); } TEST_P(OptionsTest, InsertBefore250NewInMiddleOf500Elements) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); for (int i = 0; i < 500; i++) { string val = StrCat("hello", i); val.resize(32); ql_.Push(val, QList::TAIL); } for (int i = 0; i < 250; i++) { QList::Iterator it = ql_.GetIterator(250); ASSERT_TRUE(it.Valid()); ql_.Insert(it, StrCat("abc", i), QList::BEFORE); } if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 25, 750, 32, 20)); } } TEST_P(OptionsTest, InsertAfter250NewInMiddleOf500Elements) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i), QList::HEAD); for (int i = 0; i < 250; i++) { QList::Iterator it = ql_.GetIterator(250); ASSERT_TRUE(it.Valid()); ql_.Insert(it, StrCat("abc", i), QList::AFTER); } ASSERT_EQ(750, ql_.Size()); if (fill == 32) { ASSERT_EQ(0, ql_verify(ql_, 26, 750, 20, 32)); } } TEST_P(OptionsTest, NextPlain) { auto [_, compress, method] = GetParam(); ql_ = QList(-2, compress); QList::SetPackedThreshold(3); const char* strings[] = {"hello1", "hello2", "h3", "h4", "hello5"}; for (int i = 0; i < 5; ++i) ql_.Push(strings[i], QList::HEAD); QList::Iterator it = ql_.GetIterator(QList::TAIL); int j = 0; ASSERT_TRUE(it.Valid()); do { ASSERT_EQ(strings[j], it.Get()); j++; } while (it.Next()); } TEST_P(OptionsTest, IndexFrom500) { auto [fill, compress, method] = GetParam(); ql_ = QList(fill, compress); for (int i = 0; i < 500; i++) ql_.Push(StrCat("hello", i + 1), QList::TAIL); QList::Iterator it = ql_.GetIterator(1); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello2", it.Get()); it = ql_.GetIterator(200); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello201", it.Get()); it = ql_.GetIterator(-1); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello500", it.Get()); it = ql_.GetIterator(-2); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello499", it.Get()); it = ql_.GetIterator(-100); ASSERT_TRUE(it.Valid()); ASSERT_EQ("hello401", it.Get()); it = ql_.GetIterator(500); ASSERT_FALSE(it.Valid()); } static void BM_QListCompress(benchmark::State& state) { SetupMalloc(); string path = base::ProgramRunfile("testdata/list.txt.zst"); io::Result src = io::OpenUncompressed(path); CHECK(src) << src.error(); io::LineReader lr(*src, TAKE_OWNERSHIP); string_view line; vector lines; while (lr.Next(&line)) { lines.push_back(string(line)); } VLOG(1) << "Read " << lines.size() << " lines " << state.range(0); while (state.KeepRunning()) { QList ql(-2, state.range(0)); // uses differrent compression modes, see below. ql.set_compr_method(state.range(1) == 0 ? QList::LZF : QList::LZ4); for (const string& l : lines) { ql.Push(l, QList::TAIL); } DVLOG(1) << ql.node_count() << ", " << ql.MallocUsed(true); } CHECK_EQ(0, zmalloc_used_memory_tl); } BENCHMARK(BM_QListCompress) ->ArgsProduct({{1, 4, 0}, {0, 1}}); // x - compression depth, y compression method. // x = 0 no compression, 1 - compress all nodes but edges, // 4 - compress all but 4 nodes from edges. static void BM_QListUncompress(benchmark::State& state) { SetupMalloc(); string path = base::ProgramRunfile("testdata/list.txt.zst"); io::Result src = io::OpenUncompressed(path); CHECK(src) << src.error(); io::LineReader lr(*src, TAKE_OWNERSHIP); string_view line; QList ql(-2, state.range(0)); ql.set_compr_method(state.range(1) == 0 ? QList::LZF : QList::LZ4); QList::stats.compression_attempts = 0; CHECK_EQ(QList::stats.compressed_bytes, 0u); CHECK_EQ(QList::stats.raw_compressed_bytes, 0u); size_t line_len = 0; while (lr.Next(&line)) { ql.Push(line, QList::TAIL); line_len += line.size(); } if (ql.compress_param() > 0) { CHECK_GT(QList::stats.compression_attempts, 0u); CHECK_GT(QList::stats.compressed_bytes, 0u); CHECK_GT(QList::stats.raw_compressed_bytes, QList::stats.compressed_bytes); } LOG(INFO) << "MallocUsed " << ql.compress_param() << ": " << ql.MallocUsed(true) << ", " << ql.MallocUsed(false); size_t exp_count = ql.Size(); while (state.KeepRunning()) { unsigned actual_count = 0, actual_len = 0; ql.Iterate( [&](const QList::Entry& e) { actual_len += e.view().size(); ++actual_count; return true; }, 0, -1); CHECK_EQ(exp_count, actual_count); CHECK_EQ(line_len, actual_len); } } BENCHMARK(BM_QListUncompress)->ArgsProduct({{1, 4, 0}, {0, 1}}); } // namespace dfly ================================================ FILE: src/core/score_map.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/score_map.h" #include "base/endian.h" #include "base/logging.h" #include "core/compact_object.h" #include "core/page_usage/page_usage_stats.h" #include "core/sds_utils.h" extern "C" { #include "redis/zmalloc.h" } using namespace std; namespace dfly { namespace { inline double GetValue(sds key) { char* valptr = key + sdslen(key) + 1; return absl::bit_cast(absl::little_endian::Load64(valptr)); } void* AllocateScored(string_view field, double value) { size_t meta_offset = field.size() + 1; // The layout is: // key, '\0', 8-byte double value sds newkey = AllocSdsWithSpace(field.size(), 8); if (!field.empty()) { memcpy(newkey, field.data(), field.size()); } absl::little_endian::Store64(newkey + meta_offset, absl::bit_cast(value)); return newkey; } } // namespace ScoreMap::~ScoreMap() { Clear(); } pair ScoreMap::AddOrUpdate(string_view field, double value) { void* newkey = AllocateScored(field, value); // Replace the whole entry. sds prev_entry = (sds)AddOrReplaceObj(newkey, false); if (prev_entry) { ObjDelete(prev_entry, false); return {newkey, false}; } return {newkey, true}; } std::pair ScoreMap::AddOrSkip(std::string_view field, double value) { uint64_t hashcode = Hash(&field, 1); void* obj = FindInternal(&field, hashcode, 1); // 1 - string_view if (obj) return {obj, false}; void* newkey = AllocateScored(field, value); DenseSet::AddUnique(newkey, false, hashcode); return {newkey, true}; } void* ScoreMap::AddUnique(std::string_view field, double value) { void* newkey = AllocateScored(field, value); DenseSet::AddUnique(newkey, false, Hash(&field, 1)); return newkey; } std::optional ScoreMap::Find(std::string_view field) { uint64_t hashcode = Hash(&field, 1); sds str = (sds)FindInternal(&field, hashcode, 1); if (!str) return nullopt; return GetValue(str); } uint64_t ScoreMap::Hash(const void* obj, uint32_t cookie) const { DCHECK_LT(cookie, 2u); if (cookie == 0) { sds s = (sds)obj; return CompactObj::HashCode(string_view{s, sdslen(s)}); } const string_view* sv = (const string_view*)obj; return CompactObj::HashCode(*sv); } bool ScoreMap::ObjEqual(const void* left, const void* right, uint32_t right_cookie) const { DCHECK_LT(right_cookie, 2u); sds s1 = (sds)left; if (right_cookie == 0) { sds s2 = (sds)right; if (sdslen(s1) != sdslen(s2)) { return false; } return sdslen(s1) == 0 || memcmp(s1, s2, sdslen(s1)) == 0; } const string_view* right_sv = (const string_view*)right; string_view left_sv{s1, sdslen(s1)}; return left_sv == (*right_sv); } size_t ScoreMap::ObjectAllocSize(const void* obj) const { sds s1 = (sds)obj; size_t res = zmalloc_usable_size(sdsAllocPtr(s1)); return res; } uint32_t ScoreMap::ObjExpireTime(const void* obj) const { // Should not reach. return UINT32_MAX; } void ScoreMap::ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) { // Should not reach. } void ScoreMap::ObjDelete(void* obj, bool has_ttl) const { sds s1 = (sds)obj; sdsfree(s1); } void* ScoreMap::ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const { return nullptr; } detail::SdsScorePair ScoreMap::iterator::BreakToPair(void* obj) { sds f = (sds)obj; return detail::SdsScorePair(f, GetValue(f)); } namespace { // Does not Release obj. Callers must do so explicitly if a `Reallocation` happened pair DuplicateEntryIfFragmented(void* obj, PageUsage* page_usage) { sds key = (sds)obj; size_t key_len = sdslen(key); if (!page_usage->IsPageForObjectUnderUtilized(key)) return {key, false}; sds newkey = AllocSdsWithSpace(key_len, 8); memcpy(newkey, key, key_len + 8 + 1); return {newkey, true}; } } // namespace bool ScoreMap::iterator::ReallocIfNeeded(PageUsage* page_usage, std::function cb) { auto* ptr = curr_entry_; if (ptr->IsLink()) { ptr = ptr->AsLink(); } DCHECK(!ptr->IsEmpty()); DCHECK(ptr->IsObject()); auto* obj = ptr->GetObject(); auto [new_obj, realloced] = DuplicateEntryIfFragmented(obj, page_usage); if (realloced) { if (cb) { cb((sds)obj, (sds)new_obj); } sdsfree((sds)obj); ptr->SetObject(new_obj); } return realloced; } } // namespace dfly ================================================ FILE: src/core/score_map.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "core/dense_set.h" extern "C" { #include "redis/sds.h" } namespace dfly { class PageUsage; namespace detail { class SdsScorePair { public: SdsScorePair(sds k, double v) : first(k), second(v) { } SdsScorePair* operator->() { return this; } const SdsScorePair* operator->() const { return this; } const sds first; const double second; }; }; // namespace detail class ScoreMap : public DenseSet { public: ScoreMap() { } ~ScoreMap(); class iterator : private DenseSet::IteratorBase { static detail::SdsScorePair BreakToPair(void* obj); public: iterator() : IteratorBase() { } iterator(DenseSet* owner, bool is_end) : IteratorBase(owner, is_end) { } detail::SdsScorePair operator->() const { void* ptr = curr_entry_->GetObject(); return BreakToPair(ptr); } detail::SdsScorePair operator*() const { void* ptr = curr_entry_->GetObject(); return BreakToPair(ptr); } // Try reducing memory fragmentation of the value by re-allocating. Returns true if // re-allocation happened. // If function is set, we call it with the old and the new sds. This is used for data // structures that hold multiple storages that need to be update simultaneously. For example, // SortedMap contains both a B+ tree and a ScoreMap with the former, containing pointers // to the later. Therefore, we need to update those. This is handled by the cb below. bool ReallocIfNeeded(PageUsage* page_usage, std::function = {}); iterator& operator++() { Advance(); return *this; } bool operator==(const iterator& b) const { return curr_list_ == b.curr_list_; } bool operator!=(const iterator& b) const { return !(*this == b); } }; // Returns pointer to the internal objest and the insertion result. // i.e. true if field was added, otherwise updates its value and returns false. std::pair AddOrUpdate(std::string_view field, double value); // Returns true if field was added // false, if already exists. In that case no update is done. std::pair AddOrSkip(std::string_view field, double value); void* AddUnique(std::string_view field, double value); bool Erase(std::string_view field) { return EraseInternal(&field, 1); } bool Erase(sds field) { return EraseInternal(field, 0); } /// @brief Returns value of the key or nullptr if key not found. /// @param key /// @return sds std::optional Find(std::string_view key); void* FindObj(std::string_view sv) { return FindInternal(&sv, Hash(&sv, 1), 1); } iterator begin() { return iterator{this, false}; } iterator end() { return iterator{this, true}; } private: uint64_t Hash(const void* obj, uint32_t cookie) const final; bool ObjEqual(const void* left, const void* right, uint32_t right_cookie) const final; size_t ObjectAllocSize(const void* obj) const final; uint32_t ObjExpireTime(const void* obj) const final; void ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) override; void ObjDelete(void* obj, bool has_ttl) const override; void* ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const final; }; } // namespace dfly ================================================ FILE: src/core/score_map_test.cc ================================================ // Copyright 2023, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "core/score_map.h" #include #include "base/gtest.h" #include "base/logging.h" #include "core/mi_memory_resource.h" #include "core/page_usage/page_usage_stats.h" extern "C" { #include "redis/zmalloc.h" } using namespace std; namespace dfly { class ScoreMapTest : public ::testing::Test { protected: static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); InitTLStatelessAllocMR(PMR_NS::get_default_resource()); } static void TearDownTestSuite() { mi_heap_collect(mi_heap_get_backing(), true); auto cb_visit = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { LOG(ERROR) << "Unfreed allocations: block_size " << block_size << ", allocated: " << area->used * block_size; return true; }; mi_heap_visit_blocks(mi_heap_get_backing(), false /* do not visit all blocks*/, cb_visit, nullptr); } ScoreMapTest() : mi_alloc_(mi_heap_get_backing()) { } void SetUp() override { sm_.reset(new ScoreMap()); } void TearDown() override { sm_.reset(); EXPECT_EQ(zmalloc_used_memory_tl, 0); } MiMemoryResource mi_alloc_; std::unique_ptr sm_; }; TEST_F(ScoreMapTest, Basic) { EXPECT_TRUE(sm_->AddOrUpdate("foo", 5).second); EXPECT_EQ(5, sm_->Find("foo")); auto it = sm_->begin(); EXPECT_STREQ("foo", it->first); EXPECT_EQ(5, it->second); ++it; EXPECT_TRUE(it == sm_->end()); for (const auto& k_v : *sm_) { EXPECT_STREQ("foo", k_v.first); EXPECT_EQ(5, k_v.second); } size_t sz = sm_->ObjMallocUsed(); EXPECT_FALSE(sm_->AddOrUpdate("foo", 17).second); EXPECT_EQ(sm_->ObjMallocUsed(), sz); it = sm_->begin(); EXPECT_EQ(17, it->second); EXPECT_FALSE(sm_->AddOrSkip("foo", 31).second); EXPECT_EQ(17, it->second); } TEST_F(ScoreMapTest, EmptyFind) { EXPECT_EQ(nullopt, sm_->Find("bar")); } uint64_t total_wasted_memory = 0; TEST_F(ScoreMapTest, ReallocIfNeeded) { auto build_str = [](size_t i) { return to_string(i) + string(131, 'a'); }; auto count_waste = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { size_t used = block_size * area->used; total_wasted_memory += area->committed - used; return true; }; for (size_t i = 0; i < 10'000; i++) { sm_->AddOrUpdate(build_str(i), i); } for (size_t i = 0; i < 10'000; i++) { if (i % 10 == 0) continue; sm_->Erase(build_str(i)); } mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_before = total_wasted_memory; size_t underutilized = 0; PageUsage page_usage{CollectPageStats::NO, 0.9}; for (auto it = sm_->begin(); it != sm_->end(); ++it) { underutilized += page_usage.IsPageForObjectUnderUtilized(it->first); it.ReallocIfNeeded(&page_usage); } // Check there are underutilized pages CHECK_GT(underutilized, 0u); total_wasted_memory = 0; mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_after = total_wasted_memory; // Check we waste significanlty less now EXPECT_GT(wasted_before, wasted_after * 2); ASSERT_EQ(sm_->UpperBoundSize(), 1000); for (size_t i = 0; i < 1000; i++) { auto res = sm_->Find(build_str(i * 10)); ASSERT_EQ(res.has_value(), true); ASSERT_EQ((size_t)*res, i * 10); } } } // namespace dfly ================================================ FILE: src/core/sds_utils.cc ================================================ // Copyright 2022, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "core/sds_utils.h" #include "base/endian.h" extern "C" { #include "redis/sds.h" #include "redis/zmalloc.h" } namespace dfly { namespace { inline char SdsReqType(size_t string_size) { if (string_size < 1 << 5) return SDS_TYPE_5; if (string_size < 1 << 8) return SDS_TYPE_8; if (string_size < 1 << 16) return SDS_TYPE_16; if (string_size < 1ll << 32) return SDS_TYPE_32; return SDS_TYPE_64; } inline int SdsHdrSize(char type) { switch (type & SDS_TYPE_MASK) { case SDS_TYPE_5: return sizeof(struct sdshdr5); case SDS_TYPE_8: return sizeof(struct sdshdr8); case SDS_TYPE_16: return sizeof(struct sdshdr16); case SDS_TYPE_32: return sizeof(struct sdshdr32); case SDS_TYPE_64: return sizeof(struct sdshdr64); } return 0; } } // namespace void SdsUpdateExpireTime(const void* obj, uint32_t time_at, uint32_t offset) { sds str = (sds)obj; char* valptr = str + sdslen(str) + 1; absl::little_endian::Store32(valptr + offset, time_at); } char* AllocSdsWithSpace(uint32_t strlen, uint32_t space) { size_t usable; char type = SdsReqType(strlen); int hdrlen = SdsHdrSize(type); char* ptr = (char*)zmalloc_usable(hdrlen + strlen + 1 + space, &usable); char* s = ptr + hdrlen; char* fp = s - 1; switch (type) { case SDS_TYPE_5: { *fp = type | (strlen << SDS_TYPE_BITS); break; } case SDS_TYPE_8: { SDS_HDR_VAR(8, s); sh->len = strlen; sh->alloc = strlen; *fp = type; break; } case SDS_TYPE_16: { SDS_HDR_VAR(16, s); sh->len = strlen; sh->alloc = strlen; *fp = type; break; } case SDS_TYPE_32: { SDS_HDR_VAR(32, s); sh->len = strlen; sh->alloc = strlen; *fp = type; break; } case SDS_TYPE_64: { SDS_HDR_VAR(64, s); sh->len = strlen; sh->alloc = strlen; *fp = type; break; } } s[strlen] = '\0'; return s; } } // namespace dfly ================================================ FILE: src/core/sds_utils.h ================================================ // Copyright 2022, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include namespace dfly { // Allocates an sds string that has an additional space at the end that // sds does is not aware of. Useful when you need to allocate immutable // sds string (keys) with metadata attached to them. char* AllocSdsWithSpace(uint32_t strlen, uint32_t space); // Updates the expire time of the sds object. The offset is the number of bytes void SdsUpdateExpireTime(const void* obj, uint32_t time_at, uint32_t offset); } // namespace dfly ================================================ FILE: src/core/search/CMakeLists.txt ================================================ gen_flex(lexer) gen_bison(parser) cur_gen_dir(gen_dir) set_source_files_properties(${gen_dir}/parser.cc PROPERTIES COMPILE_FLAGS "-Wno-maybe-uninitialized") add_library(dfly_search_core ast_expr.cc base.cc hnsw_index.cc query_driver.cc search.cc indices.cc sort_indices.cc vector_utils.cc compressed_sorted_set.cc block_list.cc renewable_quota.cc range_tree.cc synonyms.cc ${gen_dir}/parser.cc ${gen_dir}/lexer.cc) target_link_libraries(dfly_search_core dfly_page_usage base fibers2 redis_lib absl::strings TRDP::reflex TRDP::uni-algo TRDP::hnswlib Boost::headers) if(WITH_SIMSIMD) target_link_libraries(dfly_search_core TRDP::simsimd) target_compile_definitions(dfly_search_core PRIVATE WITH_SIMSIMD=1 SIMSIMD_DYNAMIC_DISPATCH=1 SIMSIMD_NATIVE_F16=$,1,0> SIMSIMD_NATIVE_BF16=$,1,0>) endif() helio_cxx_test(compressed_sorted_set_test dfly_search_core LABELS DFLY) helio_cxx_test(block_list_test dfly_search_core LABELS DFLY) helio_cxx_test(range_tree_test dfly_search_core absl::random_random LABELS DFLY) helio_cxx_test(rax_tree_test redis_test_lib LABELS DFLY) helio_cxx_test(search_parser_test dfly_search_core LABELS DFLY) helio_cxx_test(search_test redis_test_lib dfly_search_core LABELS DFLY) helio_cxx_test(mrmw_mutex_test redis_test_lib dfly_search_core fibers2 LABELS DFLY) if(WITH_SIMSIMD) target_link_libraries(search_test TRDP::simsimd) target_compile_definitions(search_test PRIVATE WITH_SIMSIMD=1 SIMSIMD_DYNAMIC_DISPATCH=1 SIMSIMD_NATIVE_F16=$,1,0> SIMSIMD_NATIVE_BF16=$,1,0>) endif() ================================================ FILE: src/core/search/ast_expr.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/ast_expr.h" #include #include #include #include #include "base/logging.h" using namespace std; namespace dfly::search { AstRangeNode::AstRangeNode(double lo, bool lo_excl, double hi, bool hi_excl) : lo{lo_excl ? nextafter(lo, hi) : lo}, hi{hi_excl ? nextafter(hi, lo) : hi} { } AstGeoNode::AstGeoNode(double lon, double lat, double radius, std::string unit) : lon(lon), lat(lat), radius(radius), unit(std::move(unit)) { } AstNegateNode::AstNegateNode(AstNode&& node) : node{make_unique(std::move(node))} { } AstLogicalNode::AstLogicalNode(AstNode&& l, AstNode&& r, LogicOp op) : op{op}, nodes{} { // If either node is already a logical node with the same op, // we can re-use it, as logical ops are associative. for (auto* node : {&l, &r}) { if (auto* ln = get_if(node); ln && ln->op == op) { *this = std::move(*ln); nodes.emplace_back(std::move(*(node == &l ? &r : &l))); return; } } nodes.emplace_back(std::move(l)); nodes.emplace_back(std::move(r)); } AstFieldNode::AstFieldNode(string field, AstNode&& node) : field{field.substr(1)}, node{make_unique(std::move(node))} { } AstTagsNode::AstTagsNode(TagValue tag) { tags = {std::move(tag)}; } AstTagsNode::AstTagsNode(AstExpr&& l, TagValue tag) { DCHECK(holds_alternative(l)); auto& tags_node = get(l); tags = std::move(tags_node.tags); tags.push_back(std::move(tag)); } AstKnnNode::AstKnnNode(uint32_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias, std::optional ef_runtime) : filter{nullptr}, limit{limit}, field{field.substr(1)}, vec{std::move(vec)}, score_alias{score_alias}, ef_runtime{ef_runtime} { } AstKnnNode::AstKnnNode(AstNode&& filter, AstKnnNode&& self) { *this = std::move(self); this->filter = make_unique(std::move(filter)); } AstVectorRangeNode::AstVectorRangeNode(std::string field, double radius, OwnedFtVector vec, std::string score_alias) : field{field.substr(1)}, radius{radius}, vec{std::move(vec)}, score_alias{std::move(score_alias)} { } bool AstKnnNode::HasPreFilter() const { // If we have pre filter knn query should not hold filter variable. It will be // moved to SearchAlgorithm::query_ variable. return filter == nullptr; } } // namespace dfly::search namespace std { ostream& operator<<(ostream& os, optional o) { return os; } ostream& operator<<(ostream& os, dfly::search::AstTagsNode::TagValueProxy o) { return os; } } // namespace std ================================================ FILE: src/core/search/ast_expr.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include "core/search/base.h" #include "core/search/tag_types.h" namespace dfly { namespace search { struct AstNode; // Matches all documents struct AstStarNode {}; // Matches all documents where this field has a non-null value struct AstStarFieldNode {}; template struct AstAffixNode { explicit AstAffixNode(std::string affix) : affix{std::move(affix)} { } std::string affix; }; using AstTermNode = AstAffixNode; using AstPrefixNode = AstAffixNode; using AstSuffixNode = AstAffixNode; using AstInfixNode = AstAffixNode; // Matches numeric range struct AstRangeNode { AstRangeNode(double lo, bool lo_excl, double hi, bool hi_excl); double lo, hi; }; struct AstGeoNode { AstGeoNode(double lon, double lat, double radius, std::string unit); double lon, lat; double radius; std::string unit; }; // Negates subtree struct AstNegateNode { AstNegateNode(AstNode&& node); AstNegateNode(const AstNegateNode&) = delete; AstNegateNode& operator=(const AstNegateNode&) = delete; AstNegateNode(AstNegateNode&&) noexcept = default; AstNegateNode& operator=(AstNegateNode&&) noexcept = default; std::unique_ptr node; }; // Applies logical operation to results of all sub-nodes struct AstLogicalNode { enum LogicOp { AND, OR }; // If either node is already a logical node with the same op, it'll be re-used. AstLogicalNode(AstNode&& l, AstNode&& r, LogicOp op); AstLogicalNode(const AstLogicalNode&) = delete; AstLogicalNode& operator=(const AstLogicalNode&) = delete; AstLogicalNode(AstLogicalNode&&) noexcept = default; AstLogicalNode& operator=(AstLogicalNode&&) noexcept = default; LogicOp op; std::vector nodes; }; // Selects specific field for subtree struct AstFieldNode { AstFieldNode(std::string field, AstNode&& node); AstFieldNode(const AstFieldNode&) = delete; AstFieldNode& operator=(const AstFieldNode&) = delete; AstFieldNode(AstFieldNode&&) noexcept = default; AstFieldNode& operator=(AstFieldNode&&) noexcept = default; std::string field; std::unique_ptr node; }; // Stores a list of tags for a tag query struct AstTagsNode { using TagValue = std::variant; struct TagValueProxy : public AstTagsNode::TagValue { // bison needs it to be default constructible TagValueProxy() : AstTagsNode::TagValue(AstTermNode("")) { } template TagValueProxy(AstAffixNode tv) : AstTagsNode::TagValue(std::move(tv)) { } }; AstTagsNode(TagValue); AstTagsNode(AstNode&& l, TagValue); std::vector tags; }; // Applies nearest neighbor search to the final result set struct AstKnnNode { AstKnnNode() = default; AstKnnNode(uint32_t limit, std::string_view field, OwnedFtVector vec, std::string_view score_alias, std::optional ef_runtime); AstKnnNode(AstNode&& sub, AstKnnNode&& self); AstKnnNode(const AstKnnNode&) = delete; AstKnnNode& operator=(const AstKnnNode&) = delete; AstKnnNode(AstKnnNode&&) noexcept = default; AstKnnNode& operator=(AstKnnNode&&) noexcept = default; friend std::ostream& operator<<(std::ostream& stream, const AstKnnNode& matrix) { return stream; } std::unique_ptr filter; size_t limit; std::string field; OwnedFtVector vec; std::string score_alias; std::optional ef_runtime; bool HasPreFilter() const; }; // Applies vector range search: returns all docs with distance(vec, doc_vec) <= radius struct AstVectorRangeNode { AstVectorRangeNode() = default; AstVectorRangeNode(std::string field, double radius, OwnedFtVector vec, std::string score_alias); AstVectorRangeNode(const AstVectorRangeNode&) = delete; AstVectorRangeNode& operator=(const AstVectorRangeNode&) = delete; AstVectorRangeNode(AstVectorRangeNode&&) noexcept = default; AstVectorRangeNode& operator=(AstVectorRangeNode&&) noexcept = default; friend std::ostream& operator<<(std::ostream& stream, const AstVectorRangeNode& /*node*/) { return stream; } std::string field; double radius; OwnedFtVector vec; std::string score_alias; }; using NodeVariants = std::variant; struct AstNode : public NodeVariants { using variant::variant; AstNode(const AstNode&) = delete; AstNode& operator=(const AstNode&) = delete; AstNode(AstNode&&) noexcept = default; AstNode& operator=(AstNode&&) noexcept = default; friend std::ostream& operator<<(std::ostream& stream, const AstNode& matrix) { return stream; } const NodeVariants& Variant() const& { return *this; } }; using AstExpr = AstNode; } // namespace search } // namespace dfly namespace std { ostream& operator<<(ostream& os, optional o); ostream& operator<<(ostream& os, dfly::search::AstTagsNode::TagValueProxy o); } // namespace std ================================================ FILE: src/core/search/base.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/base.h" #include namespace dfly::search { std::string_view QueryParams::operator[](std::string_view name) const { if (auto it = params.find(name); it != params.end()) return it->second; return ""; } std::string& QueryParams::operator[](std::string_view k) { return params[k]; } std::optional ParseNumericField(std::string_view value) { double value_as_double; if (absl::SimpleAtod(value, &value_as_double) && std::isfinite(value_as_double)) return value_as_double; return std::nullopt; } DefragmentResult& DefragmentResult::Merge(DefragmentResult&& other) { quota_depleted |= other.quota_depleted; objects_moved += other.objects_moved; return *this; } } // namespace dfly::search ================================================ FILE: src/core/search/base.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include #include namespace dfly { class PageUsage; } namespace dfly::search { struct DefragmentResult { bool quota_depleted{false}; size_t objects_moved{0}; DefragmentResult& Merge(DefragmentResult&& other); }; using DocId = uint32_t; using GlobalDocId = uint64_t; using ShardId = uint16_t; inline GlobalDocId CreateGlobalDocId(ShardId shard_id, DocId local_doc_id) { return ((uint64_t)shard_id << 32) | local_doc_id; } inline std::pair DecomposeGlobalDocId(GlobalDocId id) { return {(id >> 32), (id)&0xFFFFFFFF}; } enum class VectorSimilarity { L2, IP, COSINE }; using OwnedFtVector = std::pair, size_t /* dimension (size) */>; using BorrowedFtVector = const char*; // Query params represent named parameters for queries supplied via PARAMS. struct QueryParams { std::string_view operator[](std::string_view name) const; std::string& operator[](std::string_view k); size_t Size() const { return params.size(); } private: absl::flat_hash_map params; }; // Base class for optional search filters struct AstNode; struct OptionalFilterBase { virtual bool IsEmpty() const = 0; virtual AstNode Node(std::string field) = 0; virtual ~OptionalFilterBase() = default; }; using OptionalFilters = absl::flat_hash_map /* filter */>; // Values are either sortable as doubles or strings, or not sortable at all. using SortableValue = std::variant; // Interface for accessing document values with different data structures underneath. struct DocumentAccessor { using VectorInfo = std::variant; using StringList = absl::InlinedVector; using NumsList = absl::InlinedVector; virtual ~DocumentAccessor() = default; /* Returns nullopt if the specified field is not a list of strings */ virtual std::optional GetStrings(std::string_view active_field) const = 0; /* Returns nullopt if the specified field is not a vector */ virtual std::optional GetVector(std::string_view active_field, size_t dim) const = 0; /* Return nullopt if the specified field is not a list of doubles */ virtual std::optional GetNumbers(std::string_view active_field) const = 0; /* Same as GetStrings, but also supports boolean values */ virtual std::optional GetTags(std::string_view active_field) const = 0; }; // Base class for type-specific indices. // // Queries should be done directly on subclasses with their distinc // query functions. All results for all index types should be sorted. struct BaseIndex { virtual ~BaseIndex() = default; // Returns true if the document was added / indexed virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0; virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0; // Returns documents that have non-null values for this field (used for @field:* queries) // Result must be sorted virtual std::vector GetAllDocsWithNonNullValues() const = 0; /* Called at the end of indexes rebuilding after all initial Add calls are done. Some indices may need to finalize internal structures. See RangeTree for example. */ virtual void FinalizeInitialization() { } // Defragments the index by moving objects in underutilized pages to the current malloc page. virtual DefragmentResult Defragment(PageUsage* page_usage) { return DefragmentResult{.quota_depleted = false, .objects_moved = 0}; } }; // Base class for type-specific sorting indices. struct BaseSortIndex : BaseIndex { virtual SortableValue Lookup(DocId doc) const = 0; virtual std::vector Sort(std::vector* ids, size_t limit, bool desc) const = 0; }; /* Used in iterators of inverse indices. It is used to mark iterators that can be seeked to doc id that is greater than or equal to the specified value (method name is SeekGE(DocId min_doc_id)). This is used to optimize merging of results from different indices. See index_result.h for more details. */ struct SeekableTag {}; template void BasicSeekGE(DocId min_doc_id, const Iterator& end, Iterator* it); /* Used for converting field values to double. Returns std::nullopt if the conversion fails */ std::optional ParseNumericField(std::string_view value); /* Temporary method to create an empty std::optional in DocumentAccessor::GetString and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To suppress this false warning, we temporarily disable it around this block of code using GCC diagnostic directives. */ template std::optional EmptyAccessResult() { #if !defined(__clang__) // GCC 13.1 throws spurious warnings around this code. #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif return InlinedVector{}; #if !defined(__clang__) #pragma GCC diagnostic pop #endif } // Implementation /******************************************************************/ namespace details { inline size_t GetHighestPowerOfTwo(size_t n) { static constexpr size_t kBitsNumber = sizeof(size_t) * 8; return size_t(1) << (kBitsNumber - 1 - __builtin_clzl(n)); } } // namespace details template void BasicSeekGE(DocId min_doc_id, const Iterator& end, Iterator* it) { using Category = typename std::iterator_traits::iterator_category; auto extract_doc_id = [](const auto& value) { using T = std::decay_t; if constexpr (std::is_same_v) { return value; } else { return value.first; } }; if constexpr (std::is_base_of_v) { size_t length = std::distance(*it, end); for (size_t step = details::GetHighestPowerOfTwo(length); step > 0; step >>= 1) { if (step < length) { auto next_it = *it + step; if (extract_doc_id(*next_it) < min_doc_id) { *it = next_it; length -= step; } } } } while (*it != end && extract_doc_id(**it) < min_doc_id) { ++(*it); } } } // namespace dfly::search ================================================ FILE: src/core/search/block_list.cc ================================================ #include "core/search/block_list.h" #include "core/page_usage/page_usage_stats.h" namespace { template bool DefragmentVector(PMR_NS::vector& vec, dfly::PageUsage* page_usage) { if (vec.empty() || !page_usage->IsPageForObjectUnderUtilized(vec.data())) { return false; } PMR_NS::vector new_vec(vec.get_allocator()); new_vec.reserve(vec.size()); for (auto&& element : vec) { new_vec.push_back(std::move(element)); } vec = std::move(new_vec); return true; } } // namespace namespace dfly::search { using namespace std; SplitResult Split(BlockList>>&& block_list) { using Entry = std::pair; DCHECK(!block_list.Empty()); const size_t elements_count = block_list.Size(); // Extract values to find median std::vector entries_values(elements_count); size_t index = 0; for (const Entry& entry : block_list) { entries_values[index++] = entry.second; } // Find median value std::nth_element(entries_values.begin(), entries_values.begin() + elements_count / 2, entries_values.end()); double median_value = entries_values[elements_count / 2]; /* Now we need to split entries into two parts, left and right, so that: 1) left has values < median_value 2) right has values >= median_value 3) both parts have approximately the same number of elements To achieve this, we first split entries into three parts: < median_value (left blocklist), == median_value (median_entries), > median_value (righ blocklist). Then we add == median_value part to the smaller of the two parts (< or >). This guarantees that both parts have approximately the same number of elements */ BlockList> left(block_list.blocks_.get_allocator().resource(), block_list.block_size_); BlockList> right(block_list.blocks_.get_allocator().resource(), block_list.block_size_); absl::InlinedVector median_entries; left.ReserveBlocks(block_list.blocks_.size() / 2 + 1); right.ReserveBlocks(block_list.blocks_.size() / 2 + 1); double lmin = std::numeric_limits::infinity(), rmin = lmin; double lmax = -std::numeric_limits::infinity(), rmax = lmax; for (const Entry& entry : block_list) { if (entry.second < median_value) { left.PushBack(entry); lmin = std::min(lmin, entry.second); lmax = std::max(lmax, entry.second); } else if (entry.second > median_value) { right.PushBack(entry); rmin = std::min(rmin, entry.second); rmax = std::max(rmax, entry.second); } else { median_entries.push_back(entry); } } block_list.Clear(); if (left.Size() < right.Size()) { // If left is smaller, we can add median entries to it // We need to change median value to the right part and update lmax lmax = median_value; lmin = std::min(lmin, median_value); median_value = rmin; for (const auto& entry : median_entries) { left.Insert(entry); } } else { // If right part is smaller, we can add median entries to it // Median value is still the same rmax = std::max(rmax, median_value); for (const auto& entry : median_entries) { right.Insert(entry); } } return {std::move(left), std::move(right), median_value, lmin, lmax, rmax}; } template bool BlockList::Insert(ElementType t) { auto block = FindBlock(t); if (block == blocks_.end()) block = blocks_.insert(blocks_.end(), C{blocks_.get_allocator().resource()}); if (!block->Insert(std::move(t))) return false; size_++; TrySplit(block); return true; } template bool BlockList::PushBack(ElementType t) { // If the last block is full, after insert we will need to split it // So we can prevent split by creating a new block and inserting there if (blocks_.empty() || ShouldSplit(blocks_.back().Size() + 1)) { blocks_.insert(blocks_.end(), C{blocks_.get_allocator().resource()}); } if (!blocks_.back().Insert(std::move(t))) return false; size_++; return true; } template bool BlockList::Remove(ElementType t) { if (auto block = FindBlock(t); block != blocks_.end() && block->Remove(std::move(t))) { size_--; TryMerge(block); return true; } return false; } template DefragmentResult BlockList::Defragment(PageUsage* page_usage) { if (page_usage->QuotaDepleted()) { return DefragmentResult{.quota_depleted = true, .objects_moved = 0}; } DefragmentResult result; if (DefragmentVector(blocks_, page_usage)) { result.objects_moved += 1; } for (Container& block : blocks_) { if (result.Merge(block.Defragment(page_usage)).quota_depleted) { break; } } return result; } template typename BlockList::BlockIt BlockList::FindBlock(const ElementType& t) { DCHECK(blocks_.empty() || !blocks_.back().Empty()); if (!blocks_.empty() && t >= *blocks_.back().begin()) return --blocks_.end(); // Find first block that can't contain t auto it = std::upper_bound(blocks_.begin(), blocks_.end(), t, [](const ElementType& t, const C& l) { return *l.begin() > t; }); // Move to previous if possible if (it != blocks_.begin()) --it; DCHECK(it == blocks_.begin() || it->Size() * 2 >= block_size_); DCHECK(it == blocks_.end() || it->Size() <= 2 * block_size_); return it; } template bool BlockList::ShouldSplit(size_t block_size) const { return block_size >= block_size_ * 2; } template void BlockList::TryMerge(BlockIt block) { if (block->Size() == 0) { blocks_.erase(block); return; } if (block->Size() >= block_size_ / 2 || block == blocks_.begin()) return; // Merge strictly right with left to benefit from tail insert optimizations size_t idx = std::distance(blocks_.begin(), block); blocks_[idx - 1].Merge(std::move(*block)); blocks_.erase(block); TrySplit(blocks_.begin() + (idx - 1)); // to not overgrow it } template void BlockList::TrySplit(BlockIt block) { if (!ShouldSplit(block->Size() + 1)) return; auto [left, right] = std::move(*block).Split(); *block = std::move(right); blocks_.insert(block, std::move(left)); } template void BlockList::ReserveBlocks(size_t n) { blocks_.reserve(n); } template typename BlockList::BlockListIterator& BlockList::BlockListIterator::operator++() { ++block_it; if (block_it == block_end) { ++it; if (it != it_end) { block_it = it->begin(); block_end = it->end(); } else { block_it = {}; block_end = {}; } } return *this; } template void BlockList::BlockListIterator::SeekGE(DocId min_doc_id) { if (it == it_end) { block_it = {}; block_end = {}; return; } auto extract_doc_id = [](const auto& value) { using T = std::decay_t; if constexpr (std::is_same_v) { return value; } else { return value.first; } }; auto needed_block = [&](const auto& it) { return it->begin() != it->end() && min_doc_id <= extract_doc_id(it->Back()); }; // Choose the first block that has the last element >= min_doc_id if (!needed_block(it)) { while (++it != it_end) { if (needed_block(it)) { block_it = it->begin(); block_end = it->end(); break; } } if (it == it_end) { block_it = {}; block_end = {}; return; } } BasicSeekGE(min_doc_id, block_end, &block_it); DCHECK(block_it != block_end && min_doc_id <= extract_doc_id(*block_it)); } template class BlockList; template class BlockList>; template class BlockList>>; template bool SortedVector::Insert(T t) { if (entries_.empty() || t > entries_.back()) { entries_.push_back(t); return true; } auto it = std::lower_bound(entries_.begin(), entries_.end(), t); if (it != entries_.end() && *it == t) return false; entries_.insert(it, t); return true; } template bool SortedVector::Remove(T t) { auto it = std::lower_bound(entries_.begin(), entries_.end(), t); if (it != entries_.end() && *it == t) { entries_.erase(it); return true; } return false; } template void SortedVector::Merge(SortedVector&& other) { // NLog compexity in theory, but in practice used only to merge with larger values. // Tail insert optimization makes it linear entries_.reserve(entries_.size() + other.entries_.size()); for (T& t : other.entries_) Insert(std::move(t)); } template std::pair, SortedVector> SortedVector::Split() && { PMR_NS::vector tail(entries_.begin() + entries_.size() / 2, entries_.end()); entries_.resize(entries_.size() / 2); return std::make_pair(std::move(*this), SortedVector{std::move(tail)}); } template DefragmentResult SortedVector::Defragment(PageUsage* page_usage) { if (DefragmentVector(entries_, page_usage)) { return DefragmentResult{.quota_depleted = false, .objects_moved = 1}; } return DefragmentResult{}; } template class SortedVector; template class SortedVector>; } // namespace dfly::search ================================================ FILE: src/core/search/block_list.h ================================================ #pragma once #include #include #include #include #include "core/search/base.h" #include "core/search/compressed_sorted_set.h" namespace dfly::search { // Forward declarations struct SplitResult; template class BlockList; template class SortedVector; /* Split into two blocks, left and right, so that both blocks have approximately the same number of elements. Returns median value of the split. Garantees that median present in the right block and not present in the left block. Does not work for empty BlockList. */ // TODO: Move to RangeTree logic SplitResult Split(BlockList>>&& result); // BlockList is a container wrapper for CompressedSortedSet / vector // to divide the full sorted id range into separate blocks. This reduces modification // complexity from O(N) to O(logN + K), where K is the max block size. // // It tries to balance block sizes in the range [block_size / 2, block_size * 2] // by splitting or merging nodes when needed. // container must have declare ElementType typename template class BlockList { private: using BlockIt = typename PMR_NS::vector::iterator; using ConstBlockIt = typename PMR_NS::vector::const_iterator; using ElementType = typename Container::ElementType; public: BlockList(PMR_NS::memory_resource* mr, size_t block_size = 1000) : block_size_{block_size}, blocks_(mr) { } BlockList(const BlockList& other) = default; BlockList(BlockList&& other) noexcept { // Consider not to do move if block_size_ is different // DCHECK(block_size_ == other.block_size_); // It seams there is bugs in BaseStringIndex // because this check fails for it size_ = other.size_; blocks_ = std::move(other.blocks_); other.Clear(); } BlockList& operator=(const BlockList& other) = delete; BlockList& operator=(BlockList&& other) = delete; ~BlockList() = default; // Insert element, returns true if inserted, false if already present. bool Insert(ElementType t); bool PushBack(ElementType t); // Remove element, returns true if removed, false if not found. bool Remove(ElementType t); size_t Size() const { return size_; } size_t size() const { return size_; } bool Empty() const { return size_ == 0; } void Clear() { size_ = 0; blocks_.clear(); } struct BlockListIterator : public SeekableTag { // To make it work with std container contructors using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = ElementType; using pointer = ElementType*; using reference = ElementType&; ElementType operator*() const { return *block_it; } BlockListIterator& operator++(); void SeekGE(DocId min_doc_id); friend class BlockList; bool operator==(const BlockListIterator& other) const { return it == other.it && block_it == other.block_it; } bool operator!=(const BlockListIterator& other) const { return !operator==(other); } private: BlockListIterator(ConstBlockIt begin, ConstBlockIt end) : it(begin), it_end(end) { if (it != it_end) { block_it = it->begin(); block_end = it->end(); } } ConstBlockIt it, it_end; typename Container::iterator block_it, block_end; }; BlockListIterator begin() const { return BlockListIterator{blocks_.begin(), blocks_.end()}; } BlockListIterator end() const { return BlockListIterator{blocks_.end(), blocks_.end()}; } DefragmentResult Defragment(PageUsage* page_usage); private: // Find block that should contain t. Returns end() only if empty BlockIt FindBlock(const ElementType& t); bool ShouldSplit(size_t block_size) const; void TryMerge(BlockIt block); // If needed, merge with previous block void TrySplit(BlockIt block); // If needed, split into two blocks void ReserveBlocks(size_t n); friend SplitResult Split(BlockList>>&& block_list); private: const size_t block_size_ = 1000; size_t size_ = 0; PMR_NS::vector blocks_; }; // Supports Insert and Remove operations for keeping a sorted vector internally. // Wrapper to use vectors with BlockList template class SortedVector { public: using ElementType = T; explicit SortedVector(PMR_NS::memory_resource* mr) : entries_(mr) { } bool Insert(T t); bool Remove(T t); void Merge(SortedVector&& other); std::pair, SortedVector> Split() &&; T& operator[](size_t idx) { return entries_[idx]; } const T& operator[](size_t idx) const { return entries_[idx]; } size_t Size() const { return entries_.size(); } bool Empty() const { return entries_.empty(); } void Clear() { entries_.clear(); } const T& Back() const { return entries_.back(); } using iterator = typename PMR_NS::vector::const_iterator; iterator begin() const { return entries_.cbegin(); } iterator end() const { return entries_.cend(); } DefragmentResult Defragment(PageUsage* page_usage); private: SortedVector(PMR_NS::vector&& v) : entries_{std::move(v)} { } PMR_NS::vector entries_; }; extern template class SortedVector; extern template class SortedVector>; extern template class BlockList; extern template class BlockList>; extern template class BlockList>>; // Used by Split method struct SplitResult { using Container = BlockList>>; Container left; Container right; // Median value of split, used as minimum value of right block double median; // Min/max values of left (lmin, lmax) and right (rmin=median, rmax) blocks double lmin, lmax, rmax; }; } // namespace dfly::search ================================================ FILE: src/core/search/block_list_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/block_list.h" #include #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" namespace dfly::search { using namespace std; template class TemplatedBlockListTest : public testing::Test { private: using NumericType = long long; public: using ElementType = typename C::ElementType; auto Make() { // Create list with small block size to test blocking mechanism more extensively return BlockList{PMR_NS::get_default_resource(), 10}; } auto AddNewBlockListElement(DocId doc_id) { if constexpr (std::is_same_v) { return ElementType{doc_id}; } else { static_assert(std::is_same_v>, "Unsupported ElementType for BlockListTest"); const NumericType number = dist_(rnd_); id_to_values_[doc_id].push_back(number); return ElementType{doc_id, static_cast(number)}; } } auto RemoveBlockListElement(DocId doc_id) { if constexpr (std::is_same_v) { return ElementType{doc_id}; } else { static_assert(std::is_same_v>, "Unsupported ElementType for BlockListTest"); const NumericType number = id_to_values_[doc_id].back(); id_to_values_[doc_id].pop_back(); return ElementType{doc_id, static_cast(number)}; } } DocId GetDocId(const ElementType& element) { if constexpr (std::is_same_v) { return element; } else { static_assert(std::is_same_v>, "Unsupported ElementType for GetDocId"); return element.first; } } private: // Used to save doubles for std::pair std::unordered_map> id_to_values_; // Used to generate random numbers for std::pair default_random_engine rnd_; uniform_int_distribution dist_{std::numeric_limits::min(), std::numeric_limits::max()}; }; using ContainerTypes = ::testing::Types, SortedVector>>; TYPED_TEST_SUITE(TemplatedBlockListTest, ContainerTypes); TYPED_TEST(TemplatedBlockListTest, LoopMidInsertErase) { using ElementType = typename TypeParam::ElementType; const size_t kNumElements = 50; auto list = this->Make(); for (size_t i = 0; i < kNumElements / 2; i++) { list.Insert(this->AddNewBlockListElement(i)); list.Insert(this->AddNewBlockListElement(i + kNumElements / 2)); } vector out(list.begin(), list.end()); ASSERT_EQ(list.Size(), kNumElements); ASSERT_EQ(out.size(), kNumElements); for (size_t i = 0; i < kNumElements; i++) ASSERT_EQ(this->GetDocId(out[i]), i); for (size_t i = 0; i < kNumElements / 2; i++) { list.Remove(this->RemoveBlockListElement(i)); list.Remove(this->RemoveBlockListElement(i + kNumElements / 2)); } out = {list.begin(), list.end()}; EXPECT_EQ(out.size(), 0u); } TYPED_TEST(TemplatedBlockListTest, InsertReverseRemoveSteps) { using ElementType = typename TypeParam::ElementType; const size_t kNumElements = 1000; auto list = this->Make(); for (size_t i = 0; i < kNumElements; i++) { list.Insert(this->AddNewBlockListElement(kNumElements - i - 1)); } for (size_t deleted_pref = 0; deleted_pref < 10; deleted_pref++) { vector out{list.begin(), list.end()}; reverse(out.begin(), out.end()); EXPECT_EQ(out.size(), kNumElements / 10 * (10 - deleted_pref)); for (size_t i = 0; i < kNumElements; i++) { if (i % 10 >= deleted_pref) { EXPECT_EQ(this->GetDocId(out.back()), DocId(i)); out.pop_back(); } } for (size_t i = 0; i < kNumElements; i++) { if (i % 10 == deleted_pref) list.Remove(this->RemoveBlockListElement(i)); } } EXPECT_EQ(list.Size(), 0u); } TYPED_TEST(TemplatedBlockListTest, RandomNumbers) { using ElementType = typename TypeParam::ElementType; const size_t kNumIterations = 1'000; auto list = this->Make(); std::set list_copy; for (size_t i = 0; i < kNumIterations; i++) { if (list_copy.size() > 100 && rand() % 5 == 0) { auto it = list_copy.begin(); std::advance(it, rand() % list_copy.size()); list.Remove(*it); list_copy.erase(it); } else { const ElementType t = this->AddNewBlockListElement(rand() % 1'000'000); list.Insert(t); list_copy.insert(t); } ASSERT_TRUE(std::equal(list.begin(), list.end(), list_copy.begin(), list_copy.end())); } } class BlockListTest : public testing::Test { protected: }; TEST_F(BlockListTest, Split) { BlockList>> bl{PMR_NS::get_default_resource(), 20}; const size_t max_value = 100.0; const size_t step = 23.0; size_t value = max_value; for (size_t i = 0; i < 100; i++) { bl.Insert({i, static_cast(value)}); value = (max_value + value - step) % max_value; } auto split_result = Split(std::move(bl)); auto& left = split_result.left; auto& right = split_result.right; EXPECT_EQ(left.Size(), 50); EXPECT_EQ(right.Size(), 50); // Test that all values in the left part are less than or equal to max_value for (const auto& [_, left_value] : left) { for (const auto& [__, right_value] : right) { EXPECT_LE(left_value, right_value); } } double median = split_result.median; // Test that left part values do not have this median for (const auto& [_, left_value] : left) { EXPECT_NE(left_value, median); } // Test that right part values do have this median bool is_median_found = false; for (const auto& [_, right_value] : right) { if (right_value == median) { is_median_found = true; break; } } EXPECT_TRUE(is_median_found); // Test that doc_ids in both parts are sorted DocId prev_doc_id = std::numeric_limits::min(); for (const auto& [doc_id, _] : left) { EXPECT_GE(doc_id, prev_doc_id); prev_doc_id = doc_id; } prev_doc_id = std::numeric_limits::min(); for (const auto& [doc_id, _] : right) { EXPECT_GE(doc_id, prev_doc_id); prev_doc_id = doc_id; } } TEST_F(BlockListTest, SplitHard) { // First test 70 values on the left and 30 on the right BlockList>> bl1{PMR_NS::get_default_resource(), 20}; for (size_t i = 0; i < 70; i++) { bl1.Insert({i, 1.0}); } for (size_t i = 70; i < 100; i++) { bl1.Insert({i, 2.0}); } auto split_result1 = Split(std::move(bl1)); EXPECT_EQ(split_result1.median, 2.0); EXPECT_EQ(split_result1.left.Size(), 70u); EXPECT_EQ(split_result1.right.Size(), 30u); for (const auto& [_, value] : split_result1.left) { EXPECT_EQ(value, 1.0); } for (const auto& [_, value] : split_result1.right) { EXPECT_EQ(value, 2.0); } // Now test 30 values on the left and 70 on the right BlockList>> bl2{PMR_NS::get_default_resource(), 20}; for (size_t i = 0; i < 30; i++) { bl2.Insert({i, 1.0}); } for (size_t i = 30; i < 100; i++) { bl2.Insert({i, 2.0}); } auto split_result2 = Split(std::move(bl2)); EXPECT_EQ(split_result2.median, 2.0); EXPECT_EQ(split_result2.left.Size(), 30u); EXPECT_EQ(split_result2.right.Size(), 70u); for (const auto& [_, value] : split_result2.left) { EXPECT_EQ(value, 1.0); } for (const auto& [_, value] : split_result2.right) { EXPECT_EQ(value, 2.0); } } TEST_F(BlockListTest, SplitSingleDoubleValue) { BlockList>> bl{PMR_NS::get_default_resource(), 20}; for (size_t i = 0; i < 100; i++) { bl.Insert({i, 1.0}); } auto split_result = Split(std::move(bl)); auto& left = split_result.left; auto& right = split_result.right; EXPECT_EQ(left.Size(), 0u); EXPECT_EQ(right.Size(), 100u); EXPECT_EQ(split_result.median, 1.0); } static void BM_Erase90PctTail(benchmark::State& state) { BlockList bl{PMR_NS::get_default_resource()}; unsigned size = state.range(0); for (size_t i = 0; i < size; i++) bl.Insert(i); size_t base = size / 10; size_t i = 0; while (state.KeepRunning()) { benchmark::DoNotOptimize(bl.Remove(base + i)); i = (i + 1) % (size * 9 / 10); } } BENCHMARK(BM_Erase90PctTail)->Args({100'000}); } // namespace dfly::search ================================================ FILE: src/core/search/compressed_sorted_set.cc ================================================ #include "core/search/compressed_sorted_set.h" #include #include #include "absl/types/span.h" #include "base/flit.h" #include "base/logging.h" namespace dfly::search { using namespace std; namespace { using VarintBuffer = array; } // namespace CompressedSortedSet::CompressedSortedSet(PMR_NS::memory_resource* mr) : diffs_{mr} { } CompressedSortedSet::ConstIterator::ConstIterator(const CompressedSortedSet& list) : stash_{}, diffs_{list.diffs_} { ReadNext(); } CompressedSortedSet::IntType CompressedSortedSet::ConstIterator::operator*() const { DCHECK(stash_); return *stash_; } CompressedSortedSet::ConstIterator& CompressedSortedSet::ConstIterator::operator++() { ReadNext(); return *this; } bool operator==(const CompressedSortedSet::ConstIterator& l, const CompressedSortedSet::ConstIterator& r) { return l.diffs_.data() == r.diffs_.data() && l.diffs_.size() == r.diffs_.size(); } bool operator!=(const CompressedSortedSet::ConstIterator& l, const CompressedSortedSet::ConstIterator& r) { return !(l == r); } void CompressedSortedSet::ConstIterator::ReadNext() { if (diffs_.empty()) { stash_ = nullopt; last_read_ = {nullptr, 0}; diffs_ = {nullptr, 0}; return; } IntType base = stash_.value_or(0); auto [diff, read] = CompressedSortedSet::ReadVarLen(diffs_); stash_ = base + diff; last_read_ = diffs_.subspan(0, read); diffs_.remove_prefix(read); } CompressedSortedSet::ConstIterator CompressedSortedSet::begin() const { return ConstIterator{*this}; } CompressedSortedSet::ConstIterator CompressedSortedSet::end() const { return ConstIterator{}; } // Simply encode difference and add to end of diffs array void CompressedSortedSet::PushBackDiff(IntType diff) { size_++; VarintBuffer buf; auto diff_span = WriteVarLen(diff, absl::MakeSpan(buf)); diffs_.insert(diffs_.end(), diff_span.begin(), diff_span.end()); } // Do a linear scan by encoding all diffs to find value CompressedSortedSet::EntryLocation CompressedSortedSet::LowerBound(IntType value) const { auto it = begin(), prev_it = end(), next_it = end(); while (it != end()) { next_it = it; if (*it >= value || ++next_it == end()) break; prev_it = it; it = next_it; } return EntryLocation{.value = it.stash_.value_or(0), .prev_value = prev_it.stash_.value_or(0), .diff_span = it.last_read_}; } // Insert has linear complexity. It tries to find between which two entries A and B the new value V // needs to be inserted. Then it computes the differences dif1 = V - A and diff2 = B - V that need // to be stored to encode the triple A V B. Those are stored where diff0 = B - A was previously // stored, possibly extending the vector bool CompressedSortedSet::Insert(IntType value) { if (tail_value_ && *tail_value_ == value) return false; if (tail_value_ && value > *tail_value_) { PushBackDiff(value - *tail_value_); tail_value_ = value; return true; } auto bound = LowerBound(value); // At least one element was read and it's equal to value: return to avoid duplicate if (bound.value == value && !bound.diff_span.empty()) return false; // Value is bigger than any other (or list is empty): append required diff at the end if (value > bound.value || bound.diff_span.empty()) { PushBackDiff(value - bound.value); tail_value_ = value; return true; } size_++; // Now the list certainly contains the bound B > V and possibly A < V (or 0 by default), // so we need to encode both differences diff1 and diff2 DCHECK_GT(bound.value, value); DCHECK_LE(bound.prev_value, value); // Compute and encode new diff1 and diff2 into buf1 and buf2 respectivaly VarintBuffer buf1, buf2; auto diff1_span = WriteVarLen(value - bound.prev_value, absl::MakeSpan(buf1)); auto diff2_span = WriteVarLen(bound.value - value, absl::MakeSpan(buf2)); // Extend the location where diff0 is stored with optional zeros before overwriting it ptrdiff_t diff_offset = bound.diff_span.data() - diffs_.data(); size_t required_len = diff1_span.size() + diff2_span.size(); DCHECK_LE(bound.diff_span.size(), required_len); // It can't shrink for sure diffs_.insert(diffs_.begin() + diff_offset, required_len - bound.diff_span.size(), 0u); // Now overwrite diff0 and 0s with the two new differences copy(diff1_span.begin(), diff1_span.end(), diffs_.begin() + diff_offset); copy(diff2_span.begin(), diff2_span.end(), diffs_.begin() + diff_offset + diff1_span.size()); return true; } // Remove has linear complexity. It tries to find the element V and its neighbors A and B, // which are encoded as diff1 = V - A and diff2 = B - V. Adjacently stored diff1 and diff2 // need to be replaced with diff3 = diff1 + diff2s bool CompressedSortedSet::Remove(IntType value) { auto bound = LowerBound(value); // Nothing was read or the element was not found if (bound.diff_span.empty() || bound.value != value) return false; // We're removing below unconditionally size_--; // Calculate offset where values diff is stored and determine diffs tail ptrdiff_t diff_offset = bound.diff_span.data() - diffs_.data(); auto diffs_tail = absl::MakeSpan(diffs_).subspan(diff_offset + bound.diff_span.size()); // If it's stored at the end, simply truncate it away if (diffs_tail.empty()) { diffs_.resize(diffs_.size() - bound.diff_span.size()); tail_value_ = bound.prev_value; if (diffs_.empty()) tail_value_ = nullopt; return true; } // Now the list certainly contains a succeeding element B > V and possibly A < V (or 0) // Read diff2 and calculate diff3 = diff1 + diff2 auto [diff2, diff2_read] = ReadVarLen(diffs_tail); IntType diff3 = (bound.value - bound.prev_value) + diff2; // Encode diff3 VarintBuffer buf; auto diff3_buf = WriteVarLen(diff3, absl::MakeSpan(buf)); // Shrink vector before overwriting DCHECK_LE(diff3_buf.size(), diff2_read + bound.diff_span.size()); size_t to_remove = diff2_read + bound.diff_span.size() - diff3_buf.size(); diffs_.erase(diffs_.begin() + diff_offset, diffs_.begin() + diff_offset + to_remove); // Overwrite diff1/diff2 with new diff3 copy(diff3_buf.begin(), diff3_buf.end(), diffs_.begin() + diff_offset); return true; } void CompressedSortedSet::Merge(CompressedSortedSet&& other) { // Quadratic compexity in theory, but in practice used only to merge with larger values. // Tail insert optimization makes it linear for (int v : other) Insert(v); } std::pair CompressedSortedSet::Split() && { DCHECK_GT(Size(), 5u); CompressedSortedSet second(diffs_.get_allocator().resource()); // Move iterator to middle position and save size of diffs tail auto it = begin(); std::advance(it, (size_ - 1) / 2); // Save last value in the first set tail_value_ = *it; ++it; size_t keep_bytes = it.last_read_.data() - diffs_.data(); // Copy second half into second set for (; it != end(); ++it) second.Insert(*it); // Erase diffs tail diffs_.resize(keep_bytes); size_ -= second.Size(); return std::make_pair(std::move(*this), std::move(second)); } // The leftmost three bits of the first byte store the number of additional bytes. All following // bits store the number itself. absl::Span CompressedSortedSet::WriteVarLen(IntType value, absl::Span buf) { // TODO: fix flit encoding of large numbers size_t written = base::flit::EncodeT(static_cast(value), buf.data()); return buf.first(written); } std::pair CompressedSortedSet::ReadVarLen( absl::Span source) { uint64_t out = 0; size_t read = 0; // We need this because ParseT may read 8 bytes even if source can be less than that // due to the encoding and we end up accessing an invalid memory location. // (not really a bug because ParseT ignores the extra bytes it reads). if (source.size() < 8) { VarintBuffer ranged_source{0}; memcpy(&ranged_source, source.data(), source.size()); read = base::flit::ParseT(ranged_source.data(), &out); } else { read = base::flit::ParseT(source.data(), &out); } CHECK_LE(out, numeric_limits::max()); return {out, read}; } } // namespace dfly::search ================================================ FILE: src/core/search/compressed_sorted_set.h ================================================ #pragma once #include #include #include #include #include #include "base/logging.h" #include "base/pmr/memory_resource.h" #include "core/search/base.h" namespace dfly::search { // A list of sorted unique integers with reduced memory usage. // Only differences between successive elements are stored // in a variable length encoding. class CompressedSortedSet { public: using IntType = DocId; using ElementType = IntType; // Const access iterator that decodes the compressed list on traversal struct ConstIterator { friend class CompressedSortedSet; // To make it work with std container contructors using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = IntType; using pointer = IntType*; using reference = IntType&; IntType operator*() const; ConstIterator& operator++(); friend class CompressedSortedSet; friend bool operator==(const ConstIterator& l, const ConstIterator& r); friend bool operator!=(const ConstIterator& l, const ConstIterator& r); ConstIterator() = default; private: explicit ConstIterator(const CompressedSortedSet& list); void ReadNext(); // Decode next value to stash std::optional stash_{}; absl::Span last_read_{}; absl::Span diffs_{}; }; using iterator = ConstIterator; public: explicit CompressedSortedSet(PMR_NS::memory_resource* mr); ConstIterator begin() const; ConstIterator end() const; bool Insert(IntType value); // Insert arbitrary element, needs to scan whole list bool Remove(IntType value); // Remove arbitrary element, needs to scan whole list size_t Size() const { return size_; } size_t ByteSize() const { return diffs_.size(); } bool Empty() const { return size_ == 0; } void Clear() { size_ = 0; tail_value_.reset(); diffs_.clear(); } // Add all values from other void Merge(CompressedSortedSet&& other); // Split into two equally sized halves std::pair Split() &&; IntType Back() const { DCHECK(!Empty() && tail_value_.has_value()); return tail_value_.value(); } static DefragmentResult Defragment([[maybe_unused]] PageUsage* page_usage) { return {}; } private: struct EntryLocation { IntType value; // Value or 0 IntType prev_value; // Preceding value or 0 absl::Span diff_span; // Location of value encoded diff, empty if none read }; private: // Find EntryLocation of first entry that is not less than value (std::lower_bound) EntryLocation LowerBound(IntType value) const; // Push back difference without any decoding. Used only for efficient construction from sorted // list void PushBackDiff(IntType diff); // Encode integer with variable length encoding into buf and return written subspan static absl::Span WriteVarLen(IntType value, absl::Span buf); // Decode integer with variable length encoding from source static std::pair ReadVarLen(absl::Span source); private: uint32_t size_{0}; std::optional tail_value_{}; std::vector> diffs_; }; } // namespace dfly::search ================================================ FILE: src/core/search/compressed_sorted_set_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/compressed_sorted_set.h" #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/bptree_set.h" namespace dfly::search { using namespace std; namespace { struct SetInserter { using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = CompressedSortedSet::IntType; using pointer = value_type*; using reference = value_type&; explicit SetInserter(CompressedSortedSet* set) : set_{set} {}; SetInserter& operator*() { return *this; } SetInserter& operator++() { return *this; } SetInserter& operator=(value_type value) { set_->Insert(value); return *this; } private: CompressedSortedSet* set_; }; } // namespace class CompressedSortedSetTest : public ::testing::Test { protected: }; using IdVec = vector; TEST_F(CompressedSortedSetTest, BasicInsert) { CompressedSortedSet list{PMR_NS::get_default_resource()}; IdVec list_copy; auto current = [&list]() { return IdVec{list.begin(), list.end()}; }; auto add = [&list, &list_copy](uint32_t value) { list.Insert(value); set list_copy_set{list_copy.begin(), list_copy.end()}; list_copy_set.insert(value); list_copy = IdVec{list_copy_set.begin(), list_copy_set.end()}; }; // Check empty list is empty EXPECT_EQ(current(), list_copy); // Insert some numbers in sorted order add(10); EXPECT_EQ(current(), list_copy); add(15); EXPECT_EQ(current(), list_copy); add(22); EXPECT_EQ(current(), list_copy); add(25); add(31); EXPECT_EQ(current(), list_copy); // Now insert front add(7); EXPECT_EQ(current(), list_copy); add(2); EXPECT_EQ(current(), list_copy); // Insert in-between add(13); EXPECT_EQ(current(), list_copy); add(23); add(19); EXPECT_EQ(current(), list_copy); add(30); add(27); EXPECT_EQ(current(), list_copy); // Now add some numbers in reverse order add(41); add(40); add(37); add(34); EXPECT_EQ(current(), list_copy); // Now add a 0 add(0); EXPECT_EQ(current(), list_copy); // Make sure all test integers fit into a single byte EXPECT_EQ(list.ByteSize(), list.Size()); } TEST_F(CompressedSortedSetTest, BasicInsertLargeValues) { CompressedSortedSet list{PMR_NS::get_default_resource()}; IdVec list_copy; const uint32_t kBase = 1'000'000'000; // Add big integers in reverse order uint32_t base = kBase; while (base > 0) { list.Insert(base); list_copy.insert(list_copy.begin(), base); base /= 10; } EXPECT_EQ(IdVec(list.begin(), list.end()), list_copy); // Now add neighboring integers with an offset of one base = kBase; while (base > 0) { list.Insert(base + 1); list_copy.push_back(base + 1); base /= 10; } sort(list_copy.begin(), list_copy.end()); EXPECT_EQ(IdVec(list.begin(), list.end()), list_copy); // Make sure we use at least twice less memory EXPECT_LE(list.ByteSize() * 2, list.Size() * sizeof(uint32_t)); } TEST_F(CompressedSortedSetTest, SortedBackInserter) { CompressedSortedSet list{PMR_NS::get_default_resource()}; vector v1 = {1, 3, 5}; vector v2 = {2, 4, 6}; merge(v1.begin(), v1.end(), v2.begin(), v2.end(), SetInserter{&list}); EXPECT_EQ(IdVec(list.begin(), list.end()), IdVec({1, 2, 3, 4, 5, 6})); } TEST_F(CompressedSortedSetTest, BasicRemove) { CompressedSortedSet list{PMR_NS::get_default_resource()}; IdVec values = {1, 3, 4, 7, 8, 11, 15, 17, 20, 22, 27}; copy(values.begin(), values.end(), SetInserter{&list}); EXPECT_EQ(IdVec(list.begin(), list.end()), values); auto remove = [&list, &values](uint32_t value) { values.erase(find(values.begin(), values.end(), value)); list.Remove(value); }; // Remove back and front remove(27); EXPECT_EQ(IdVec(list.begin(), list.end()), values); remove(1); EXPECT_EQ(IdVec(list.begin(), list.end()), values); // Remove from middle remove(11); remove(4); EXPECT_EQ(IdVec(list.begin(), list.end()), values); remove(17); remove(8); EXPECT_EQ(IdVec(list.begin(), list.end()), values); // Remove non existing list.Remove(16); EXPECT_EQ(IdVec(list.begin(), list.end()), values); } TEST_F(CompressedSortedSetTest, BasicRemoveLargeValues) { CompressedSortedSet list{PMR_NS::get_default_resource()}; IdVec values = {1, 12, 123, 123'4, 123'45, 123'456, 1'234'567, 12'345'678}; copy(values.begin(), values.end(), SetInserter{&list}); EXPECT_EQ(IdVec(list.begin(), list.end()), values); auto remove = [&list, &values](uint32_t value) { values.erase(find(values.begin(), values.end(), value)); list.Remove(value); }; // Remove from middle remove(123'45); EXPECT_EQ(IdVec(list.begin(), list.end()), values); remove(12); EXPECT_EQ(IdVec(list.begin(), list.end()), values); remove(1'234'567); EXPECT_EQ(IdVec(list.begin(), list.end()), values); // Remove front remove(1); EXPECT_EQ(IdVec(list.begin(), list.end()), values); // Remove back remove(12'345'678); EXPECT_EQ(IdVec(list.begin(), list.end()), values); } TEST_F(CompressedSortedSetTest, InsertRemoveLargeValues) { CompressedSortedSet list{PMR_NS::get_default_resource()}; for (int shift = 3; shift < 30; shift++) { uint32_t value = 1u << shift; IdVec values{value + 3, value, value - 5}; for (auto v : values) list.Insert(v); sort(values.begin(), values.end()); EXPECT_EQ(IdVec(list.begin(), list.end()), values); for (auto v : values) list.Remove(v); EXPECT_EQ(IdVec(list.begin(), list.end()), IdVec({})); } } } // namespace dfly::search ================================================ FILE: src/core/search/hnsw_alg.h ================================================ // This file is copied from hnswlib and modified to fit Dragonfly's needs. #include #include #include #pragma once namespace dfly::search { enum class HnswErrorStatus : int8_t { SUCCESS = 0, /* markDelete errors */ LABEL_NOT_FOUND, ELEMENT_ALREADY_DELETED, }; template class HierarchicalNSW : public hnswlib::AlgorithmInterface { public: using tableint = hnswlib::tableint; using labeltype = hnswlib::labeltype; using linklistsizeint = hnswlib::linklistsizeint; using VisitedListPool = hnswlib::VisitedListPool; using vl_type = hnswlib::vl_type; using VisitedList = hnswlib::VisitedList; using BaseFilterFunctor = hnswlib::BaseFilterFunctor; using BaseSearchStopCondition = hnswlib::BaseSearchStopCondition; static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; static const unsigned char DELETE_MARK = 0x01; size_t max_elements_{0}; mutable std::atomic cur_element_count{0}; // current number of elements size_t size_data_per_element_{0}; size_t size_links_per_element_{0}; mutable std::atomic num_deleted_{0}; // number of deleted elements size_t M_{0}; size_t maxM_{0}; size_t maxM0_{0}; size_t ef_construction_{0}; size_t ef_{0}; double mult_{0.0}, revSize_{0.0}; int maxlevel_{0}; std::unique_ptr visited_list_pool_{nullptr}; // Locks operations with element by label value mutable std::vector label_op_locks_; std::mutex global; std::vector link_list_locks_; tableint enterpoint_node_{0}; size_t size_links_level0_{0}; size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{0}; char* data_level0_memory_{nullptr}; // Level 0 memory block. Contains links + ptr to data + label char* data_vector_memory_{nullptr}; // Memory block for copied vectors char** linkLists_{nullptr}; std::vector element_levels_; // keeps level of each element size_t data_size_{0}; hnswlib::DISTFUNC fstdistfunc_; void* dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ std::unordered_map label_lookup_; std::default_random_engine level_generator_; std::default_random_engine update_probability_generator_; mutable std::atomic metric_distance_computations{0}; mutable std::atomic metric_hops{0}; bool copy_vector_ = true; bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions std::mutex deleted_elements_lock; // lock for deleted_elements std::unordered_set deleted_elements; // contains internal ids of deleted elements HierarchicalNSW(hnswlib::SpaceInterface* s) { } HierarchicalNSW(hnswlib::SpaceInterface* s, const std::string& location, bool nmslib = false, size_t max_elements = 0, bool allow_replace_deleted = false) : allow_replace_deleted_(allow_replace_deleted) { loadIndex(location, s, max_elements); } HierarchicalNSW(hnswlib::SpaceInterface* s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100, bool copy_vector = true, bool allow_replace_deleted = false) : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), link_list_locks_(max_elements), element_levels_(max_elements), copy_vector_(copy_vector), allow_replace_deleted_(allow_replace_deleted) { max_elements_ = max_elements; num_deleted_ = 0; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); if (M <= 10000) { M_ = M; } else { HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; M_ = 10000; } maxM_ = M_; maxM0_ = M_ * 2; ef_construction_ = std::max(ef_construction, M_); ef_ = 10; level_generator_.seed(random_seed); update_probability_generator_.seed(random_seed + 1); // If we copy vector we don't use pointer to data size_t vector_ptr_size = copy_vector_ ? 0 : sizeof(char*); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); size_data_per_element_ = size_links_level0_ + vector_ptr_size + sizeof(labeltype); offsetData_ = size_links_level0_; label_offset_ = size_links_level0_ + vector_ptr_size; offsetLevel0_ = 0; data_level0_memory_ = (char*)mi_malloc(max_elements_ * size_data_per_element_); if (data_level0_memory_ == nullptr) throw std::runtime_error("Not enough memory"); if (copy_vector) { data_vector_memory_ = (char*)mi_malloc(max_elements_ * data_size_); if (data_vector_memory_ == nullptr) throw std::runtime_error("Not enough memory"); } cur_element_count = 0; visited_list_pool_ = std::unique_ptr(new VisitedListPool(1, max_elements)); // initializations for special treatment of the first node enterpoint_node_ = -1; maxlevel_ = -1; linkLists_ = (char**)mi_malloc(sizeof(void*) * max_elements_); if (linkLists_ == nullptr) throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); mult_ = 1 / log(1.0 * M_); revSize_ = 1.0 / mult_; } ~HierarchicalNSW() { clear(); } void clear() { mi_free(data_level0_memory_); data_level0_memory_ = nullptr; for (tableint i = 0; i < cur_element_count; i++) { if (element_levels_[i] > 0) mi_free(linkLists_[i]); } if (copy_vector_) { mi_free(data_vector_memory_); } mi_free(linkLists_); linkLists_ = nullptr; cur_element_count = 0; visited_list_pool_.reset(nullptr); } struct CompareByFirst { constexpr bool operator()(std::pair const& a, std::pair const& b) const noexcept { return a.first < b.first; } }; void setEf(size_t ef) { ef_ = ef; } inline std::mutex& getLabelOpMutex(labeltype label) const { // calculate hash size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); return label_op_locks_[lock_id]; } inline labeltype getExternalLabel(tableint internal_id) const { labeltype return_label; memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); return return_label; } inline void setExternalLabel(tableint internal_id, labeltype label) const { memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); } inline char* getDataPtrByInternalId(tableint internal_id) const { return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); } // Return pointer to data by internal id inline char* getDataByInternalId(tableint internal_id) const { if (copy_vector_) { return (data_vector_memory_ + internal_id * data_size_); } else { char* unaligned_data_ptr = (char*)(getDataPtrByInternalId(internal_id)); char* data_ptr = nullptr; memcpy(static_cast(&data_ptr), unaligned_data_ptr, sizeof(void*)); return data_ptr; } } int getRandomLevel(double reverse_size) { std::uniform_real_distribution distribution(0.0, 1.0); double r = -log(distribution(level_generator_)) * reverse_size; return (int)r; } size_t getMaxElements() { return max_elements_; } size_t getCurrentElementCount() { return cur_element_count; } size_t getDeletedCount() { return num_deleted_; } std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void* data_point, int layer) { VisitedList* vl = visited_list_pool_->getFreeVisitedList(); vl_type* visited_array = vl->mass; vl_type visited_array_tag = vl->curV; std::priority_queue, std::vector>, CompareByFirst> top_candidates; std::priority_queue, std::vector>, CompareByFirst> candidateSet; dist_t lowerBound; if (!isMarkedDeleted(ep_id)) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); top_candidates.emplace(dist, ep_id); lowerBound = dist; candidateSet.emplace(-dist, ep_id); } else { lowerBound = std::numeric_limits::max(); candidateSet.emplace(-lowerBound, ep_id); } visited_array[ep_id] = visited_array_tag; while (!candidateSet.empty()) { std::pair curr_el_pair = candidateSet.top(); if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { break; } candidateSet.pop(); tableint curNodeNum = curr_el_pair.second; std::unique_lock lock(link_list_locks_[curNodeNum]); int* data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); if (layer == 0) { data = (int*)get_linklist0(curNodeNum); } else { data = (int*)get_linklist(curNodeNum, layer); // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * // size_links_per_element_); } size_t size = getListCount((linklistsizeint*)data); tableint* datal = (tableint*)(data + 1); __builtin_prefetch((char*)(visited_array + *(data + 1)), 0, 3); __builtin_prefetch((char*)(visited_array + *(data + 1) + 64), 0, 3); __builtin_prefetch(getDataByInternalId(*datal), 0, 3); for (size_t j = 0; j < size; j++) { tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; // Request prefetching next vector data memory if (j + 1 < size) { __builtin_prefetch(getDataByInternalId(*(datal + j + 1)), 0, 3); } if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; char* currObj1 = (getDataByInternalId(candidate_id)); dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { candidateSet.emplace(-dist1, candidate_id); __builtin_prefetch(getDataByInternalId(candidateSet.top().second), 0, 3); if (!isMarkedDeleted(candidate_id)) top_candidates.emplace(dist1, candidate_id); if (top_candidates.size() > ef_construction_) top_candidates.pop(); if (!top_candidates.empty()) lowerBound = top_candidates.top().first; } } } visited_list_pool_->releaseVisitedList(vl); return top_candidates; } // bare_bone_search means there is no check for deletions and stop condition is ignored in return // of extra performance template std::priority_queue, std::vector>, CompareByFirst> searchBaseLayerST(tableint ep_id, const void* data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr, BaseSearchStopCondition* stop_condition = nullptr) const { VisitedList* vl = visited_list_pool_->getFreeVisitedList(); vl_type* visited_array = vl->mass; vl_type visited_array_tag = vl->curV; std::priority_queue, std::vector>, CompareByFirst> top_candidates; std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; if (bare_bone_search || (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { char* ep_data = getDataByInternalId(ep_id); dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); if (!bare_bone_search && stop_condition) { stop_condition->add_point_to_result(getExternalLabel(ep_id), ep_data, dist); } candidate_set.emplace(-dist, ep_id); } else { lowerBound = std::numeric_limits::max(); candidate_set.emplace(-lowerBound, ep_id); } visited_array[ep_id] = visited_array_tag; while (!candidate_set.empty()) { std::pair current_node_pair = candidate_set.top(); dist_t candidate_dist = -current_node_pair.first; bool flag_stop_search; if (bare_bone_search) { flag_stop_search = candidate_dist > lowerBound; } else { if (stop_condition) { flag_stop_search = stop_condition->should_stop_search(candidate_dist, lowerBound); } else { flag_stop_search = candidate_dist > lowerBound && top_candidates.size() == ef; } } if (flag_stop_search) { break; } candidate_set.pop(); tableint current_node_id = current_node_pair.second; int* data = (int*)get_linklist0(current_node_id); size_t size = getListCount((linklistsizeint*)data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); if (collect_metrics) { metric_hops++; metric_distance_computations += size; } __builtin_prefetch((char*)(visited_array + *(data + 1)), 0, 3); __builtin_prefetch((char*)(visited_array + *(data + 1) + 64), 0, 3); __builtin_prefetch(getDataByInternalId(*(data + 1)), 0, 3); __builtin_prefetch((char*)(data + 2), 0, 3); for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); // if (candidate_id == 0) continue; // Request prefetching next vector data memory if (j + 1 < size) { __builtin_prefetch(getDataByInternalId(*(data + j + 1)), 0, 3); } if (!(visited_array[candidate_id] == visited_array_tag)) { visited_array[candidate_id] = visited_array_tag; char* currObj1 = (getDataByInternalId(candidate_id)); dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); bool flag_consider_candidate; if (!bare_bone_search && stop_condition) { flag_consider_candidate = stop_condition->should_consider_candidate(dist, lowerBound); } else { flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist; } if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); __builtin_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + offsetLevel0_, /////////// 0, 3); //////////////////////// if (bare_bone_search || (!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) { top_candidates.emplace(dist, candidate_id); if (!bare_bone_search && stop_condition) { stop_condition->add_point_to_result(getExternalLabel(candidate_id), currObj1, dist); } } bool flag_remove_extra = false; if (!bare_bone_search && stop_condition) { flag_remove_extra = stop_condition->should_remove_extra(); } else { flag_remove_extra = top_candidates.size() > ef; } while (flag_remove_extra) { tableint id = top_candidates.top().second; top_candidates.pop(); if (!bare_bone_search && stop_condition) { stop_condition->remove_point_from_result(getExternalLabel(id), getDataByInternalId(id), dist); flag_remove_extra = stop_condition->should_remove_extra(); } else { flag_remove_extra = top_candidates.size() > ef; } } if (!top_candidates.empty()) lowerBound = top_candidates.top().first; } } } } visited_list_pool_->releaseVisitedList(vl); return top_candidates; } void getNeighborsByHeuristic2( std::priority_queue, std::vector>, CompareByFirst>& top_candidates, const size_t M) { if (top_candidates.size() < M) { return; } std::priority_queue> queue_closest; std::vector> return_list; while (top_candidates.size() > 0) { queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); top_candidates.pop(); } while (queue_closest.size()) { if (return_list.size() >= M) break; std::pair curent_pair = queue_closest.top(); dist_t dist_to_query = -curent_pair.first; queue_closest.pop(); bool good = true; for (std::pair second_pair : return_list) { dist_t curdist = fstdistfunc_(getDataByInternalId(second_pair.second), getDataByInternalId(curent_pair.second), dist_func_param_); if (curdist < dist_to_query) { good = false; break; } } if (good) { return_list.push_back(curent_pair); } } for (std::pair curent_pair : return_list) { top_candidates.emplace(-curent_pair.first, curent_pair.second); } } linklistsizeint* get_linklist0(tableint internal_id) const { return (linklistsizeint*)(data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); } linklistsizeint* get_linklist0(tableint internal_id, char* data_level0_memory_) const { return (linklistsizeint*)(data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); } linklistsizeint* get_linklist(tableint internal_id, int level) const { return (linklistsizeint*)(linkLists_[internal_id] + (level - 1) * size_links_per_element_); } linklistsizeint* get_linklist_at_level(tableint internal_id, int level) const { return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); } tableint mutuallyConnectNewElement( const void* data_point, tableint cur_c, std::priority_queue, std::vector>, CompareByFirst>& top_candidates, int level, bool isUpdate) { size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) throw std::runtime_error( "Should be not be more than M_ candidates returned by the heuristic"); std::vector selectedNeighbors; selectedNeighbors.reserve(M_); while (top_candidates.size() > 0) { selectedNeighbors.push_back(top_candidates.top().second); top_candidates.pop(); } tableint next_closest_entry_point = selectedNeighbors.back(); { // lock only during the update // because during the addition the lock for cur_c is already acquired std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); if (isUpdate) { lock.lock(); } linklistsizeint* ll_cur; if (level == 0) ll_cur = get_linklist0(cur_c); else ll_cur = get_linklist(cur_c, level); if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); } setListCount(ll_cur, selectedNeighbors.size()); tableint* data = (tableint*)(ll_cur + 1); for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { if (data[idx] && !isUpdate) throw std::runtime_error("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); data[idx] = selectedNeighbors[idx]; } } for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); linklistsizeint* ll_other; if (level == 0) ll_other = get_linklist0(selectedNeighbors[idx]); else ll_other = get_linklist(selectedNeighbors[idx], level); size_t sz_link_list_other = getListCount(ll_other); if (sz_link_list_other > Mcurmax) throw std::runtime_error("Bad value of sz_link_list_other"); if (selectedNeighbors[idx] == cur_c) throw std::runtime_error("Trying to connect an element to itself"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); tableint* data = (tableint*)(ll_other + 1); bool is_cur_c_present = false; if (isUpdate) { for (size_t j = 0; j < sz_link_list_other; j++) { if (data[j] == cur_c) { is_cur_c_present = true; break; } } } // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then // no need to modify any connections or run the heuristics. if (!is_cur_c_present) { if (sz_link_list_other < Mcurmax) { data[sz_link_list_other] = cur_c; setListCount(ll_other, sz_link_list_other + 1); } else { // finding the "weakest" element to replace it with the new one dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), dist_func_param_); // Heuristic: std::priority_queue, std::vector>, CompareByFirst> candidates; candidates.emplace(d_max, cur_c); for (size_t j = 0; j < sz_link_list_other; j++) { candidates.emplace( fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), dist_func_param_), data[j]); } getNeighborsByHeuristic2(candidates, Mcurmax); int indx = 0; while (candidates.size() > 0) { data[indx] = candidates.top().second; candidates.pop(); indx++; } setListCount(ll_other, indx); // Nearest K: /*int indx = -1; for (int j = 0; j < sz_link_list_other; j++) { dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); if (d > d_max) { indx = j; d_max = d; } } if (indx >= 0) { data[indx] = cur_c; } */ } } } return next_closest_entry_point; } void resizeIndex(size_t new_max_elements) { if (new_max_elements < cur_element_count) throw std::runtime_error( "Cannot resize, max element is less than the current number of elements"); visited_list_pool_.reset(new VisitedListPool(1, new_max_elements)); element_levels_.resize(new_max_elements); std::vector(new_max_elements).swap(link_list_locks_); // Reallocate base layer char* data_level0_memory_new = (char*)mi_realloc(data_level0_memory_, new_max_elements * size_data_per_element_); if (data_level0_memory_new == nullptr) throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); data_level0_memory_ = data_level0_memory_new; // If we copy vectors, reallocate also vector data memory if (copy_vector_) { char* data_vector_memory_new = (char*)mi_realloc(data_vector_memory_, new_max_elements * data_size_); if (data_vector_memory_new == nullptr) throw std::runtime_error("Not enough memory: resizeIndex failed to allocate vector memory"); data_vector_memory_ = data_vector_memory_new; } // Reallocate all other layers char** linkLists_new = (char**)mi_realloc(linkLists_, sizeof(void*) * new_max_elements); if (linkLists_new == nullptr) throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); linkLists_ = linkLists_new; max_elements_ = new_max_elements; } size_t indexFileSize() const { size_t size = 0; size += sizeof(offsetLevel0_); size += sizeof(max_elements_); size += sizeof(cur_element_count); size += sizeof(size_data_per_element_); size += sizeof(label_offset_); size += sizeof(offsetData_); size += sizeof(maxlevel_); size += sizeof(enterpoint_node_); size += sizeof(maxM_); size += sizeof(maxM0_); size += sizeof(M_); size += sizeof(mult_); size += sizeof(ef_construction_); size += cur_element_count * size_data_per_element_; for (size_t i = 0; i < cur_element_count; i++) { unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; size += sizeof(linkListSize); size += linkListSize; } return size; } void saveIndex(const std::string& location) { #if 0 std::ofstream output(location, std::ios::binary); std::streampos position; writeBinaryPOD(output, offsetLevel0_); writeBinaryPOD(output, max_elements_); writeBinaryPOD(output, cur_element_count); writeBinaryPOD(output, size_data_per_element_); writeBinaryPOD(output, label_offset_); writeBinaryPOD(output, offsetData_); writeBinaryPOD(output, maxlevel_); writeBinaryPOD(output, enterpoint_node_); writeBinaryPOD(output, maxM_); writeBinaryPOD(output, maxM0_); writeBinaryPOD(output, M_); writeBinaryPOD(output, mult_); writeBinaryPOD(output, ef_construction_); writeBinaryPOD(output, copy_vector_); output.write(data_level0_memory_, cur_element_count * size_data_per_element_); if(copy_vector_) { output.write(data_vector_memory_, cur_element_count * data_size_); } for (size_t i = 0; i < cur_element_count; i++) { unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; writeBinaryPOD(output, linkListSize); if (linkListSize) output.write(linkLists_[i], linkListSize); } output.close(); #endif } void loadIndex(const std::string& location, hnswlib::SpaceInterface* s, size_t max_elements_i = 0) { #if 0 std::ifstream input(location, std::ios::binary); if (!input.is_open()) throw std::runtime_error("Cannot open file"); clear(); // get file size: input.seekg(0, input.end); std::streampos total_filesize = input.tellg(); input.seekg(0, input.beg); readBinaryPOD(input, offsetLevel0_); readBinaryPOD(input, max_elements_); readBinaryPOD(input, cur_element_count); size_t max_elements = max_elements_i; if (max_elements < cur_element_count) max_elements = max_elements_; max_elements_ = max_elements; readBinaryPOD(input, size_data_per_element_); readBinaryPOD(input, label_offset_); readBinaryPOD(input, offsetData_); readBinaryPOD(input, maxlevel_); readBinaryPOD(input, enterpoint_node_); readBinaryPOD(input, maxM_); readBinaryPOD(input, maxM0_); readBinaryPOD(input, M_); readBinaryPOD(input, mult_); readBinaryPOD(input, ef_construction_); readBinaryPOD(input, copy_vector_); data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); auto pos = input.tellg(); /// Optional - check if index is ok: input.seekg(cur_element_count * size_data_per_element_, input.cur); for (size_t i = 0; i < cur_element_count; i++) { if (input.tellg() < 0 || input.tellg() >= total_filesize) { throw std::runtime_error("Index seems to be corrupted or unsupported"); } unsigned int linkListSize; readBinaryPOD(input, linkListSize); if (linkListSize != 0) { input.seekg(linkListSize, input.cur); } } // throw exception if it either corrupted or old index if (input.tellg() != total_filesize) throw std::runtime_error("Index seems to be corrupted or unsupported"); input.clear(); /// Optional check end input.seekg(pos, input.beg); data_level0_memory_ = (char *) mi_malloc(max_elements * size_data_per_element_); if (data_level0_memory_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); input.read(data_level0_memory_, cur_element_count * size_data_per_element_); if(copy_vector_) { data_vector_memory_ = (char *) mi_malloc(max_elements * data_size_); if (data_vector_memory_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate vector memory"); input.read(data_vector_memory_, cur_element_count * data_size_); } size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); visited_list_pool_.reset(new VisitedListPool(1, max_elements)); linkLists_ = (char **) mi_malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); element_levels_ = std::vector(max_elements); revSize_ = 1.0 / mult_; ef_ = 10; for (size_t i = 0; i < cur_element_count; i++) { label_lookup_[getExternalLabel(i)] = i; unsigned int linkListSize; readBinaryPOD(input, linkListSize); if (linkListSize == 0) { element_levels_[i] = 0; linkLists_[i] = nullptr; } else { element_levels_[i] = linkListSize / size_links_per_element_; linkLists_[i] = (char *) mi_malloc(linkListSize); if (linkLists_[i] == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); input.read(linkLists_[i], linkListSize); } } for (size_t i = 0; i < cur_element_count; i++) { if (isMarkedDeleted(i)) { num_deleted_ += 1; if (allow_replace_deleted_) deleted_elements.insert(i); } } input.close(); #endif } template std::vector getDataByLabel(labeltype label) const { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { throw std::runtime_error("Label not found"); } tableint internalId = search->second; lock_table.unlock(); char* data_ptrv = getDataByInternalId(internalId); size_t dim = *((size_t*)dist_func_param_); std::vector data; data_t* data_ptr = (data_t*)data_ptrv; for (size_t i = 0; i < dim; i++) { data.push_back(*data_ptr); data_ptr += 1; } return data; } /* * Marks an element with the given label deleted, does NOT really change the current graph. */ HnswErrorStatus markDelete(labeltype label) { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { return HnswErrorStatus::LABEL_NOT_FOUND; } tableint internalId = search->second; lock_table.unlock(); if (!markDeletedInternal(internalId)) { return HnswErrorStatus::ELEMENT_ALREADY_DELETED; } return HnswErrorStatus::SUCCESS; } /* * Uses the last 16 bits of the memory for the linked list size to store the mark, * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost * all cases. */ bool markDeletedInternal(tableint internalId) { assert(internalId < cur_element_count); if (!isMarkedDeleted(internalId)) { unsigned char* ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; *ll_cur |= DELETE_MARK; num_deleted_ += 1; if (allow_replace_deleted_) { std::unique_lock lock_deleted_elements(deleted_elements_lock); deleted_elements.insert(internalId); } return true; } else { return false; } } /* * Removes the deleted mark of the node, does NOT really change the current graph. * * Note: the method is not safe to use when replacement of deleted elements is enabled, * because elements marked as deleted can be completely removed by addPoint */ void unmarkDelete(labeltype label) { // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search == label_lookup_.end()) { throw std::runtime_error("Label not found"); } tableint internalId = search->second; lock_table.unlock(); unmarkDeletedInternal(internalId); } /* * Remove the deleted mark of the node. */ void unmarkDeletedInternal(tableint internalId) { assert(internalId < cur_element_count); if (isMarkedDeleted(internalId)) { unsigned char* ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; *ll_cur &= ~DELETE_MARK; num_deleted_ -= 1; if (allow_replace_deleted_) { std::unique_lock lock_deleted_elements(deleted_elements_lock); deleted_elements.erase(internalId); } } else { throw std::runtime_error("The requested to undelete element is not deleted"); } } /* * Checks the first 16 bits of the memory to see if the element is marked deleted. */ bool isMarkedDeleted(tableint internalId) const { unsigned char* ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; return *ll_cur & DELETE_MARK; } unsigned short int getListCount(linklistsizeint* ptr) const { return *((unsigned short int*)ptr); } void setListCount(linklistsizeint* ptr, unsigned short int size) const { *((unsigned short int*)(ptr)) = *((unsigned short int*)&size); } /* * Adds point. Updates the point if it is already in the index. * If replacement of deleted elements is enabled: replaces previously deleted point if any, * updating it with new point */ void addPoint(const void* data_point, labeltype label, bool replace_deleted = false) { if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); } // lock all operations with element by label std::unique_lock lock_label(getLabelOpMutex(label)); if (!replace_deleted) { addPoint(data_point, label, -1); return; } // check if there is vacant place tableint internal_id_replaced; std::unique_lock lock_deleted_elements(deleted_elements_lock); bool is_vacant_place = !deleted_elements.empty(); if (is_vacant_place) { internal_id_replaced = *deleted_elements.begin(); deleted_elements.erase(internal_id_replaced); } lock_deleted_elements.unlock(); // if there is no vacant place then add or update point // else add point to vacant place if (!is_vacant_place) { addPoint(data_point, label, -1); } else { // we assume that there are no concurrent operations on deleted element labeltype label_replaced = getExternalLabel(internal_id_replaced); setExternalLabel(internal_id_replaced, label); std::unique_lock lock_table(label_lookup_lock); label_lookup_.erase(label_replaced); label_lookup_[label] = internal_id_replaced; lock_table.unlock(); unmarkDeletedInternal(internal_id_replaced); updatePoint(data_point, internal_id_replaced, 1.0); } } void updatePoint(const void* dataPointIn, tableint internalId, float updateNeighborProbability) { if (copy_vector_) { memcpy(getDataByInternalId(internalId), dataPointIn, data_size_); } else { memcpy(getDataPtrByInternalId(internalId), &dataPointIn, sizeof(void*)); } const void* dataPoint = getDataByInternalId(internalId); assert(dataPoint != nullptr); int maxLevelCopy = maxlevel_; tableint entryPointCopy = enterpoint_node_; // If point to be updated is entry point and graph just contains single element then just // return. if (entryPointCopy == internalId && cur_element_count == 1) return; int elemLevel = element_levels_[internalId]; std::uniform_real_distribution distribution(0.0, 1.0); for (int layer = 0; layer <= elemLevel; layer++) { std::unordered_set sCand; std::unordered_set sNeigh; std::vector listOneHop = getConnectionsWithLock(internalId, layer); if (listOneHop.size() == 0) continue; sCand.insert(internalId); for (auto&& elOneHop : listOneHop) { sCand.insert(elOneHop); if (distribution(update_probability_generator_) > updateNeighborProbability) continue; sNeigh.insert(elOneHop); std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); for (auto&& elTwoHop : listTwoHop) { sCand.insert(elTwoHop); } } for (auto&& neigh : sNeigh) { // if (neigh == internalId) // continue; std::priority_queue, std::vector>, CompareByFirst> candidates; size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 size_t elementsToKeep = std::min(ef_construction_, size); for (auto&& cand : sCand) { if (cand == neigh) continue; dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); if (candidates.size() < elementsToKeep) { candidates.emplace(distance, cand); } else { if (distance < candidates.top().first) { candidates.pop(); candidates.emplace(distance, cand); } } } // Retrieve neighbours using heuristic and set connections. getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); { std::unique_lock lock(link_list_locks_[neigh]); linklistsizeint* ll_cur; ll_cur = get_linklist_at_level(neigh, layer); size_t candSize = candidates.size(); setListCount(ll_cur, candSize); tableint* data = (tableint*)(ll_cur + 1); for (size_t idx = 0; idx < candSize; idx++) { data[idx] = candidates.top().second; candidates.pop(); } } } } repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); } void repairConnectionsForUpdate(const void* dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { tableint currObj = entryPointInternalId; if (dataPointLevel < maxLevel) { dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); for (int level = maxLevel; level > dataPointLevel; level--) { bool changed = true; while (changed) { changed = false; unsigned int* data; std::unique_lock lock(link_list_locks_[currObj]); data = get_linklist_at_level(currObj, level); int size = getListCount(data); tableint* datal = (tableint*)(data + 1); __builtin_prefetch(getDataByInternalId(*datal), 0, 3); for (int i = 0; i < size; i++) { if (i + 1 < size) { __builtin_prefetch(getDataByInternalId(*(datal + i + 1)), 1, 3); } tableint cand = datal[i]; dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; currObj = cand; changed = true; } } } } } if (dataPointLevel > maxLevel) throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); for (int level = dataPointLevel; level >= 0; level--) { std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer(currObj, dataPoint, level); std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; while (topCandidates.size() > 0) { if (topCandidates.top().second != dataPointInternalId) filteredTopCandidates.push(topCandidates.top()); topCandidates.pop(); } // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where // `topCandidates` could just contains entry point itself. To prevent self loops, the // `topCandidates` is filtered and thus can be empty. if (filteredTopCandidates.size() > 0) { bool epDeleted = isMarkedDeleted(entryPointInternalId); if (epDeleted) { filteredTopCandidates.emplace( fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); if (filteredTopCandidates.size() > ef_construction_) filteredTopCandidates.pop(); } currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); } } } std::vector getConnectionsWithLock(tableint internalId, int level) { std::unique_lock lock(link_list_locks_[internalId]); unsigned int* data = get_linklist_at_level(internalId, level); int size = getListCount(data); std::vector result(size); tableint* ll = (tableint*)(data + 1); memcpy(result.data(), ll, size * sizeof(tableint)); return result; } tableint addPoint(const void* data_point_in, labeltype label, int level) { tableint cur_c = 0; { // Checking if the element with the same label already exists // if so, updating it *instead* of creating a new element. std::unique_lock lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search != label_lookup_.end()) { tableint existingInternalId = search->second; if (allow_replace_deleted_) { if (isMarkedDeleted(existingInternalId)) { throw std::runtime_error( "Can't use addPoint to update deleted elements if replacement of deleted elements " "is enabled."); } } lock_table.unlock(); if (isMarkedDeleted(existingInternalId)) { unmarkDeletedInternal(existingInternalId); } updatePoint(data_point_in, existingInternalId, 1.0); return existingInternalId; } if (cur_element_count >= max_elements_) { throw std::runtime_error("The number of elements exceeds the specified limit"); } cur_c = cur_element_count; cur_element_count++; label_lookup_[label] = cur_c; } std::unique_lock lock_el(link_list_locks_[cur_c]); int curlevel = getRandomLevel(mult_); if (level > 0) curlevel = level; element_levels_[cur_c] = curlevel; std::unique_lock templock(global); int maxlevelcopy = maxlevel_; if (curlevel <= maxlevelcopy) templock.unlock(); tableint currObj = enterpoint_node_; tableint enterpoint_copy = enterpoint_node_; memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); if (copy_vector_) { memset(data_vector_memory_ + cur_c * data_size_, 0, data_size_); } // Initialisation of the data and label setExternalLabel(cur_c, label); if (copy_vector_) { memcpy(getDataByInternalId(cur_c), data_point_in, data_size_); } else { memcpy(getDataPtrByInternalId(cur_c), &data_point_in, sizeof(void*)); } const void* data_point = getDataByInternalId(cur_c); assert(data_point != nullptr); if (curlevel) { linkLists_[cur_c] = (char*)mi_malloc(size_links_per_element_ * curlevel + 1); if (linkLists_[cur_c] == nullptr) throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); } if ((signed)currObj != -1) { if (curlevel < maxlevelcopy) { dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); for (int level = maxlevelcopy; level > curlevel; level--) { bool changed = true; while (changed) { changed = false; unsigned int* data; std::unique_lock lock(link_list_locks_[currObj]); data = get_linklist(currObj, level); int size = getListCount(data); tableint* datal = (tableint*)(data + 1); for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand > max_elements_) throw std::runtime_error("cand error"); dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; currObj = cand; changed = true; } } } } } bool epDeleted = isMarkedDeleted(enterpoint_copy); for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { if (level > maxlevelcopy || level < 0) // possible? throw std::runtime_error("Level error"); std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer(currObj, data_point, level); if (epDeleted) { top_candidates.emplace( fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); if (top_candidates.size() > ef_construction_) top_candidates.pop(); } currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } } else { // Do nothing for the first element enterpoint_node_ = 0; maxlevel_ = curlevel; } // Releasing lock for the maximum level if (curlevel > maxlevelcopy) { enterpoint_node_ = cur_c; maxlevel_ = curlevel; } return cur_c; } std::priority_queue> searchKnn( const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { std::priority_queue> result; if (cur_element_count == 0) return result; tableint currObj = enterpoint_node_; dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); for (int level = maxlevel_; level > 0; level--) { bool changed = true; while (changed) { changed = false; unsigned int* data; data = (unsigned int*)get_linklist(currObj, level); int size = getListCount(data); metric_hops++; metric_distance_computations += size; tableint* datal = (tableint*)(data + 1); for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand > max_elements_) throw std::runtime_error("cand error"); dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; currObj = cand; changed = true; } } } } std::priority_queue, std::vector>, CompareByFirst> top_candidates; bool bare_bone_search = !num_deleted_ && !isIdAllowed; if (bare_bone_search) { top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, k), isIdAllowed); } else { top_candidates = searchBaseLayerST(currObj, query_data, std::max(ef_, k), isIdAllowed); } while (top_candidates.size() > k) { top_candidates.pop(); } while (top_candidates.size() > 0) { std::pair rez = top_candidates.top(); result.push(std::pair(rez.first, getExternalLabel(rez.second))); top_candidates.pop(); } return result; } // Brute-force KNN search over a pre-filtered set of label IDs. // Computes distances for all provided IDs and returns the top-k closest, ordered by distance. std::priority_queue> subsetKnnSearch( const void* query_data, size_t k, const std::vector& ids) const { std::priority_queue> result; if (cur_element_count == 0 || ids.empty() || k == 0) return result; for (const auto& label : ids) { auto it = label_lookup_.find(label); if (it == label_lookup_.end()) { continue; } tableint internal_id = it->second; if (isMarkedDeleted(internal_id)) { continue; } dist_t dist = fstdistfunc_(query_data, getDataByInternalId(internal_id), dist_func_param_); if (result.size() < k) { result.emplace(dist, label); } else if (dist < result.top().first) { result.pop(); result.emplace(dist, label); } } return result; } std::vector> searchStopConditionClosest( const void* query_data, BaseSearchStopCondition& stop_condition, BaseFilterFunctor* isIdAllowed = nullptr) const { std::vector> result; if (cur_element_count == 0) return result; tableint currObj = enterpoint_node_; dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); for (int level = maxlevel_; level > 0; level--) { bool changed = true; while (changed) { changed = false; unsigned int* data; data = (unsigned int*)get_linklist(currObj, level); int size = getListCount(data); metric_hops++; metric_distance_computations += size; tableint* datal = (tableint*)(data + 1); for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand < 0 || cand > max_elements_) throw std::runtime_error("cand error"); dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; currObj = cand; changed = true; } } } } std::priority_queue, std::vector>, CompareByFirst> top_candidates; top_candidates = searchBaseLayerST(currObj, query_data, 0, isIdAllowed, &stop_condition); size_t sz = top_candidates.size(); result.resize(sz); while (!top_candidates.empty()) { result[--sz] = top_candidates.top(); top_candidates.pop(); } stop_condition.filter_results(result); return result; } // Returns all elements within `radius` distance from query_data. // Adapts the HNSW beam search from Malkov & Yashunin (2018), https://arxiv.org/abs/1603.09320: // Phase 1 is the standard greedy descent to find the level-0 entry point; Phase 2 replaces // the top-k heap with a radius threshold, collecting all nodes with dist <= radius. // The dynamic search boundary starts at max(entry_point_distance, radius) and shrinks as // closer out-of-radius candidates are found; `epsilon` controls the overscan factor // (default 0.01) to improve recall near the boundary. std::vector> searchRange(const void* query_data, dist_t radius, double epsilon = 0.01) const { std::vector> result; if (cur_element_count == 0) return result; // Phase 1: greedy descent from top level to find the best entry point for level 0. tableint currObj = enterpoint_node_; dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); for (int level = maxlevel_; level > 0; level--) { bool changed = true; while (changed) { changed = false; unsigned int* data = (unsigned int*)get_linklist(currObj, level); int size = getListCount(data); tableint* datal = (tableint*)(data + 1); for (int i = 0; i < size; i++) { tableint cand = datal[i]; if (cand >= max_elements_) throw std::runtime_error("cand error"); dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); if (d < curdist) { curdist = d; currObj = cand; changed = true; } } } } // Phase 2: range search on bottom layer (level 0) with dynamic search boundary. VisitedList* vl = visited_list_pool_->getFreeVisitedList(); vl_type* visited_array = vl->mass; vl_type visited_array_tag = vl->curV; std::priority_queue, std::vector>, CompareByFirst> candidate_set; // Dynamic range starts at max(entry_point_dist, radius) so we never stop early just // because the entry point is farther than radius. dist_t ep_dist = curdist; dist_t dynamic_range = std::max(ep_dist, radius); dist_t dyn_boundary = static_cast(dynamic_range * (1.0 + epsilon)); if (!isMarkedDeleted(currObj) && ep_dist <= radius) result.emplace_back(ep_dist, getExternalLabel(currObj)); candidate_set.emplace(-ep_dist, currObj); visited_array[currObj] = visited_array_tag; while (!candidate_set.empty()) { auto curr_pair = candidate_set.top(); dist_t curr_dist = -curr_pair.first; if (curr_dist > dyn_boundary) break; candidate_set.pop(); tableint curr_id = curr_pair.second; // Shrink dynamic_range: if candidate is between radius and current range, pull the // boundary down toward radius. If candidate is within radius and dynamic_range is // still above radius (entry point was far), clamp to radius so we stop over-scanning. if (curr_dist < dynamic_range) { if (curr_dist >= radius) { dynamic_range = curr_dist; } else if (dynamic_range > radius) { dynamic_range = radius; } dyn_boundary = static_cast(dynamic_range * (1.0 + epsilon)); } int* data = (int*)get_linklist0(curr_id); size_t size = getListCount((linklistsizeint*)data); for (size_t j = 1; j <= size; j++) { tableint candidate_id = *(data + j); if (candidate_id >= max_elements_) throw std::runtime_error("cand error"); if (j < size) __builtin_prefetch(getDataByInternalId(*(data + j + 1)), 0, 3); if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; dist_t d = fstdistfunc_(query_data, getDataByInternalId(candidate_id), dist_func_param_); if (d < dyn_boundary) { candidate_set.emplace(-d, candidate_id); if (!isMarkedDeleted(candidate_id) && d <= radius) result.emplace_back(d, getExternalLabel(candidate_id)); } } } visited_list_pool_->releaseVisitedList(vl); return result; } #if 0 void checkIntegrity() { int connections_checked = 0; std::vector inbound_connections_num(cur_element_count, 0); for (int i = 0; i < cur_element_count; i++) { for (int l = 0; l <= element_levels_[i]; l++) { linklistsizeint *ll_cur = get_linklist_at_level(i, l); int size = getListCount(ll_cur); tableint *data = (tableint *) (ll_cur + 1); std::unordered_set s; for (int j = 0; j < size; j++) { assert(data[j] < cur_element_count); assert(data[j] != i); inbound_connections_num[data[j]]++; s.insert(data[j]); connections_checked++; } assert(s.size() == size); } } if (cur_element_count > 1) { int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; for (int i=0; i < cur_element_count; i++) { assert(inbound_connections_num[i] > 0); min1 = std::min(inbound_connections_num[i], min1); max1 = std::max(inbound_connections_num[i], max1); } std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; } std::cout << "integrity ok, checked " << connections_checked << " connections\n"; } #endif }; } // namespace dfly::search ================================================ FILE: src/core/search/hnsw_index.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/hnsw_index.h" #include #include #include #include #include #include "base/logging.h" #include "core/search/hnsw_alg.h" #include "core/search/mrmw_mutex.h" #include "core/search/vector_utils.h" namespace dfly::search { using namespace std; namespace { class HnswSpace : public hnswlib::SpaceInterface { unsigned dim_; VectorSimilarity sim_; static float L2DistanceStatic(const void* pVect1, const void* pVect2, const void* param) { return L2Distance(static_cast(pVect1), static_cast(pVect2), *static_cast(param)); } static float IPDistanceStatic(const void* pVect1, const void* pVect2, const void* param) { return IPDistance(static_cast(pVect1), static_cast(pVect2), *static_cast(param)); } static float CosineDistanceStatic(const void* pVect1, const void* pVect2, const void* param) { return CosineDistance(static_cast(pVect1), static_cast(pVect2), *static_cast(param)); } public: explicit HnswSpace(size_t dim, VectorSimilarity sim) : dim_(dim), sim_(sim) { } size_t get_data_size() { return dim_ * sizeof(float); } hnswlib::DISTFUNC get_dist_func() { if (sim_ == VectorSimilarity::L2) { return L2DistanceStatic; } else if (sim_ == VectorSimilarity::COSINE) { return CosineDistanceStatic; } else { return IPDistanceStatic; } } void* get_dist_func_param() { return &dim_; } }; } // namespace // TODO: to replace it and use HierarchicalNSW directly. struct HnswlibAdapter { // Default setting of hnswlib/hnswalg constexpr static size_t kDefaultEfRuntime = 10; explicit HnswlibAdapter(const SchemaField::VectorParams& params, bool copy_vector) : space_{params.dim, params.sim}, world_{&space_, params.capacity, params.hnsw_m, params.hnsw_ef_construction, 100 /* seed*/, copy_vector}, copy_vector_{copy_vector}, data_size_{params.dim * sizeof(float)} { } // Adds a point to the index. If the write lock cannot be acquired (e.g. // serialization holds a read lock), the operation is deferred and will be // replayed by a subsequent write or TryProcessDeferred() call. // When copy_vector_ is false the index stores a raw pointer to external data, // so we must add the point synchronously before the caller's pointer goes out // of scope — use a blocking write lock in that case. void Add(const void* data, GlobalDocId id) { if (copy_vector_) { { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock, std::try_to_lock); if (lock.locked()) { ProcessDeferred(); DoAdd(data, id); return; } } // Could not acquire write lock — defer the operation. AddDeferredOp(id, DeferredOp(true, data, data_size_, /*copy=*/true)); TryProcessDeferred(); } else { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); ProcessDeferred(); DoAdd(data, id); } } // Removes a point from the index. If the write lock cannot be acquired, the // operation is deferred. void Remove(GlobalDocId id) { { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock, std::try_to_lock); if (lock.locked()) { ProcessDeferred(); DoRemove(id); return; } } AddDeferredOp(id, DeferredOp(false, nullptr, 0, false)); TryProcessDeferred(); } vector> Knn(float* target, size_t k, std::optional ef) { TryProcessDeferred(); world_.setEf(ef.value_or(kDefaultEfRuntime)); MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return QueueToVec(world_.searchKnn(target, k)); } vector> Knn(float* target, size_t k, std::optional ef, const vector& allowed) { struct BinsearchFilter : hnswlib::BaseFilterFunctor { virtual bool operator()(hnswlib::labeltype id) { return binary_search(allowed->begin(), allowed->end(), id); } BinsearchFilter(const vector* allowed) : allowed{allowed} { } const vector* allowed; }; TryProcessDeferred(); world_.setEf(ef.value_or(kDefaultEfRuntime)); BinsearchFilter filter{&allowed}; MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return QueueToVec(world_.searchKnn(target, k, &filter)); } // Brute-force KNN search over a specific subset of documents. // Computes distances for all provided document IDs and returns the k nearest neighbors. vector> SubsetKnn(float* target, size_t k, const vector& docs) { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return QueueToVec(world_.subsetKnnSearch(target, k, docs)); } // Returns all documents within the given radius, with their distances. // Uses dynamic-range exploration (searchRange) to correctly handle cases where // the entry point is farther than radius. vector> RangeSearch(float* target, float radius) { TryProcessDeferred(); MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return world_.searchRange(target, radius); } HnswIndexMetadata GetMetadata() const { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); HnswIndexMetadata metadata; metadata.max_elements = world_.max_elements_; metadata.cur_element_count = world_.cur_element_count.load(); metadata.maxlevel = world_.maxlevel_; metadata.enterpoint_node = world_.enterpoint_node_; return metadata; } void SetMetadata(const HnswIndexMetadata& metadata) { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); absl::WriterMutexLock resize_lock(&resize_mutex_); // SetMetadata is only called during deserialization before the index is used. // Assert the index is empty to ensure no concurrent operations are possible. DCHECK_EQ(world_.cur_element_count.load(), 0u) << "SetMetadata should only be called on an empty index during deserialization"; // Runtime check for release builds to prevent silent corruption if (world_.cur_element_count.load() != 0) { LOG(ERROR) << "SetMetadata called on non-empty HNSW index with " << world_.cur_element_count.load() << " elements, ignoring"; return; } // Pre-allocate capacity based on expected element count, but don't set cur_element_count. // cur_element_count will be set by RestoreFromNodes when the actual nodes are restored. if (world_.max_elements_ < metadata.cur_element_count) { world_.resizeIndex(metadata.cur_element_count); } // Note: Don't set cur_element_count here - RestoreFromNodes will set it after restoring nodes. } size_t GetNodeCount() const { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); return world_.cur_element_count.load(); } std::vector GetNodesRange(size_t start, size_t end) const { DCHECK(mrmw_mutex_.IsReadLocked()); size_t count = world_.cur_element_count.load(); end = std::min(end, count); start = std::min(start, end); std::vector result; result.reserve(end - start); for (size_t internal_id = start; internal_id < end; ++internal_id) { HnswNodeData node_data; node_data.internal_id = internal_id; node_data.global_id = world_.getExternalLabel(internal_id); node_data.level = world_.element_levels_[internal_id]; node_data.levels_links.resize(node_data.level + 1); auto* ll0 = world_.get_linklist0(internal_id); unsigned short link_count0 = world_.getListCount(ll0); auto* links0 = reinterpret_cast(ll0 + 1); node_data.levels_links[0].assign(links0, links0 + link_count0); for (int lvl = 1; lvl <= node_data.level; ++lvl) { auto* ll = world_.get_linklist(internal_id, lvl); unsigned short link_count = world_.getListCount(ll); auto* links = reinterpret_cast(ll + 1); node_data.levels_links[lvl].assign(links, links + link_count); } result.push_back(std::move(node_data)); } return result; } private: // A single deferred Add or Remove operation. struct DeferredOp { bool is_add; bool owns_data; // If true, data_ptr was allocated by us and must be freed. const void* data_ptr; // Pointer to vector data (owned or borrowed). DeferredOp(bool is_add, const void* data, size_t data_size, bool copy) : is_add(is_add), owns_data(copy && data != nullptr) { if (owns_data) { void* buf = mi_malloc(data_size); memcpy(buf, data, data_size); data_ptr = buf; } else { data_ptr = data; } } ~DeferredOp() { if (owns_data) mi_free(const_cast(data_ptr)); } DeferredOp(DeferredOp&& o) noexcept : is_add(o.is_add), owns_data(o.owns_data), data_ptr(o.data_ptr) { o.owns_data = false; o.data_ptr = nullptr; } DeferredOp& operator=(DeferredOp&& o) noexcept { auto lhs = std::tie(is_add, owns_data, data_ptr); auto rhs = std::tie(o.is_add, o.owns_data, o.data_ptr); std::swap(lhs, rhs); return *this; } DeferredOp(const DeferredOp&) = delete; DeferredOp& operator=(const DeferredOp&) = delete; }; // Actually add the point. Must be called while holding mrmw write lock. void DoAdd(const void* data, GlobalDocId id) { while (true) { try { absl::ReaderMutexLock resize_lock(&resize_mutex_); world_.addPoint(data, id); return; } catch (const std::exception& e) { std::string error_msg = e.what(); if (absl::StrContains(error_msg, "The number of elements exceeds the specified limit")) { ResizeIfFull(); continue; } LOG(ERROR) << "HnswlibAdapter::DoAdd exception: " << e.what(); return; } } } void DoRemove(GlobalDocId id) { HnswErrorStatus status = world_.markDelete(id); if (status != HnswErrorStatus::SUCCESS) { VLOG(1) << "HnswlibAdapter::Remove failed with status: " << static_cast(status) << " for global id: " << id; } } // Add a deferred operation, replacing any previous one for the same document. void AddDeferredOp(GlobalDocId id, DeferredOp op) { std::lock_guard g(deferred_mu_); deferred_ops_.insert_or_assign(id, std::move(op)); } // Take all deferred operations out of the queue. absl::flat_hash_map TakeDeferredOps() { std::lock_guard g(deferred_mu_); absl::flat_hash_map ops; ops.swap(deferred_ops_); return ops; } // Drain the deferred operations queue. Must be called while holding the mrmw // write lock. Only copy_vector_=true adds and removes can be deferred, so // ordering within the queue does not matter. void ProcessDeferred() { auto ops = TakeDeferredOps(); for (auto& [id, op] : ops) { if (op.is_add) { DoAdd(op.data_ptr, id); } else { DoRemove(id); } } } // Non-blocking attempt to drain the deferred queue. void TryProcessDeferred() { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock, std::try_to_lock); if (lock.locked()) { ProcessDeferred(); } } // Function requires that we hold mutex while resizing index. resizeIndex is not thread safe with // insertion (https://github.com/nmslib/hnswlib/issues/267) void ResizeIfFull() { { // First check with reader lock to avoid contention. absl::ReaderMutexLock lock(&resize_mutex_); if (world_.getCurrentElementCount() < world_.getMaxElements() || (world_.allow_replace_deleted_ && world_.getDeletedCount() > 0)) { return; } } try { // Upgrade to writer lock. absl::WriterMutexLock lock(&resize_mutex_); if (world_.getCurrentElementCount() == world_.getMaxElements() && (!world_.allow_replace_deleted_ || world_.getDeletedCount() == 0)) { auto max_elements = world_.getMaxElements(); world_.resizeIndex(max_elements * 2); VLOG(1) << "Resizing HNSW Index from " << max_elements << " to " << max_elements * 2; } } catch (const std::exception& e) { LOG(FATAL) << "HnswlibAdapter::ResizeIfFull exception: " << e.what(); } } template static vector> QueueToVec(Q queue) { vector> out(queue.size()); size_t idx = out.size(); while (!queue.empty()) { out[--idx] = queue.top(); queue.pop(); } return out; } public: // Restore HNSW graph structure from serialized nodes with metadata void RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata) { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); absl::WriterMutexLock resize_lock(&resize_mutex_); if (nodes.empty()) { return; } // RestoreFromNodes is only called during deserialization on a freshly created index. // Assert the index is empty to prevent memory leaks from double-allocation of linkLists_. DCHECK_EQ(world_.cur_element_count.load(), 0u) << "RestoreFromNodes should only be called on an empty index during deserialization"; // Ensure we have enough capacity. // Metadata may have been captured before the snapshot read-lock, so // cur_element_count can be smaller than actual node internal_ids when // concurrent writes happen. Compute the real requirement from nodes. size_t max_internal_id = 0; for (const auto& node : nodes) { max_internal_id = std::max(max_internal_id, node.internal_id); } size_t required_capacity = std::max(metadata.cur_element_count, max_internal_id + 1); if (world_.max_elements_ < required_capacity) { world_.resizeIndex(required_capacity); } // Restore each node - directly set up memory and fields size_t restored_count = 0; for (const auto& node : nodes) { size_t internal_id = node.internal_id; // Validate internal_id is within bounds - invalid internal_id indicates corrupted data CHECK(internal_id < world_.max_elements_); // Register label in lookup table world_.label_lookup_[node.global_id] = internal_id; // Set the level world_.element_levels_[internal_id] = node.level; // Clear level 0 memory and set label. // Memory layout: each element occupies size_data_per_element_ bytes starting at // data_level0_memory_ + internal_id * size_data_per_element_. // offsetLevel0_ is always 0, so we clear exactly one element's worth of data. // This matches the pattern in hnswlib's addPoint(). memset(world_.data_level0_memory_ + internal_id * world_.size_data_per_element_, 0, world_.size_data_per_element_); world_.setExternalLabel(internal_id, node.global_id); // In copy mode, zero the vector memory so distance computations don't use // uninitialized data for nodes that are marked deleted. if (world_.copy_vector_) { char* data_ptr = world_.data_vector_memory_ + internal_id * world_.data_size_; memset(data_ptr, 0, world_.data_size_); } // Allocate upper layer links if needed if (node.level > 0) { world_.linkLists_[internal_id] = (char*)mi_malloc(world_.size_links_per_element_ * node.level + 1); memset(world_.linkLists_[internal_id], 0, world_.size_links_per_element_ * node.level + 1); } // Restore links for layer 0 if (!node.levels_links.empty()) { auto* ll0 = world_.get_linklist0(internal_id); world_.setListCount(ll0, node.levels_links[0].size()); auto* links0 = reinterpret_cast(ll0 + 1); std::copy(node.levels_links[0].begin(), node.levels_links[0].end(), links0); } // Restore links for upper layers for (int lvl = 1; lvl <= node.level && lvl < static_cast(node.levels_links.size()); ++lvl) { auto* ll = world_.get_linklist(internal_id, lvl); world_.setListCount(ll, node.levels_links[lvl].size()); auto* links = reinterpret_cast(ll + 1); std::copy(node.levels_links[lvl].begin(), node.levels_links[lvl].end(), links); } // Track restored count so markDeletedInternal can validate internal_id bounds. world_.cur_element_count.store(++restored_count); // Mark node as deleted until UpdateVectorData provides valid vector data. // This prevents crashes from dereferencing uninitialised data pointers // (especially in borrowed-vector mode). world_.markDeletedInternal(internal_id); } // Set the metadata for the graph world_.maxlevel_ = metadata.maxlevel; world_.enterpoint_node_ = metadata.enterpoint_node; VLOG(1) << "Restored HNSW index with " << restored_count << " nodes, maxlevel=" << metadata.maxlevel << ", enterpoint=" << metadata.enterpoint_node; } // Update vector data for an existing node (used after RestoreFromNodes). // Returns false if the node doesn't exist in the index. bool UpdateVectorData(GlobalDocId id, const void* data) { TryProcessDeferred(); MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); // Find the internal id for this label auto it = world_.label_lookup_.find(id); if (it == world_.label_lookup_.end()) { VLOG(1) << "UpdateVectorData: label " << id << " not found in index"; return false; } size_t internal_id = it->second; // Copy/store the vector data based on copy_vector_ mode if (world_.copy_vector_) { // Owned mode: copy data into world's vector memory char* data_ptr = world_.data_vector_memory_ + internal_id * world_.data_size_; memcpy(data_ptr, data, world_.data_size_); } else { // Borrowed mode: store pointer to external data char* ptr_location = world_.getDataPtrByInternalId(internal_id); memcpy(ptr_location, &data, sizeof(void*)); } // Unmark deleted so the node participates in KNN searches now that it // has valid vector data. During RestoreFromNodes all nodes are marked // deleted by default to prevent dereferencing uninitialised data. if (world_.isMarkedDeleted(internal_id)) { world_.unmarkDeletedInternal(internal_id); } return true; } std::unique_ptr GetReadLock() const { return std::make_unique(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); } private: HnswSpace space_; HierarchicalNSW world_; absl::Mutex resize_mutex_; mutable MRMWMutex mrmw_mutex_; bool copy_vector_; // Whether vectors are copied into hnswlib. size_t data_size_; // Byte size of a single vector. mutable base::SpinLock deferred_mu_; // Protects deferred_ops_. absl::flat_hash_map deferred_ops_; // GUARDED_BY(deferred_mu_) }; HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, bool copy_vector, PMR_NS::memory_resource*) : copy_vector_(copy_vector), dim_{params.dim}, adapter_{make_unique(params, copy_vector)} { DCHECK(params.use_hnsw); // TODO: Patch hnsw to use MR } HnswVectorIndex::~HnswVectorIndex() { } bool HnswVectorIndex::Add(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) { auto vector_ptr = doc.GetVector(field, dim_); if (!vector_ptr) { return false; } const void* data = nullptr; if (std::holds_alternative(*vector_ptr)) { data = std::get(*vector_ptr).first.get(); } else { data = std::get(*vector_ptr); } if (!data) { return false; } adapter_->Add(data, id); return true; } std::vector> HnswVectorIndex::Knn(float* target, size_t k, std::optional ef) const { return adapter_->Knn(target, k, ef); } std::vector> HnswVectorIndex::Knn( float* target, size_t k, std::optional ef, const std::vector& allowed) const { return adapter_->Knn(target, k, ef, allowed); } std::vector> HnswVectorIndex::SubsetKnn( float* target, size_t k, const std::vector& docs) const { return adapter_->SubsetKnn(target, k, docs); } std::vector> HnswVectorIndex::RangeQuery(float* target, float radius) const { return adapter_->RangeSearch(target, radius); } void HnswVectorIndex::Remove(GlobalDocId id, const DocumentAccessor& doc, string_view field) { adapter_->Remove(id); } void HnswVectorIndex::Remove(GlobalDocId id) { adapter_->Remove(id); } HnswIndexMetadata HnswVectorIndex::GetMetadata() const { return adapter_->GetMetadata(); } void HnswVectorIndex::SetMetadata(const HnswIndexMetadata& metadata) { adapter_->SetMetadata(metadata); } size_t HnswVectorIndex::GetNodeCount() const { return adapter_->GetNodeCount(); } std::vector HnswVectorIndex::GetNodesRange(size_t start, size_t end) const { return adapter_->GetNodesRange(start, end); } void HnswVectorIndex::RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata) { adapter_->RestoreFromNodes(nodes, metadata); } bool HnswVectorIndex::UpdateVectorData(GlobalDocId id, const DocumentAccessor& doc, std::string_view field) { auto vector_ptr = doc.GetVector(field, dim_); if (!vector_ptr || *vector_ptr == search::DocumentAccessor::VectorInfo(search::BorrowedFtVector(nullptr))) { // Document doesn't have the vector field - mark node as deleted to prevent // "ghost" nodes with invalid vector data from participating in searches LOG(WARNING) << "UpdateVectorData: document " << id << " missing vector field, marking node as deleted in HNSW index"; adapter_->Remove(id); return false; } const void* data = nullptr; if (std::holds_alternative(*vector_ptr)) { data = std::get(*vector_ptr).first.get(); } else { data = std::get(*vector_ptr); } return adapter_->UpdateVectorData(id, data); } std::unique_ptr HnswVectorIndex::GetReadLock() const { return adapter_->GetReadLock(); } } // namespace dfly::search ================================================ FILE: src/core/search/hnsw_index.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "core/search/mrmw_mutex.h" #include "core/search/search.h" namespace dfly::search { // Metadata structure for HNSW index serialization // Contains the key parameters needed to restore the index state struct HnswIndexMetadata { size_t max_elements = 0; // Maximum number of elements the index can hold // Note: cur_element_count may be smaller than actual node count during concurrent writes, // so we compute the real requirement from nodes during restoration. // TODO: consider removing it from metadata and rely entirely on node data for restoration. size_t cur_element_count = 0; // Current number of elements in the index int maxlevel = -1; // Maximum level of the graph size_t enterpoint_node = 0; // Entry point node for the graph }; // Node data structure for HNSW serialization struct HnswNodeData { uint32_t internal_id; GlobalDocId global_id; int level; std::vector> levels_links; // Links for each level (0 to level) // Returns the total serialized size in bytes. // Format: internal_id(4) + global_id(8) + level(4) // + for each level: links_num(4) + links(4 each) size_t TotalSize() const { size_t size = 4 + 8 + 4; // internal_id + global_id + level for (const auto& links : levels_links) { size += 4 + links.size() * 4; // links_num + links } return size; } }; struct HnswlibAdapter; class HnswVectorIndex { public: explicit HnswVectorIndex(const search::SchemaField::VectorParams& params, bool copy_vector, PMR_NS::memory_resource* mr = PMR_NS::get_default_resource()); ~HnswVectorIndex(); bool Add(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field); void Remove(search::GlobalDocId id, const search::DocumentAccessor& doc, std::string_view field); void Remove(search::GlobalDocId id); bool IsVectorCopied() const { return copy_vector_; } std::vector> Knn(float* target, size_t k, std::optional ef) const; std::vector> Knn(float* target, size_t k, std::optional ef, const std::vector& allowed) const; std::vector> SubsetKnn(float* target, size_t k, const std::vector& docs) const; // Returns all documents within radius, with their distances. std::vector> RangeQuery(float* target, float radius) const; size_t GetDim() const { return dim_; } // Get metadata for serialization HnswIndexMetadata GetMetadata() const; // Set metadata (used during restoration) void SetMetadata(const HnswIndexMetadata& metadata); // Get total number of nodes in the index size_t GetNodeCount() const; // Get nodes in the specified range [start, end) // Returns vector of node data for serialization std::vector GetNodesRange(size_t start, size_t end) const; // Restore graph structure from serialized nodes with metadata // This restores the HNSW graph links but NOT the vector data // Vector data must be populated separately via UpdateVectorData void RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata); // Update vector data for an existing node (used after RestoreFromNodes) // This populates the vector data for a node that already has graph links bool UpdateVectorData(GlobalDocId id, const DocumentAccessor& doc, std::string_view field); // Acquire a read lock on the internal MRMW mutex. // Use this during serialization to block concurrent Add/Remove (write) operations. std::unique_ptr GetReadLock() const; private: bool copy_vector_; size_t dim_; std::unique_ptr adapter_; }; } // namespace dfly::search ================================================ FILE: src/core/search/index_result.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "core/search/ast_expr.h" #include "core/search/block_list.h" #include "core/search/range_tree.h" namespace dfly::search { // Represents an either owned or non-owned result set that can be accessed and merged transparently. class IndexResult { private: using DocVec = std::vector; using Variant = std::variant*, const BlockList>*, RangeResult>; template using VariantOfConstPtrs = std::variant; using BorrowedView = VariantOfConstPtrs, BlockList>, SingleBlockRangeResult, TwoBlocksRangeResult>; public: IndexResult() = default; explicit IndexResult(Variant value); template explicit IndexResult(const Container* container = nullptr); /* It will return approximate size of the result set. Actual result can be smaller than the size returned by this method. */ size_t ApproximateSize() const; BorrowedView Borrowed() const; // Move out of owned or copy borrowed. Take up to `limit` entries and return original size. std::pair Take(size_t limit = std::numeric_limits::max()); private: bool IsOwned() const; Variant value_; }; std::vector MergeIndexResults(const IndexResult& left, const IndexResult& right, AstLogicalNode::LogicOp op); // Implementation /******************************************************************/ inline IndexResult::IndexResult(Variant value) : value_{std::move(value)} { } template IndexResult::IndexResult(const Container* container) : value_{container} { if (container == nullptr) { value_ = DocVec{}; } } inline size_t IndexResult::ApproximateSize() const { return std::visit([](auto* set) { return set->size(); }, Borrowed()); } inline IndexResult::BorrowedView IndexResult::Borrowed() const { auto cb = [](const auto& v) -> BorrowedView { using T = std::decay_t; if constexpr (std::is_pointer_v>) { return v; } else if constexpr (std::is_same_v) { auto range_cb = [](const auto& set) -> BorrowedView { return &set; }; return std::visit(range_cb, v.GetResult()); } else { return &v; } }; return std::visit(cb, value_); } inline std::pair IndexResult::Take(size_t limit) { if (IsOwned()) { auto& vec = std::get(value_); size_t size = vec.size(); return {std::move(vec), size}; } // Numeric ranges need to be filtered and don't know their exact size ahead if (std::holds_alternative(value_)) { auto cb = [limit](auto* range) -> std::pair { DocVec out; size_t total = 0; out.reserve(std::min(limit, range->size())); for (auto it = range->begin(); it != range->end(); ++it) { total++; if (out.size() < limit) out.push_back(*it); } return {std::move(out), total}; }; return std::visit(cb, Borrowed()); } // Generic borrowed results sets don't need to be filtered, so we can tell the result size ahead auto cb = [limit](auto* set) -> std::pair { DocVec out; out.reserve(std::min(limit, set->size())); for (auto it = set->begin(); it != set->end() && out.size() < limit; ++it) out.push_back(*it); return {std::move(out), set->size()}; }; return std::visit(cb, Borrowed()); } inline bool IndexResult::IsOwned() const { return std::holds_alternative(value_); } namespace details { using BackInserter = std::back_insert_iterator>; template constexpr bool IsSeekableIterator = std::is_base_of_v; template void Seek(DocId min_doc_id, const Iterator& end, Iterator* it) { static constexpr DocId kFastSeekThreshold = 15; auto extract_doc_id = [](const auto& value) { using T = std::decay_t; if constexpr (std::is_same_v) { return value; } else { return value.first; } }; DocId current_value = extract_doc_id(**it); DCHECK(current_value < min_doc_id); if (min_doc_id - current_value > kFastSeekThreshold) { // If the gap is large, use a fast seek if constexpr (IsSeekableIterator) { it->SeekGE(min_doc_id); } else { BasicSeekGE(min_doc_id, end, it); } } else { // If the gap is small, just iterate do { ++(*it); } while (*it != end && extract_doc_id(**it) < min_doc_id); } } template void SetIntersection(FirstIterator first_begin, FirstIterator first_end, SecondIterator second_begin, SecondIterator second_end, BackInserter out) { auto l_it = first_begin; auto r_it = second_begin; while (l_it != first_end && r_it != second_end) { DocId l_value = *l_it; DocId r_value = *r_it; if (l_value == r_value) { *out++ = l_value; ++l_it; if (l_it != first_end) { Seek(*l_it, second_end, &r_it); } } else if (l_value < r_value) { Seek(r_value, first_end, &l_it); } else { DCHECK(l_value > r_value); Seek(l_value, second_end, &r_it); } } } } // namespace details inline std::vector MergeIndexResults(const IndexResult& left, const IndexResult& right, AstLogicalNode::LogicOp op) { std::vector result; if (op == AstLogicalNode::LogicOp::AND) { result.reserve(std::min(left.ApproximateSize(), right.ApproximateSize())); auto cb = [&result](auto* s1, auto* s2) { details::SetIntersection(s1->begin(), s1->end(), s2->begin(), s2->end(), std::back_inserter(result)); }; std::visit(cb, left.Borrowed(), right.Borrowed()); } else { result.reserve(std::max(left.ApproximateSize(), right.ApproximateSize())); auto cb = [&result](auto* s1, auto* s2) { std::set_union(s1->begin(), s1->end(), s2->begin(), s2->end(), std::back_inserter(result)); }; std::visit(cb, left.Borrowed(), right.Borrowed()); } return result; } } // namespace dfly::search ================================================ FILE: src/core/search/indices.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/indices.h" #include #include #include #include #include #include #include #define UNI_ALGO_DISABLE_NFKC_NFKD #include #include #include #include #include #include "base/flags.h" ABSL_FLAG(bool, use_numeric_range_tree, true, "Use range tree for numeric index. " "If false, use a simple implementation with btree_set. " "Range tree is more memory efficient and faster for range queries, " "but slower for single value queries."); namespace dfly::search { using namespace std; using cmn::StringOrView; namespace { bool IsAllAscii(string_view sv) { return all_of(sv.begin(), sv.end(), [](unsigned char c) { return isascii(c); }); } string ToLower(string_view word) { return IsAllAscii(word) ? absl::AsciiStrToLower(word) : una::cases::to_lowercase_utf8(word); } // Get all words from text as matched by the ICU library absl::flat_hash_set TokenizeWords(std::string_view text, const TextIndex::StopWords& stopwords, const Synonyms* synonyms) { absl::flat_hash_set words; for (std::string_view word : una::views::word_only::utf8(text)) { if (std::string word_lc = una::cases::to_lowercase_utf8(word); !stopwords.contains(word_lc)) { if (synonyms) { if (auto group_id = synonyms->GetGroupToken(word_lc); group_id) { words.insert(*group_id); } } words.insert(std::move(word_lc)); } } return words; } // Split taglist, remove duplicates and convert all to lowercase absl::flat_hash_set NormalizeTags(string_view taglist, bool case_sensitive, char separator) { // Splitting utf8 by ascii character is safe absl::flat_hash_set tags; for (string_view tag : absl::StrSplit(taglist, separator, absl::SkipEmpty())) { string_view str = absl::StripAsciiWhitespace(tag); if (case_sensitive) tags.insert(string{str}); else tags.insert(ToLower(str)); } return tags; } // Iterate over all suffixes of all words void IterateAllSuffixes(const absl::flat_hash_set& words, absl::FunctionRef cb) { for (string_view word : words) { for (size_t offs = 0; offs < word.length(); offs++) { cb(word.substr(offs)); } } } // Haversine with earth radius in meters. Used to calculate distance. boost::geometry::strategy::distance::haversine haversine_(6372797.560856); double ConvertToRadiusInMeters(size_t radius, std::string_view arg) { const std::string unit = absl::AsciiStrToUpper(arg); if (unit == "M") { return radius * 1; } else if (unit == "KM") { return radius * 1000; } else if (unit == "FT") { return radius * 0.3048; } else if (unit == "MI") { return radius * 1609.34; } else { return -1; } } // Verify if geo string is valid and convert to point std::optional GetGeoPoint(const string_view& geo_string) { // Empty geo string if (geo_string.empty()) return nullopt; absl::InlinedVector coordinates = absl::StrSplit(geo_string, ","); // Invalid coordinate format if (coordinates.size() != 2) return std::nullopt; // Convert coordinates to double double lon, lat; if (!absl::SimpleAtod(coordinates[0], &lon) || !absl::SimpleAtod(coordinates[1], &lat)) return nullopt; // Verify that coordinates are within valid ranges if (lon < -180 || lon > 180 || lat < -90 || lat > 90) return nullopt; return GeoIndex::point{lon, lat}; } }; // namespace class RangeTreeAdapter : public NumericIndex::RangeTreeBase { public: explicit RangeTreeAdapter(size_t max_range_block_size, PMR_NS::memory_resource* mr) : range_tree_{mr, max_range_block_size}, builder_{RangeTree::Builder{}} { } void Add(DocId id, absl::Span values) override { for (double value : values) { if (builder_) builder_->Add(id, value); else range_tree_.Add(id, value); } } void Remove(DocId id, absl::Span values) override { for (double value : values) { if (builder_) builder_->Remove(id, value); else range_tree_.Remove(id, value); } } RangeResult Range(double l, double r) const override { return range_tree_.Range(l, r); } vector GetAllDocIds() const override { // TODO: remove take return range_tree_.GetAllDocIds().Take(); } void FinalizeInitialization() override { builder_->Populate(&range_tree_, {500}); builder_.reset(); } private: RangeTree range_tree_; std::optional builder_; }; class BtreeSetImpl : public NumericIndex::RangeTreeBase { public: explicit BtreeSetImpl(PMR_NS::memory_resource* mr) : entries_(mr) { } void Add(DocId id, absl::Span values) override { if (values.size() > 1) { unique_ids_ = false; } for (double value : values) { entries_.insert({value, id}); } } void Remove(DocId id, absl::Span values) override { for (double value : values) { entries_.erase({value, id}); } } RangeResult Range(double l, double r) const override { DCHECK(l <= r); auto it_l = entries_.lower_bound({l, 0}); auto it_r = entries_.lower_bound({r, numeric_limits::max()}); DCHECK_GE(it_r - it_l, 0); vector out; for (auto it = it_l; it != it_r; ++it) out.push_back(it->second); sort(out.begin(), out.end()); if (!unique_ids_) { out.erase(unique(out.begin(), out.end()), out.end()); } return RangeResult(std::move(out)); } vector GetAllDocIds() const override { std::vector result; result.reserve(entries_.size()); if (unique_ids_) { // If unique_ids_ is true, we can just take the second element of each entry for (const auto& [_, doc_id] : entries_) { result.push_back(doc_id); } } else { absl::flat_hash_set unique_docs; unique_docs.reserve(entries_.size()); for (const auto& [_, doc_id] : entries_) { const auto [__, is_new] = unique_docs.insert(doc_id); if (is_new) { result.push_back(doc_id); } } } std::sort(result.begin(), result.end()); return result; } private: bool unique_ids_ = true; // If true, docs ids are unique in the index, otherwise they can repeat. using Entry = std::pair; absl::btree_set, PMR_NS::polymorphic_allocator> entries_; }; NumericIndex::NumericIndex(size_t max_range_block_size, PMR_NS::memory_resource* mr) { if (absl::GetFlag(FLAGS_use_numeric_range_tree)) { range_tree_ = make_unique(max_range_block_size, mr); } else { range_tree_ = make_unique(mr); } } bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) { auto numbers = doc.GetNumbers(field); if (!numbers) { return false; } range_tree_->Add(id, absl::MakeSpan(numbers.value())); return true; } void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { auto numbers = doc.GetNumbers(field).value(); range_tree_->Remove(id, absl::MakeSpan(numbers)); } void NumericIndex::FinalizeInitialization() { range_tree_->FinalizeInitialization(); } RangeResult NumericIndex::Range(double l, double r) const { if (r < l) return {}; return range_tree_->Range(l, r); } vector NumericIndex::GetAllDocsWithNonNullValues() const { return range_tree_->GetAllDocIds(); } template BaseStringIndex::BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive, bool with_suffix) : case_sensitive_{case_sensitive}, entries_{mr} { if (with_suffix) suffix_trie_.emplace(mr); } template const typename BaseStringIndex::Container* BaseStringIndex::Matching( string_view word, bool strip_whitespace) const { if (strip_whitespace) word = absl::StripAsciiWhitespace(word); auto it = entries_.find(NormalizeQueryWord(word).view()); return (it != entries_.end()) ? &it->second : nullptr; } template void BaseStringIndex::MatchPrefix(std::string_view prefix, absl::FunctionRef cb) const { StringOrView prefix_norm{NormalizeQueryWord(prefix)}; prefix = prefix_norm.view(); // TODO(vlad): Use right iterator to avoid string comparison? for (auto it = entries_.lower_bound(prefix); it != entries_.end() && (*it).first.rfind(prefix, 0) == 0; ++it) { cb(&(*it).second); } } template void BaseStringIndex::MatchSuffix(std::string_view suffix, absl::FunctionRef cb) const { StringOrView suffix_norm{NormalizeQueryWord(suffix)}; suffix = suffix_norm.view(); // If we have a suffix trie built, we just need to fetch the relevant suffix if (suffix_trie_) { auto it = suffix_trie_->find(suffix); cb((it != suffix_trie_->end()) ? &it->second : nullptr); return; } // Otherwise, iterate over all entries and look for the suffix for (const auto& entry : entries_) { int32_t start = entry.first.size() - suffix.size(); if (start >= 0 && entry.first.substr(start) == suffix) cb(&entry.second); } } template void BaseStringIndex::MatchInfix(std::string_view infix, absl::FunctionRef cb) const { StringOrView infix_norm{NormalizeQueryWord(infix)}; infix = infix_norm.view(); // If we have a suffix trie built, we just need to match the prefix if (suffix_trie_) { for (auto it = suffix_trie_->lower_bound(infix); it != suffix_trie_->end() && (*it).first.rfind(infix, 0) == 0; ++it) cb(&(*it).second); return; } // Otherwise, iterate over all entries and check if it contains the entry for (const auto& entry : entries_) { if (entry.first.find(infix) != string::npos) cb(&entry.second); } } template bool BaseStringIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) { auto strings_list = GetStrings(doc, field); if (!strings_list) { return false; } absl::flat_hash_set tokens; for (string_view str : strings_list.value()) tokens.merge(Tokenize(str)); if (tokens.size() > 1) unique_ids_ = false; for (string_view token : tokens) GetOrCreate(&entries_, token)->Insert(id); if (suffix_trie_) IterateAllSuffixes(tokens, [&](string_view str) { GetOrCreate(&*suffix_trie_, str)->Insert(id); }); return true; } template void BaseStringIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { auto strings_list = GetStrings(doc, field).value(); absl::flat_hash_set tokens; for (string_view str : strings_list) tokens.merge(Tokenize(str)); for (string_view token : tokens) Remove(&entries_, id, token); if (suffix_trie_) IterateAllSuffixes(tokens, [&](string_view str) { Remove(&*suffix_trie_, id, str); }); } template vector BaseStringIndex::GetTerms() const { vector res; res.reserve(entries_.size()); for (const auto& [term, _] : entries_) { res.push_back(string{term}); } return res; } template vector BaseStringIndex::GetAllDocsWithNonNullValues() const { std::vector result; result.reserve(entries_.size()); if (unique_ids_) { // If unique_ids_ is true, we can just take the second element of each entry for (const auto& [_, container] : entries_) { for (const auto& doc_id : container) { result.push_back(doc_id); } } } else { absl::flat_hash_set unique_docs; unique_docs.reserve(entries_.size()); for (const auto& [_, container] : entries_) { for (const auto& doc_id : container) { auto [_, is_new] = unique_docs.insert(doc_id); if (is_new) { result.push_back(doc_id); } } } } std::sort(result.begin(), result.end()); return result; } template StringOrView BaseStringIndex::NormalizeQueryWord(std::string_view query) const { if (case_sensitive_) return StringOrView::FromView(query); return StringOrView::FromString(ToLower(query)); } template typename BaseStringIndex::Container* BaseStringIndex::GetOrCreate( search::RaxTreeMap* map, string_view word) { auto* mr = map->get_allocator().resource(); return &map->try_emplace(PMR_NS::string{word, mr}, mr, 1000 /* block size */).first->second; } template void BaseStringIndex::Remove(search::RaxTreeMap* map, DocId id, string_view word) { auto it = map->find(word); if (it == map->end()) return; it->second.Remove(id); if (it->second.Size() == 0) map->erase(it); } template struct BaseStringIndex; template struct BaseStringIndex>; TextIndex::TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords, const Synonyms* synonyms, bool with_suffixtrie) : BaseStringIndex(mr, false, with_suffixtrie), stopwords_{stopwords}, synonyms_{synonyms} { } std::optional TextIndex::GetStrings(const DocumentAccessor& doc, std::string_view field) const { return doc.GetStrings(field); } absl::flat_hash_set TextIndex::Tokenize(std::string_view value) const { return TokenizeWords(value, *stopwords_, synonyms_); } DefragmentResult TagIndex::Defragment(PageUsage* page_usage) { auto defrag = [&](auto& tree, string* key) { DefragmentMap dm{tree, key}; return dm.Defragment(page_usage); }; DefragmentResult result = defrag(entries_, &next_defrag_entry_); if (suffix_trie_) { result.Merge(defrag(suffix_trie_.value(), &next_defrag_suffix_entry_)); } return result; } std::optional TagIndex::GetStrings(const DocumentAccessor& doc, std::string_view field) const { return doc.GetTags(field); } absl::flat_hash_set TagIndex::Tokenize(std::string_view value) const { return NormalizeTags(value, case_sensitive_, separator_); } BaseVectorIndex::BaseVectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim} { } std::pair BaseVectorIndex::Info() const { return {dim_, sim_}; } bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { auto vector = doc.GetVector(field, dim_); if (!vector) return false; if (std::holds_alternative(*vector)) { const auto& owned_vector = std::get(*vector); AddVector(id, owned_vector.first.get()); } else { const auto& borrowed_vector = std::get(*vector); AddVector(id, borrowed_vector); } return true; } // Each document occupies (dim_ + 1) floats in entries_: dim_ floats for the vector data, // followed by one float as a presence marker (1.0 = present, 0.0 = absent/removed). // This avoids the previous heuristic of treating all-zero vectors as null. static constexpr float kPresent = 1.0f; static constexpr float kAbsent = 0.0f; FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr) : BaseVectorIndex{params.dim, params.sim}, entries_{mr} { DCHECK(!params.use_hnsw); entries_.reserve(params.capacity * (params.dim + 1)); } void FlatVectorIndex::AddVector(DocId id, const void* vector) { const size_t stride = dim_ + 1; DCHECK_LE(id * stride, entries_.size()); if (id * stride == entries_.size()) entries_.resize((id + 1) * stride, 0.0f); if (vector) { memcpy(&entries_[id * stride], vector, dim_ * sizeof(float)); entries_[id * stride + dim_] = kPresent; } } void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { const size_t stride = dim_ + 1; if (id * stride + dim_ < entries_.size()) entries_[id * stride + dim_] = kAbsent; } const float* FlatVectorIndex::Get(DocId doc) const { const size_t stride = dim_ + 1; if (doc * stride + dim_ >= entries_.size() || entries_[doc * stride + dim_] != kPresent) return nullptr; return &entries_[doc * stride]; } std::vector FlatVectorIndex::GetAllDocsWithNonNullValues() const { const size_t stride = dim_ + 1; size_t num_slots = entries_.size() / stride; std::vector result; result.reserve(num_slots); for (DocId id = 0; id < num_slots; ++id) { if (entries_[id * stride + dim_] == kPresent) result.push_back(id); } return result; } GeoIndex::GeoIndex(PMR_NS::memory_resource* mr) : rtree_(make_unique()) { } GeoIndex::~GeoIndex() { } bool GeoIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { auto geo_string = doc.GetStrings(field); if (!geo_string) { return false; } // If field doesn't exists don't add to index. if (geo_string->empty()) { return true; } std::vector points; for (string_view str : *geo_string) { auto doc_point = GetGeoPoint(str); if (!doc_point) { return false; } points.emplace_back(*doc_point); } for (point p : points) { rtree_->insert({p, id}); } return true; } void GeoIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { auto geo_string = doc.GetStrings(field); if (!geo_string || geo_string->empty()) { return; } std::vector points; for (string_view str : *geo_string) { auto doc_point = GetGeoPoint(str); if (!doc_point) { return; } points.emplace_back(*doc_point); } for (point p : points) { rtree_->remove({p, id}); } } std::vector GeoIndex::RadiusSearch(double lon, double lat, double radius, std::string_view unit) { std::set unique_results; // Get radius in meters double converted_radius = ConvertToRadiusInMeters(radius, unit); // Declare the geographic_point_circle strategy with 4 points boost::geometry::strategy::buffer::geographic_point_circle<> point_strategy(4); // Declare the distance strategy in meters around the point boost::geometry::strategy::buffer::distance_symmetric distance_strategy(converted_radius); // Declare other necessary strategies, unused for point boost::geometry::strategy::buffer::join_round join_strategy; boost::geometry::strategy::buffer::end_round end_strategy; boost::geometry::strategy::buffer::side_straight side_strategy; point p{lon, lat}; // Create polygon with 4 point around point boost::geometry::model::multi_polygon> buffer_polygon; boost::geometry::buffer(p, buffer_polygon, distance_strategy, side_strategy, join_strategy, end_strategy, point_strategy); // Create bouding box around polygon to include all possible points boost::geometry::model::box box; boost::geometry::envelope(buffer_polygon, box); rtree_->query(boost::geometry::index::within(box), boost::make_function_output_iterator( [&unique_results, &p, &converted_radius](auto const& val) { if (haversine_.apply(val.first, p) <= converted_radius) { unique_results.insert(val.second); } })); // TODO: we should return sorted results by radius distance return {unique_results.begin(), unique_results.end()}; } std::vector GeoIndex::GetAllDocsWithNonNullValues() const { std::set unique_results; std::for_each(boost::geometry::index::begin(*rtree_), boost::geometry::index::end(*rtree_), [&unique_results](auto const& val) { unique_results.insert(val.second); }); return {unique_results.begin(), unique_results.end()}; } } // namespace dfly::search ================================================ FILE: src/core/search/indices.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include // Wrong warning reported when geometry.hpp is loaded #ifndef __clang__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif #include #ifndef __clang__ #pragma GCC diagnostic pop #endif #include #include #include #include #include "base/pmr/memory_resource.h" #include "core/page_usage/page_usage_stats.h" #include "core/search/base.h" #include "core/search/block_list.h" #include "core/search/compressed_sorted_set.h" #include "core/search/range_tree.h" #include "core/search/rax_tree.h" // TODO: move core field definitions out of big header #include "common/string_or_view.h" #include "core/search/search.h" namespace dfly::search { // Index for integer fields. // Range bounds are queried in logarithmic time, iteration is constant. struct NumericIndex : public BaseIndex { // Temporary base class for range tree. // It is used to use two different range trees depending on the flag use_range_tree. // If the flag is true, RangeTree is used, otherwise a simple implementation with btree_set. struct RangeTreeBase { virtual void Add(DocId id, absl::Span values) = 0; virtual void Remove(DocId id, absl::Span values) = 0; // Returns all DocIds that match the range [l, r]. virtual RangeResult Range(double l, double r) const = 0; // Returns all DocIds that have non-null values in the index. virtual std::vector GetAllDocIds() const = 0; virtual void FinalizeInitialization(){}; virtual ~RangeTreeBase() = default; }; // max_range_block_size is the maximum number of entries in a single range block. // It is used in RangeTree. Check RangeTree for details. explicit NumericIndex(size_t max_range_block_size, PMR_NS::memory_resource* mr); bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; void FinalizeInitialization() override; RangeResult Range(double l, double r) const; std::vector GetAllDocsWithNonNullValues() const override; private: std::unique_ptr range_tree_; }; // Base index for string based indices. template struct BaseStringIndex : public BaseIndex { using Container = BlockList; using VecOrPtr = std::variant, const Container*>; BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive, bool with_suffixtrie); bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; // Pointer is valid as long as index is not mutated. Nullptr if not found const Container* Matching(std::string_view str, bool strip_whitespace = true) const; // Iterate over all nodes matching on prefix. void MatchPrefix(std::string_view prefix, absl::FunctionRef cb) const; // Iterate over all nodes matching suffix query. Faster if suffix trie is built. void MatchSuffix(std::string_view suffix, absl::FunctionRef cb) const; // Iterate over all nodes matching infix query. Faster if suffix trie is built. void MatchInfix(std::string_view prefix, absl::FunctionRef cb) const; // Returns all the terms that appear as keys in the reverse index. std::vector GetTerms() const; std::vector GetAllDocsWithNonNullValues() const override; protected: using StringList = DocumentAccessor::StringList; // Used by Add & Remove to get strings from document virtual std::optional GetStrings(const DocumentAccessor& doc, std::string_view field) const = 0; // Used by Add & Remove to tokenize text value virtual absl::flat_hash_set Tokenize(std::string_view value) const = 0; cmn::StringOrView NormalizeQueryWord(std::string_view word) const; static Container* GetOrCreate(search::RaxTreeMap* map, std::string_view word); static void Remove(search::RaxTreeMap* map, DocId id, std::string_view word); bool case_sensitive_ = false; bool unique_ids_ = true; // If true, docs ids are unique in the index, otherwise they can repeat. search::RaxTreeMap entries_; std::optional> suffix_trie_; }; // Index for text fields. // Hashmap based lookup per word. struct TextIndex : public BaseStringIndex { using StopWords = absl::flat_hash_set; TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords, const Synonyms* synonyms, bool with_suffixtrie); protected: std::optional GetStrings(const DocumentAccessor& doc, std::string_view field) const override; absl::flat_hash_set Tokenize(std::string_view value) const override; private: const StopWords* stopwords_; const Synonyms* synonyms_; }; // Index for text fields. // Hashmap based lookup per word. struct TagIndex : public BaseStringIndex> { TagIndex(PMR_NS::memory_resource* mr, SchemaField::TagParams params) : BaseStringIndex(mr, params.case_sensitive, params.with_suffixtrie), separator_{params.separator} { } DefragmentResult Defragment(PageUsage* page_usage) override; protected: std::optional GetStrings(const DocumentAccessor& doc, std::string_view field) const override; absl::flat_hash_set Tokenize(std::string_view value) const override; private: char separator_; std::string next_defrag_entry_; std::string next_defrag_suffix_entry_; }; struct BaseVectorIndex : public BaseIndex { std::pair Info() const; bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final; protected: BaseVectorIndex(size_t dim, VectorSimilarity sim); virtual void AddVector(DocId id, const void* vector) = 0; size_t dim_; VectorSimilarity sim_; }; // Index for vector fields. // Only supports lookup by id. struct FlatVectorIndex : public BaseVectorIndex { FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; const float* Get(DocId doc) const; // Return all documents that have vectors in this index std::vector GetAllDocsWithNonNullValues() const override; protected: void AddVector(DocId id, const void* vector) override; private: PMR_NS::vector entries_; }; struct GeoIndex : public BaseIndex { using point = boost::geometry::model::point>; using index_entry = std::pair; explicit GeoIndex(PMR_NS::memory_resource* mr); ~GeoIndex(); bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; std::vector RadiusSearch(double lon, double lat, double radius, std::string_view arg); std::vector GetAllDocsWithNonNullValues() const override; private: using rtree = boost::geometry::index::rtree>; std::unique_ptr rtree_; }; // Defragments a map like data structure. The values in the map must have a `Defragment` method. // Works with rax tree map and hash based maps template struct DefragmentMap { using ValueType = Container::value_type; using Iterator = Container::iterator; DefragmentMap(Container& container, std::string* key) : key{key} { if (key->empty()) { it = container.end(); } else if constexpr (requires { container.lower_bound(*key); }) { it = container.lower_bound(*key); } else { it = container.find(*key); } if (it == container.end()) { it = container.begin(); } end = container.end(); } // The key is set if the defragmentation has to stop mid way due to depleted quota DefragmentResult Defragment(PageUsage* page_usage) { if (page_usage->QuotaDepleted()) { return DefragmentResult{.quota_depleted = true, .objects_moved = 0}; } DefragmentResult result; for (; it != end; ++it) { const auto& [k, map] = *it; if (result.Merge(DefragmentIndex(map, page_usage)).quota_depleted) { *key = k; break; } } if (it == end) { key->clear(); } return result; } private: template static auto DefragmentIndex(T& t, PageUsage* page_usage) { if constexpr (requires { t->Defragment(page_usage); }) { return t->Defragment(page_usage); } else { return t.Defragment(page_usage); } } std::string* key; Iterator it; Iterator end; }; } // namespace dfly::search ================================================ FILE: src/core/search/lexer.lex ================================================ %top{ // Our lexer need to know about Parser::symbol_type #include "core/search/parser.hh" #include "core/search/tag_types.h" // Include TagType enum } %{ #include #include #include "base/logging.h" #define DFLY_LEXER_CC 1 #include "core/search/scanner.h" #undef DFLY_LEXER_CC %} %o bison-cc-namespace="dfly.search" bison-cc-parser="Parser" %o namespace="dfly.search" %o class="Scanner" lex="Lex" %o nodefault batch case-insensitive /* %o debug */ /* Declarations before lexer implementation. */ %{ // A number symbol corresponding to the value in S. using dfly::search::Parser; using namespace std; using dfly::search::TagType; Parser::symbol_type make_StringLit(string_view src, const Parser::location_type& loc); Parser::symbol_type make_Tag(string_view src, TagType type, const Parser::location_type& loc); %} dq \" sq \' esc_chars ['"\?\\abfnrtv] esc_seq \\{esc_chars} term_ch \w tag_val_base_ch [^,.<>{}\[\]\\\"\?':;!@#$%^&*()\-+=~\/| ]|\\. tag_val_ch {tag_val_base_ch}+(:+{tag_val_base_ch}*)* astrsk_ch \* %{ // Code run each time a pattern is matched. %} %% %{ // Code run each time lex() is called. %} [[:space:]]+ // skip white space "(" return Parser::make_LPAREN (loc()); ")" return Parser::make_RPAREN (loc()); "*" return Parser::make_STAR (loc()); "-" return Parser::make_NOT_OP (loc()); ":" return Parser::make_COLON (loc()); "=>" return Parser::make_ARROW (loc()); "[" return Parser::make_LBRACKET (loc()); "]" return Parser::make_RBRACKET (loc()); "{" return Parser::make_LCURLBR (loc()); "}" return Parser::make_RCURLBR (loc()); "|" return Parser::make_OR_OP (loc()); "," return Parser::make_COMMA (loc()); "KNN" return Parser::make_KNN (loc()); "AS" return Parser::make_AS (loc()); "EF_RUNTIME" return Parser::make_EF_RUNTIME (loc()); "VECTOR_RANGE" return Parser::make_VECTOR_RANGE (loc()); "$YIELD_DISTANCE_AS" return Parser::make_YIELD_DISTANCE_AS (loc()); [0-9]{1,9} return Parser::make_UINT32(str(), loc()); [+-]?(([0-9]*[.])?[0-9]+|inf) return Parser::make_DOUBLE(str(), loc()); {dq}([^"]|{esc_seq})*{dq} return make_StringLit(matched_view(1, 1), loc()); {sq}([^']|{esc_seq})*{sq} return make_StringLit(matched_view(1, 1), loc()); "$"{term_ch}+ return ParseParam(str(), loc()); "@"{term_ch}+ return Parser::make_FIELD(str(), loc()); {astrsk_ch}{term_ch}+{astrsk_ch} return Parser::make_INFIX(string{matched_view(1, 1)}, loc()); {term_ch}+{astrsk_ch} return Parser::make_PREFIX(string{matched_view(0, 1)}, loc()); {astrsk_ch}{term_ch}+ return Parser::make_SUFFIX(string{matched_view(1, 0)}, loc()); {term_ch}+ return Parser::make_TERM(str(), loc()); {tag_val_ch}+{astrsk_ch} return make_Tag(str(), TagType::PREFIX, loc()); {astrsk_ch}{tag_val_ch}+ return make_Tag(str(), TagType::SUFFIX, loc()); {astrsk_ch}{tag_val_ch}+{astrsk_ch} return make_Tag(str(), TagType::INFIX, loc()); {tag_val_ch}+ return make_Tag(str(), TagType::REGULAR, loc()); <> return Parser::make_YYEOF(loc()); %% Parser::symbol_type make_StringLit(string_view src, const Parser::location_type& loc) { string res; if (!absl::CUnescape(src, &res)) throw Parser::syntax_error (loc, "bad escaped string: " + string(src)); return Parser::make_TERM(res, loc); } Parser::symbol_type make_Tag(string_view src, TagType type, const Parser::location_type& loc) { string res; res.reserve(src.size()); // Determine processing boundaries size_t start = (type == TagType::SUFFIX || type == TagType::INFIX) ? 1 : 0; size_t end = src.size(); if (type == TagType::PREFIX || type == TagType::INFIX) { end--; // Skip the last '*' character } // Handle escaping bool escaped = false; for (size_t i = start; i < end; ++i) { if (escaped) { escaped = false; } else if (src[i] == '\\') { escaped = true; continue; } res.push_back(src[i]); } // Return the appropriate token type switch (type) { case TagType::PREFIX: return Parser::make_PREFIX(res, loc); case TagType::SUFFIX: return Parser::make_SUFFIX(res, loc); case TagType::INFIX: return Parser::make_INFIX(res, loc); case TagType::REGULAR: default: return Parser::make_TAG_VAL(res, loc); } } ================================================ FILE: src/core/search/mrmw_mutex.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "base/logging.h" #include "base/spinlock.h" namespace dfly::search { // Simple implementation of multi-Reader multi-Writer Mutex // MRMWMutex supports concurrent reads or concurrent writes but not a mix of // concurrent reads and writes at the same time. class MRMWMutex { public: enum class LockMode : uint8_t { kReadLock, kWriteLock }; MRMWMutex() : lock_mode_(LockMode::kReadLock) { } void Lock(LockMode mode) { std::unique_lock lk(mutex_); // If we have any active_runners we need to check lock mode if (active_runners_) { auto& waiters = GetWaiters(mode); waiters++; GetCondVar(mode).wait(lk, [&] { return lock_mode_ == mode; }); waiters--; } else { // No active runners so just update to requested lock mode lock_mode_ = mode; } active_runners_++; } void Unlock(LockMode mode) { std::lock_guard lk(mutex_); LockMode inverse_mode = GetInverseMode(mode); active_runners_--; // If this was last runner and there are waiters on inverse mode if (!active_runners_ && GetWaiters(inverse_mode) > 0) { lock_mode_ = inverse_mode; GetCondVar(inverse_mode).notify_all(); } } // Check if the mutex is currently held in read mode with at least one active runner. // For use in DCHECKs only - not thread-safe without external synchronization. bool IsReadLocked() const { return active_runners_ > 0 && lock_mode_ == LockMode::kReadLock; } // Non-blocking lock attempt. Returns true if the lock was acquired. bool TryLock(LockMode mode) { if (!mutex_.try_lock()) { return false; } if (active_runners_ && lock_mode_ != mode) { mutex_.unlock(); return false; } if (!active_runners_) { lock_mode_ = mode; } active_runners_++; mutex_.unlock(); return true; } private: inline size_t& GetWaiters(LockMode target_mode) { return target_mode == LockMode::kReadLock ? reader_waiters_ : writer_waiters_; }; inline std::condition_variable_any& GetCondVar(LockMode target_mode) { return target_mode == LockMode::kReadLock ? reader_cond_var_ : writer_cond_var_; }; static inline LockMode GetInverseMode(LockMode mode) { return mode == LockMode::kReadLock ? LockMode::kWriteLock : LockMode::kReadLock; } // TODO: use fiber sync primitives in future base::SpinLock mutex_; std::condition_variable_any reader_cond_var_, writer_cond_var_; size_t writer_waiters_ = 0, reader_waiters_ = 0; size_t active_runners_ = 0; LockMode lock_mode_; }; class MRMWMutexLock { public: // Blocking lock. explicit MRMWMutexLock(MRMWMutex* mutex, MRMWMutex::LockMode mode) : mutex_(mutex), lock_mode_(mode), locked_(true) { mutex->Lock(lock_mode_); } // Non-blocking try-lock. Check locked() to see if the lock was acquired. MRMWMutexLock(MRMWMutex* mutex, MRMWMutex::LockMode mode, std::try_to_lock_t) : mutex_(mutex), lock_mode_(mode), locked_(mutex->TryLock(mode)) { } bool locked() const { return locked_; } ~MRMWMutexLock() { if (locked_) mutex_->Unlock(lock_mode_); } MRMWMutexLock(const MRMWMutexLock&) = delete; MRMWMutexLock(MRMWMutexLock&&) = delete; MRMWMutexLock& operator=(const MRMWMutexLock&) = delete; MRMWMutexLock& operator=(MRMWMutexLock&&) = delete; private: MRMWMutex* const mutex_; MRMWMutex::LockMode lock_mode_; bool locked_; }; } // namespace dfly::search ================================================ FILE: src/core/search/mrmw_mutex_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/mrmw_mutex.h" #include #include #include "absl/flags/flag.h" #include "base/gtest.h" #include "base/logging.h" #include "util/fibers/pool.h" ABSL_FLAG(bool, force_epoll, false, "If true, uses epoll api instead iouring to run tests"); namespace dfly::search { namespace { // Helper function to simulate reading operation void ReadTask(MRMWMutex* mutex, std::atomic& read_count, size_t sleep_time) { read_count.fetch_add(1, std::memory_order_relaxed); MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kReadLock); util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time)); read_count.fetch_sub(1, std::memory_order_relaxed); } // Helper function to simulate writing operation void WriteTask(MRMWMutex* mutex, std::atomic& write_count, size_t sleep_time) { write_count.fetch_add(1, std::memory_order_relaxed); MRMWMutexLock lock(mutex, MRMWMutex::LockMode::kWriteLock); util::ThisFiber::SleepFor(std::chrono::milliseconds(sleep_time)); write_count.fetch_sub(1, std::memory_order_relaxed); } constexpr size_t kReadTaskSleepTime = 50; constexpr size_t kWriteTaskSleepTime = 100; } // namespace class MRMWMutexTest : public ::testing::Test { protected: MRMWMutex mutex_; std::mt19937 generator_; void SetUp() override { #ifdef __linux__ if (absl::GetFlag(FLAGS_force_epoll)) { pp_.reset(util::fb2::Pool::Epoll(2)); } else { pp_.reset(util::fb2::Pool::IOUring(16, 2)); } #else pp_.reset(util::fb2::Pool::Epoll(2)); #endif pp_->Run(); } void TearDown() override { pp_->Stop(); pp_.reset(); } std::unique_ptr pp_; }; // Test 1: Multiple readers can lock concurrently TEST_F(MRMWMutexTest, MultipleReadersConcurrently) { std::atomic read_count(0); const int num_readers = 5; std::vector readers; readers.reserve(num_readers); for (int i = 0; i < num_readers; ++i) { readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); })); } // Wait for all reader threads to finish for (auto& t : readers) { t.Join(); } // All readers should have been able to lock the mutex concurrently EXPECT_EQ(read_count.load(), 0); } // Test 2: Writer blocks readers and writer should get the lock exclusively TEST_F(MRMWMutexTest, ReadersBlockWriters) { std::atomic read_count(0); std::atomic write_count(0); const int num_readers = 10; // Start multiple readers std::vector readers; readers.reserve(num_readers); for (int i = 0; i < num_readers; ++i) { readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); })); } // Give readers time to acquire the lock util::ThisFiber::SleepFor(std::chrono::milliseconds(10)); pp_->at(1) ->LaunchFiber(util::fb2::Launch::post, [&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); }) .Join(); // Wait for all reader threads to finish for (auto& t : readers) { t.Join(); } EXPECT_EQ(read_count.load(), 0); EXPECT_EQ(write_count.load(), 0); } // Test 3: Unlock transitions correctly and wakes up waiting threads TEST_F(MRMWMutexTest, ReaderAfterWriter) { std::atomic write_count(0); std::atomic read_count(0); // Start a writer thread auto writer = pp_->at(1)->LaunchFiber(util::fb2::Launch::post, [&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); }); // Give writer time to acquire the lock util::ThisFiber::SleepFor(std::chrono::milliseconds(10)); // Now start a reader task that will block until the writer is done pp_->at(0) ->LaunchFiber(util::fb2::Launch::post, [&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); }) .Join(); // Ensure that writer has completed writer.Join(); EXPECT_EQ(read_count.load(), 0); EXPECT_EQ(write_count.load(), 0); } // Test 4: Ensure writer gets the lock after readers finish TEST_F(MRMWMutexTest, WriterAfterReaders) { std::atomic read_count(0); std::atomic write_count(0); // Start multiple readers const int num_readers = 10; std::vector readers; readers.reserve(num_readers); for (int i = 0; i < num_readers; ++i) { readers.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); })); } // Wait for all readers to acquire and release the lock for (auto& t : readers) { t.Join(); } // Start the writer after all readers are done pp_->at(1) ->LaunchFiber(util::fb2::Launch::post, [&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); }) .Join(); EXPECT_EQ(read_count.load(), 0); EXPECT_EQ(write_count.load(), 0); } TEST_F(MRMWMutexTest, MixWritersReadersOnDifferentFibers) { std::atomic read_count(0); std::atomic write_count(0); // Start multiple readers and writers const int num_threads = 100; std::vector threads; threads.reserve(num_threads); for (int i = 0; i < num_threads; ++i) { if (rand() % 3) { threads.emplace_back(pp_->at(0)->LaunchFiber(util::fb2::Launch::post, [&] { ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); })); } else { threads.emplace_back(pp_->at(1)->LaunchFiber(util::fb2::Launch::post, [&] { WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); })); } } // Wait for all readers to acquire and release the lock for (auto& t : threads) { t.Join(); } } // TODO: Once we have fiber locking we can test scenario where we write/read on same fibers // current implementation block thread so it is not possible to test this for now. // Test 6: Mix of readers and writes on random fibers // TEST_F(MRMWMutexTest, MixWritersReadersOnFibers) { // std::atomic read_count(0); // std::atomic write_count(0); // // Start multiple readers and writers // const int num_threads = 100; // std::vector threads; // threads.reserve(num_threads + 1); // // Add long read task that will block all write tasks // threads.emplace_back( // pp_->at(0)->LaunchFiber([&] { ReadTask(&mutex_, std::ref(read_count), 2000); })); // // Give long writer time to acquire the lock // util::ThisFiber::SleepFor(std::chrono::milliseconds(100)); // size_t write_threads = 0; // for (int i = 0; i < num_threads; ++i) { // size_t fiber_id = rand() % 2; // if (rand() % 3) { // threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] { // ReadTask(&mutex_, std::ref(read_count), kReadTaskSleepTime); // })); // } else { // write_threads++; // threads.emplace_back(pp_->at(fiber_id)->LaunchFiber(util::fb2::Launch::post, [&] { // WriteTask(&mutex_, std::ref(write_count), kWriteTaskSleepTime); // })); // } // } // // All shorter threads should be done and only long one remains // util::ThisFiber::SleepFor(std::chrono::milliseconds(500)); // EXPECT_EQ(read_count.load(), 1); // EXPECT_EQ(write_count.load(), write_threads); // // Wait for all readers to acquire and release the lock // for (auto& t : threads) { // t.Join(); // } // } TEST_F(MRMWMutexTest, IsReadLockedReflectsState) { // Initially no lock is held. EXPECT_FALSE(mutex_.IsReadLocked()); // Acquire a read lock and verify. mutex_.Lock(MRMWMutex::LockMode::kReadLock); EXPECT_TRUE(mutex_.IsReadLocked()); // A second concurrent reader should still report read-locked. mutex_.Lock(MRMWMutex::LockMode::kReadLock); EXPECT_TRUE(mutex_.IsReadLocked()); // Release one reader — still locked by the other. mutex_.Unlock(MRMWMutex::LockMode::kReadLock); EXPECT_TRUE(mutex_.IsReadLocked()); // Release the last reader. mutex_.Unlock(MRMWMutex::LockMode::kReadLock); EXPECT_FALSE(mutex_.IsReadLocked()); } TEST_F(MRMWMutexTest, IsReadLockedFalseUnderWriteLock) { mutex_.Lock(MRMWMutex::LockMode::kWriteLock); EXPECT_FALSE(mutex_.IsReadLocked()); mutex_.Unlock(MRMWMutex::LockMode::kWriteLock); } TEST_F(MRMWMutexTest, TryLockSucceedsWhenFree) { // TryLock on a free mutex should succeed for both modes. EXPECT_TRUE(mutex_.TryLock(MRMWMutex::LockMode::kReadLock)); mutex_.Unlock(MRMWMutex::LockMode::kReadLock); EXPECT_TRUE(mutex_.TryLock(MRMWMutex::LockMode::kWriteLock)); mutex_.Unlock(MRMWMutex::LockMode::kWriteLock); } TEST_F(MRMWMutexTest, TryLockFailsOnConflict) { // Hold a read lock, then try-lock for write should fail. mutex_.Lock(MRMWMutex::LockMode::kReadLock); EXPECT_FALSE(mutex_.TryLock(MRMWMutex::LockMode::kWriteLock)); mutex_.Unlock(MRMWMutex::LockMode::kReadLock); // Hold a write lock, then try-lock for read should fail. mutex_.Lock(MRMWMutex::LockMode::kWriteLock); EXPECT_FALSE(mutex_.TryLock(MRMWMutex::LockMode::kReadLock)); mutex_.Unlock(MRMWMutex::LockMode::kWriteLock); } TEST_F(MRMWMutexTest, TryLockSucceedsForSameMode) { // Multiple readers via TryLock should all succeed. mutex_.Lock(MRMWMutex::LockMode::kReadLock); EXPECT_TRUE(mutex_.TryLock(MRMWMutex::LockMode::kReadLock)); mutex_.Unlock(MRMWMutex::LockMode::kReadLock); mutex_.Unlock(MRMWMutex::LockMode::kReadLock); // Multiple writers via TryLock should all succeed. mutex_.Lock(MRMWMutex::LockMode::kWriteLock); EXPECT_TRUE(mutex_.TryLock(MRMWMutex::LockMode::kWriteLock)); mutex_.Unlock(MRMWMutex::LockMode::kWriteLock); mutex_.Unlock(MRMWMutex::LockMode::kWriteLock); } TEST_F(MRMWMutexTest, MRMWMutexLockTryLockSemantics) { // Hold a read lock, then try a MRMWMutexLock for write — should not be locked. MRMWMutexLock read_lock(&mutex_, MRMWMutex::LockMode::kReadLock); MRMWMutexLock try_write(&mutex_, MRMWMutex::LockMode::kWriteLock, std::try_to_lock); EXPECT_FALSE(try_write.locked()); // Same-mode try-lock via RAII should succeed. MRMWMutexLock try_read(&mutex_, MRMWMutex::LockMode::kReadLock, std::try_to_lock); EXPECT_TRUE(try_read.locked()); } } // namespace dfly::search ================================================ FILE: src/core/search/parser.y ================================================ %skeleton "lalr1.cc" // -*- C++ -*- %require "3.5" // fedora 32 has this one. %defines // %header starts from 3.8.1 %define api.namespace {dfly::search} %define api.token.raw %define api.token.constructor %define api.value.type variant %define api.parser.class {Parser} %define parse.assert %define api.value.automove true // Added to header file before parser declaration. %code requires { #include "core/search/ast_expr.h" namespace dfly { namespace search { class QueryDriver; } } } // Added to cc file %code { #include #include "core/search/query_driver.h" #include "core/search/vector_utils.h" #define yylex driver->scanner()->Lex using namespace std; uint32_t toUint32(string_view src); double toDouble(string_view src); } %parse-param { QueryDriver *driver } %locations %define parse.trace %define parse.error verbose // detailed %define parse.lac full %define api.token.prefix {TOK_} %token LPAREN "(" RPAREN ")" STAR "*" ARROW "=>" COLON ":" LBRACKET "[" RBRACKET "]" LCURLBR "{" RCURLBR "}" OR_OP "|" COMMA "," KNN "KNN" AS "AS" EF_RUNTIME "EF_RUNTIME" VECTOR_RANGE "VECTOR_RANGE" YIELD_DISTANCE_AS "$YIELD_DISTANCE_AS" ; %token AND_OP // Needed 0 at the end to satisfy bison 3.5.1 %token YYEOF 0 %token TERM "term" TAG_VAL "tag_val" PARAM "param" FIELD "field" PREFIX "prefix" SUFFIX "suffix" INFIX "infix" %precedence TERM TAG_VAL %left OR_OP %left AND_OP %right NOT_OP %precedence LPAREN RPAREN %token DOUBLE "double" %token UINT32 "uint32" %nterm final_query filter star_expr search_expr search_unary_expr search_or_expr search_and_expr bracket_filter_expr %nterm field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list %nterm tag_list_element %nterm knn_query %nterm opt_knn_alias %nterm geounit %nterm > opt_ef_runtime %nterm vector_range_query %nterm vec_range_radius %printer { yyo << $$; } <*>; %% final_query: filter { driver->Set(std::move($1)); } | filter ARROW knn_query { driver->Set(AstKnnNode(std::move($1), std::move($3))); } | vector_range_query { driver->Set(std::move($1)); } knn_query: LBRACKET KNN UINT32 FIELD TERM opt_ef_runtime opt_knn_alias RBRACKET { // Accept any string as vector - validation happens later during search execution uint32_t knn_count = toUint32($3); auto field = std::move($4); auto alias = std::move($7); auto ef = $6; auto vec_result = BytesToFtVectorSafe($5); if (!vec_result) { // Create empty vector for invalid data - will return empty results during search auto empty_vec = std::make_unique(0); $$ = AstKnnNode(knn_count, std::move(field), std::make_pair(std::move(empty_vec), size_t{0}), std::move(alias), ef); } else { $$ = AstKnnNode(knn_count, std::move(field), std::move(*vec_result), std::move(alias), ef); } } opt_knn_alias: AS TERM { $$ = std::move($2); } | { $$ = std::string{}; } opt_ef_runtime: /* empty */ { $$ = std::nullopt; } | EF_RUNTIME UINT32 { $$ = toUint32($2); } vector_range_query: FIELD COLON LBRACKET VECTOR_RANGE vec_range_radius TERM RBRACKET ARROW LCURLBR YIELD_DISTANCE_AS COLON TERM RCURLBR { double radius = $5; auto field = std::move($1); auto alias = std::move($12); auto vec_result = BytesToFtVectorSafe($6); if (!vec_result) { auto empty_vec = std::make_unique(0); $$ = AstVectorRangeNode(std::move(field), radius, {std::move(empty_vec), size_t{0}}, std::move(alias)); } else { $$ = AstVectorRangeNode(std::move(field), radius, std::move(*vec_result), std::move(alias)); } } vec_range_radius: DOUBLE { $$ = toDouble($1); } | UINT32 { $$ = static_cast(toUint32($1)); } | TERM { double v = 0; if (!absl::SimpleAtod($1, &v)) YYABORT; $$ = v; } filter: search_expr { $$ = std::move($1); } | star_expr { $$ = std::move($1); } star_expr: STAR { $$ = AstStarNode(); } | LPAREN star_expr RPAREN { $$ = std::move($2); } search_expr: search_unary_expr { $$ = std::move($1); } | search_and_expr { $$ = std::move($1); } | search_or_expr { $$ = std::move($1); } search_and_expr: search_unary_expr search_unary_expr %prec AND_OP { $$ = AstLogicalNode(std::move($1), std::move($2), AstLogicalNode::AND); } | search_and_expr search_unary_expr %prec AND_OP { $$ = AstLogicalNode(std::move($1), std::move($2), AstLogicalNode::AND); } search_or_expr: search_expr OR_OP search_and_expr { $$ = AstLogicalNode(std::move($1), std::move($3), AstLogicalNode::OR); } | search_expr OR_OP search_unary_expr { $$ = AstLogicalNode(std::move($1), std::move($3), AstLogicalNode::OR); } search_unary_expr: LPAREN search_expr RPAREN { $$ = std::move($2); } | NOT_OP search_unary_expr { $$ = AstNegateNode(std::move($2)); } | TERM { $$ = AstTermNode(std::move($1)); } | PREFIX { $$ = AstPrefixNode(std::move($1)); } | SUFFIX { $$ = AstSuffixNode(std::move($1)); } | INFIX { $$ = AstInfixNode(std::move($1)); } | UINT32 { $$ = AstTermNode(std::move($1)); } | FIELD COLON field_cond { $$ = AstFieldNode(std::move($1), std::move($3)); } field_cond: TERM { $$ = AstTermNode(std::move($1)); } | UINT32 { $$ = AstTermNode(std::move($1)); } | STAR { $$ = AstStarFieldNode(); } | NOT_OP field_cond { $$ = AstNegateNode(std::move($2)); } | LPAREN field_cond_expr RPAREN { $$ = std::move($2); } | LBRACKET bracket_filter_expr RBRACKET { $$ = std::move($2); } | LCURLBR tag_list RCURLBR { $$ = std::move($2); } | PREFIX { $$ = AstPrefixNode(std::move($1)); } | SUFFIX { $$ = AstSuffixNode(std::move($1)); } | INFIX { $$ = AstInfixNode(std::move($1)); } bracket_filter_expr: /* Numeric filter has form [(] UINT32|DOUBLE [COMMA] [(] UINT32|DOUBLE */ DOUBLE DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($2), false); } | LPAREN DOUBLE DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($3), false); } | DOUBLE LPAREN DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($3), true); } | LPAREN DOUBLE LPAREN DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($4), true); } | DOUBLE UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($2), false); } | LPAREN DOUBLE UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($3), false); } | DOUBLE LPAREN UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($3), true); } | LPAREN DOUBLE LPAREN UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($4), true); } | UINT32 DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($2), false); } | LPAREN UINT32 DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($3), false); } | UINT32 LPAREN DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($3), true); } | LPAREN UINT32 LPAREN DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($4), true); } | UINT32 UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($2), false); } | LPAREN UINT32 UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($3), false); } | UINT32 LPAREN UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($3), true); } | LPAREN UINT32 LPAREN UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($4), true); } | DOUBLE COMMA DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($3), false); } | DOUBLE COMMA UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($3), false); } | UINT32 COMMA DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($3), false); } | UINT32 COMMA UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($3), false); } | LPAREN DOUBLE COMMA DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($4), false); } | DOUBLE COMMA LPAREN DOUBLE { $$ = AstRangeNode(toDouble($1), false, toDouble($4), true); } | LPAREN DOUBLE COMMA LPAREN DOUBLE { $$ = AstRangeNode(toDouble($2), true, toDouble($5), true); } | LPAREN DOUBLE COMMA UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($4), false); } | DOUBLE COMMA LPAREN UINT32 { $$ = AstRangeNode(toDouble($1), false, toUint32($4), true); } | LPAREN DOUBLE COMMA LPAREN UINT32 { $$ = AstRangeNode(toDouble($2), true, toUint32($5), true); } | LPAREN UINT32 COMMA DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($4), false); } | UINT32 COMMA LPAREN DOUBLE { $$ = AstRangeNode(toUint32($1), false, toDouble($4), true); } | LPAREN UINT32 COMMA LPAREN DOUBLE { $$ = AstRangeNode(toUint32($2), true, toDouble($5), true); } | LPAREN UINT32 COMMA UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($4), false); } | UINT32 COMMA LPAREN UINT32 { $$ = AstRangeNode(toUint32($1), false, toUint32($4), true); } | LPAREN UINT32 COMMA LPAREN UINT32 { $$ = AstRangeNode(toUint32($2), true, toUint32($5), true); } /* GEO filter */ | DOUBLE DOUBLE UINT32 geounit { $$ = AstGeoNode(toDouble($1), toDouble($2), toUint32($3), std::move($4)); } | DOUBLE DOUBLE DOUBLE geounit { $$ = AstGeoNode(toDouble($1), toDouble($2), toDouble($3), std::move($4)); } geounit: TERM { std::string unit = $1; absl::AsciiStrToUpper(&unit); if ((unit == "M") || (unit == "KM") || (unit == "MI") || (unit == "FT")) { $$ = unit; } else { YYABORT; } } field_cond_expr: field_unary_expr { $$ = std::move($1); } | field_and_expr { $$ = std::move($1); } | field_or_expr { $$ = std::move($1); } field_and_expr: field_unary_expr field_unary_expr %prec AND_OP { $$ = AstLogicalNode(std::move($1), std::move($2), AstLogicalNode::AND); } | field_and_expr field_unary_expr %prec AND_OP { $$ = AstLogicalNode(std::move($1), std::move($2), AstLogicalNode::AND); } field_or_expr: field_cond_expr OR_OP field_unary_expr { $$ = AstLogicalNode(std::move($1), std::move($3), AstLogicalNode::OR); } | field_cond_expr OR_OP field_and_expr { $$ = AstLogicalNode(std::move($1), std::move($3), AstLogicalNode::OR); } field_unary_expr: LPAREN field_cond_expr RPAREN { $$ = std::move($2); } | NOT_OP field_unary_expr { $$ = AstNegateNode(std::move($2)); } | TERM { $$ = AstTermNode(std::move($1)); } | UINT32 { $$ = AstTermNode(std::move($1)); } tag_list: tag_list_element { $$ = AstTagsNode(std::move($1)); } | tag_list OR_OP tag_list_element { $$ = AstTagsNode(std::move($1), std::move($3)); } tag_list_element: TERM { $$ = AstTermNode(std::move($1)); } | PREFIX { $$ = AstPrefixNode(std::move($1)); } | SUFFIX { $$ = AstSuffixNode(std::move($1)); } | INFIX { $$ = AstInfixNode(std::move($1)); } | UINT32 { $$ = AstTermNode(std::move($1)); } | DOUBLE { $$ = AstTermNode(std::move($1)); } | TAG_VAL { $$ = AstTermNode(std::move($1)); } %% void dfly::search::Parser::error(const location_type& l, const string& m) { driver->Error(l, m); } std::uint32_t toUint32(string_view str) { uint32_t val = 0; std::ignore = absl::SimpleAtoi(str, &val); // no need to check the result because str is parsed by regex return val; } double toDouble(string_view str) { double val = 0; std::ignore = absl::SimpleAtod(str, &val); // no need to check the result because str is parsed by regex return val; } ================================================ FILE: src/core/search/query_driver.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/query_driver.h" namespace dfly { namespace search { QueryDriver::QueryDriver() : scanner_(std::make_unique()) { } QueryDriver::~QueryDriver() { } void QueryDriver::ResetScanner() { scanner_ = std::make_unique(); scanner_->SetParams(params_); } void QueryDriver::Error(const Parser::location_type& loc, std::string_view msg) { VLOG(1) << "Parse error " << loc << ": " << msg; } void QueryDriver::SetOptionalFilters(const OptionalFilters* filters) { if (filters) { for (auto& [field, filter] : *filters) { expr_ = AstLogicalNode(std::move(expr_), filter->Node(field), AstLogicalNode::AND); } } } } // namespace search } // namespace dfly ================================================ FILE: src/core/search/query_driver.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "core/search/ast_expr.h" #include "core/search/base.h" #include "core/search/parser.hh" #include "core/search/scanner.h" namespace dfly { namespace search { class QueryDriver { public: QueryDriver(); ~QueryDriver(); void SetInput(std::string str) { cur_str_ = std::move(str); scanner()->in(cur_str_); } void SetParams(const QueryParams* params) { params_ = params; scanner_->SetParams(params); } void SetOptionalFilters(const OptionalFilters* filters); Parser::symbol_type Lex() { return scanner()->Lex(); } void ResetScanner(); void Set(AstExpr expr) { expr_ = std::move(expr); } AstExpr Take() { return std::move(expr_); } const QueryParams& GetParams() const { return *params_; } Scanner* scanner() { return scanner_.get(); } void Error(const Parser::location_type& loc, std::string_view msg); public: Parser::location_type location; private: const QueryParams* params_; AstExpr expr_; std::string cur_str_; std::unique_ptr scanner_; }; } // namespace search } // namespace dfly ================================================ FILE: src/core/search/range_tree.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/range_tree.h" namespace dfly::search { namespace { std::vector MergeAllResults(absl::Span blocks, double l, double r) { DCHECK(blocks.size() != 1 && blocks.size() != 2); // After the benchmarking, it is better to use inlined vector // than std::priority_queue absl::InlinedVector heap; heap.reserve(blocks.size()); size_t doc_ids_count = 0; for (const auto* block : blocks) { auto it = MakeBegin(*block, l, r); if (!it.HasReachedEnd()) { heap.emplace_back(it); doc_ids_count += block->Size(); } } std::vector result; result.reserve(doc_ids_count); size_t size = heap.size(); while (size) { DCHECK(!heap[0].HasReachedEnd()); size_t min_doc_id_index = 0; for (size_t i = 1; i < size; ++i) { DCHECK(!heap[i].HasReachedEnd()); if (*heap[i] < *heap[min_doc_id_index]) { min_doc_id_index = i; } } auto& it = heap[min_doc_id_index]; result.push_back(*it); ++it; if (it.HasReachedEnd()) { // If we reached the end of the current block, remove it from the heap std::swap(heap[min_doc_id_index], heap[size - 1]); --size; } } DCHECK(std::is_sorted(result.begin(), result.end())); return result; } template auto FindRangeBlockImpl(MapT& entries, double value) { DCHECK(!entries.empty()); auto it = entries.lower_bound(value); if (it != entries.begin() && (it == entries.end() || it->first > value)) { // TODO: remove this, we do log N here // we can use negative left bouding to find the block --it; // Move to the block that contains the value } DCHECK(it != entries.end() && it->first <= value); return it; } } // namespace RangeTree::RangeTree(PMR_NS::memory_resource* mr, size_t max_range_block_size) : max_range_block_size_(max_range_block_size), entries_(mr) { // The tree has at least always a block with a negative infinity bound, so that any new insertion // goes at least somewhere CreateEmptyBlock(-std::numeric_limits::infinity()); } void RangeTree::Add(DocId id, double value) { DCHECK(std::isfinite(value)); auto it = FindRangeBlock(value); auto& [lower_bound, block] = *it; // Don't disrupt large monovalue blocks, instead create new nextafter block if (block.Size() >= max_range_block_size_ && lower_bound == block.max_seen /* monovalue */ && value != lower_bound /* but new value is different*/ ) { // We use nextafter as the lower bound to "catch" all other possible inserts into the block, // as a decreasing `value` sequence would otherwise create lots of single-value blocks double lb2 = std::nextafter(lower_bound, std::numeric_limits::infinity()); CreateEmptyBlock(lb2)->second.Insert({id, value}); return; } auto insert_result = block.Insert({id, value}); LOG_IF(ERROR, !insert_result) << "RangeTree: Failed to insert id: " << id << ", value: " << value; // Small block or large monovalue block, not reducable by splitting if (block.Size() <= max_range_block_size_ || lower_bound == block.max_seen) return; SplitBlock(it); } void RangeTree::Remove(DocId id, double value) { DCHECK(std::isfinite(value)); auto it = FindRangeBlock(value); RangeBlock& block = it->second; auto remove_result = block.Remove({id, value}); LOG_IF(ERROR, !remove_result) << "RangeTree: Failed to remove id: " << id << ", value: " << value; // Merge with left block if both are relatively small and won't be forced to split soon if (block.size() < max_range_block_size_ / 4 && it != entries_.begin()) { auto lit = it; --lit; auto& lblock = lit->second; if (block.Size() + lblock.Size() < max_range_block_size_ / 2) { for (auto e : block) lblock.Insert(e); entries_.erase(it); stats_.merges++; } } } RangeResult RangeTree::Range(double l, double r) const { return {RangeBlocks(l, r), l, r}; } absl::InlinedVector RangeTree::RangeBlocks(double l, double r) const { DCHECK(l <= r); auto it_l = FindRangeBlock(l); auto it_r = FindRangeBlock(r); absl::InlinedVector blocks; for (auto it = it_l;; ++it) { blocks.push_back(&it->second); if (it == it_r) { break; } } DCHECK(!blocks.empty()); return blocks; } RangeResult RangeTree::GetAllDocIds() const { return RangeResult{GetAllBlocks()}; } absl::InlinedVector RangeTree::GetAllBlocks() const { absl::InlinedVector blocks; blocks.reserve(entries_.size()); for (const auto& entry : entries_) { blocks.push_back(&entry.second); } return blocks; } RangeTree::Map::iterator RangeTree::FindRangeBlock(double value) { return FindRangeBlockImpl(entries_, value); } RangeTree::Map::const_iterator RangeTree::FindRangeBlock(double value) const { return FindRangeBlockImpl(entries_, value); } RangeTree::Map::iterator RangeTree::CreateEmptyBlock(double lb) { return entries_ .emplace(std::piecewise_construct, std::forward_as_tuple(lb), std::forward_as_tuple(entries_.get_allocator().resource(), max_range_block_size_)) .first; } /* There is an edge case in the SplitBlock method: If split_result.left.Size() == 0, it means that all values in the block were equal to the median value. Because split works like this: - at the beginning it does not insert median values into the left or right block, - then it checks if left block is smaller than right block, if so, it adds median values to the left block, otherwise it adds it to the right block. So if left block is empty, it means that left.Size() < right.Size() was false, what means that right.Size() was also zero. After that all median entries were added to the right block. That means that we have equal values in the whole block, and their count is greater than max_range_block_size_. So we will do cascade splits of the right block. TODO: we can optimize this case by splitting to three blocks: - empty left block with range [l, m), - middle block with range [m, std::nextafter(m, +inf)), - empty right block with range [std::nextafter(m, +inf), r) */ void RangeTree::SplitBlock(Map::iterator it) { double lower_bound = it->first; auto split_result = Split(std::move(it->second)); const double m = split_result.median; DCHECK(!split_result.right.Empty()); entries_.erase(it); stats_.splits++; // Insert left block if it's not empty or if its the first one (negative inf bound) if (!split_result.left.Empty() || std::isinf(lower_bound)) { if (!std::isinf(lower_bound)) // keep negative inf bound lower_bound = split_result.lmin; entries_.emplace(std::piecewise_construct, std::forward_as_tuple(lower_bound), std::forward_as_tuple(std::move(split_result.left), split_result.lmax)); } entries_.emplace(std::piecewise_construct, std::forward_as_tuple(m), std::forward_as_tuple(std::move(split_result.right), split_result.rmax)); DCHECK(TreeIsInCorrectState()); } RangeTree::Stats RangeTree::GetStats() const { return Stats{.splits = stats_.splits, .merges = stats_.merges, .block_count = entries_.size()}; } // Used for DCHECKs to check that the tree is in a correct state. [[maybe_unused]] bool RangeTree::TreeIsInCorrectState() const { if (entries_.empty()) { return false; } double prev_range = entries_.begin()->first; for (auto it = std::next(entries_.begin()); it != entries_.end(); ++it) { const double& current_range = it->first; // Check that ranges are non-overlapping and sorted // Also there can not be gaps between ranges if (prev_range >= current_range) { return false; } prev_range = current_range; } return true; } RangeResult::RangeResult(std::vector doc_ids) : result_(std::move(doc_ids)) { } RangeResult::RangeResult(absl::InlinedVector blocks) : RangeResult(std::move(blocks), -std::numeric_limits::infinity(), std::numeric_limits::infinity()) { } RangeResult::RangeResult(absl::InlinedVector blocks, double l, double r) { if (blocks.size() == 1) { result_ = SingleBlockRangeResult(blocks[0], l, r); } else if (blocks.size() == 2) { result_ = TwoBlocksRangeResult(blocks[0], blocks[1], l, r); } else { result_ = MergeAllResults(absl::MakeSpan(blocks), l, r); } } std::vector RangeResult::Take() { if (std::holds_alternative(result_)) { DCHECK(std::is_sorted(std::get(result_).begin(), std::get(result_).end())); return std::get(std::move(result_)); } auto cb = [](const auto& v) { std::vector result; result.reserve(v.size()); std::copy(v.begin(), v.end(), std::back_inserter(result)); DCHECK(std::is_sorted(result.begin(), result.end())); return result; }; return std::visit(cb, result_); } void RangeTree::Builder::Add(DocId id, double value) { bool inserted = updates_.emplace(id, value).second; DCHECK(inserted); } void RangeTree::Builder::Remove(DocId id, double value) { if (!updates_.erase({id, value})) delayed_erased_.emplace(id, value); } void RangeTree::Builder::Populate(RangeTree* tree, const RenewableQuota& quota) { // Sort all elements by value std::vector sorted_entries(updates_.begin(), updates_.end()); std::ranges::sort(sorted_entries, {}, &Entry::second); updates_.clear(); quota.Check(); // TODO: sort might take a long time // Add sorted elements in batches size_t max_size = tree->max_range_block_size_; RangeBlock* block = &tree->entries_.begin()->second; for (size_t idx = 0; idx < sorted_entries.size();) { // Create new block for each insertion batch (first goes into only first block) if (idx) block = &tree->CreateEmptyBlock(sorted_entries[idx].second)->second; // Insert until we filled a block and a new value started (equal value must be in same block) while (idx < sorted_entries.size()) { if (block->Size() >= max_size && sorted_entries[idx - 1].second != sorted_entries[idx].second) break; block->Insert(sorted_entries[idx]); idx++; // If we filled a new multiple of the block size due to equal entries, check quota if ((block->Size() - 1) / max_size != block->Size() / max_size) quota.Check(); } quota.Check(); // Yield if needed } // Update entries accumulated during yields in batches while respecting quota. // Last loop is atomic (without quota checks) to ensure consistency size_t iterations = 3; while (iterations--) { // Take updates to allow new ones during suspensions auto stolen_erased = std::move(delayed_erased_); auto stolen_updates = std::move(updates_); delayed_erased_.clear(); updates_.clear(); auto check_quota = [&, ops = size_t(0)]() mutable { ops++; if (iterations && ops / max_size != (ops + 1) / max_size) quota.Check(); }; for (auto [id, v] : stolen_erased) { tree->Remove(id, v); check_quota(); } for (auto [id, v] : stolen_updates) { tree->Add(id, v); check_quota(); } } // Because last iteration was atomic DCHECK(updates_.empty()); DCHECK(delayed_erased_.empty()); } } // namespace dfly::search ================================================ FILE: src/core/search/range_tree.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include "base/pmr/memory_resource.h" #include "core/search/base.h" #include "core/search/block_list.h" #include "core/search/renewable_quota.h" namespace dfly::search { class RangeResult; /* RangeTree is an index structure for numeric fields that allows efficient range queries. It maps disjoint numeric ranges (e.g., [0, 5), [5, 10), [10, 15), ...) to sorted sets of document IDs. Internally, it uses absl::btree_map, RangeBlock>, where each key represents a numeric value range, and the corresponding RangeBlock (similar to std::vector) stores (DocId, value) pairs, sorted by DocId. The parameter `max_range_block_size_` defines the maximum number of entries in a single RangeBlock. When a block exceeds this limit, it is split into two to maintain balanced performance. */ class RangeTree { public: friend class RangeResult; using Entry = std::pair; // More efficient builder for range tree where updates are batched // and then applied in an optimized order inside Populate. struct Builder { void Add(DocId id, double value); void Remove(DocId id, double value); // Build tree from batched updates. Accepts new updates during suspensions. void Populate(RangeTree* tree, const RenewableQuota& quota); private: absl::flat_hash_set updates_, delayed_erased_; }; // Main node of numeric tree struct RangeBlock : public BlockList> { template explicit RangeBlock(PMR_NS::memory_resource* mr, Ts... ts) : BlockList{mr, ts...} { } RangeBlock(BlockList>&& bs, double maxv) : BlockList{std::move(bs)}, max_seen{maxv} { } bool Insert(Entry e) { max_seen = std::max(max_seen, e.second); return BlockList::Insert(e); } // Max value seen, might be not present anymore double max_seen = -std::numeric_limits::infinity(); }; static constexpr size_t kDefaultMaxRangeBlockSize = 10'000; explicit RangeTree(PMR_NS::memory_resource* mr, size_t max_range_block_size = kDefaultMaxRangeBlockSize); // Adds a document with a value to the index. void Add(DocId id, double value); // Removes a document with a value from the index. void Remove(DocId id, double value); // Returns all documents with values in the range [l, r]. RangeResult Range(double l, double r) const; // Same as Range, but returns the blocks that contain the results. absl::InlinedVector RangeBlocks(double l, double r) const; RangeResult GetAllDocIds() const; // Returns all blocks in the tree. absl::InlinedVector GetAllBlocks() const; struct Stats { size_t splits = 0; size_t merges = 0; size_t block_count = 0; }; Stats GetStats() const; private: using Map = absl::btree_map, PMR_NS::polymorphic_allocator>>; Map::iterator FindRangeBlock(double value); Map::const_iterator FindRangeBlock(double value) const; Map::iterator CreateEmptyBlock(double lb); void SplitBlock(Map::iterator it); // Used for DCHECKs bool TreeIsInCorrectState() const; private: // The maximum size of a range block. If a block exceeds this size, it will be split size_t max_range_block_size_; Map entries_; struct { size_t splits = 0; size_t merges = 0; } stats_; }; /* This iterator filters out entries that are not in the range [l, r]. It is used to iterate over the RangeBlock and return only the entries that are within the specified range. The iterator is initialized with a range [l, r] and will skip entries that are outside this range. */ class RangeFilterIterator : public SeekableTag { private: static constexpr DocId kInvalidDocId = std::numeric_limits::max(); using RangeBlock = RangeTree::RangeBlock; using BaseIterator = RangeBlock::BlockListIterator; public: using iterator_category = BaseIterator::iterator_category; using difference_type = BaseIterator::difference_type; using value_type = DocId; using pointer = value_type*; using reference = value_type&; RangeFilterIterator(BaseIterator begin, BaseIterator end, double l, double r); value_type operator*() const; RangeFilterIterator& operator++(); void SeekGE(DocId min_doc_id); bool operator==(const RangeFilterIterator& other) const; bool operator!=(const RangeFilterIterator& other) const; bool HasReachedEnd() const; private: void SkipInvalidEntries(DocId last_id); bool InRange(BaseIterator it) const; double l_, r_; BaseIterator current_, end_; }; RangeFilterIterator MakeBegin(const RangeTree::RangeBlock& block, double l, double r); RangeFilterIterator MakeEnd(const RangeTree::RangeBlock& block, double l, double r); /* Separate class for merging results from a single RangeBlock. It provides an iterator interface to iterate over the entries in the block that are within the specified range [l, r]. This is used when the result of a range query is contained within a single block. It is needed to avoid unnecessary complexity in the RangeResult class, which can handle both single and multiple blocks. It provides better performance and clarity when dealing with single block results. */ class SingleBlockRangeResult { public: SingleBlockRangeResult(const RangeTree::RangeBlock* block, double l, double r); RangeFilterIterator begin() const; RangeFilterIterator end() const; size_t size() const; private: double l_; double r_; const RangeTree::RangeBlock* block_ = nullptr; }; /* Separate class for merging results from two RangeBlocks. It provides an iterator interface to iterate over the entries in both blocks that are within the specified range [l, r]. It automatically merges the results from both blocks and provides a unified view. This is used when the result of a range query spans two blocks. It provides a more efficient way to handle results that span multiple blocks, avoiding unnecessary complexity in the RangeResult class. TODO: Implement efficient merging for more than two blocks and remove this class. */ class TwoBlocksRangeResult { public: TwoBlocksRangeResult(const RangeTree::RangeBlock* left_block, const RangeTree::RangeBlock* right_block, double l, double r); size_t size() const; class MergingIterator : public SeekableTag { private: static constexpr DocId kInvalidDocId = std::numeric_limits::max(); public: using iterator_category = RangeFilterIterator::iterator_category; using difference_type = RangeFilterIterator::difference_type; using value_type = RangeFilterIterator::value_type; using pointer = RangeFilterIterator::pointer; using reference = RangeFilterIterator::reference; MergingIterator(RangeFilterIterator l, RangeFilterIterator r); value_type operator*() const; MergingIterator& operator++(); void SeekGE(DocId min_doc_id); bool operator==(const MergingIterator& other) const; bool operator!=(const MergingIterator& other) const; private: void InitializeMin(); DocId current_min_ = kInvalidDocId; RangeFilterIterator l_; RangeFilterIterator r_; }; MergingIterator begin() const; MergingIterator end() const; private: double l_; double r_; const RangeTree::RangeBlock* left_block_ = nullptr; const RangeTree::RangeBlock* right_block_ = nullptr; }; /* Represent the result of a range query on the RangeTree. It can contain results from a single block, two blocks, or several blocks. Several blocks are merged into a single result, which is represented by vector. TODO: Implement efficient merging for more than two blocks */ class RangeResult { private: using RangeBlockPointer = const RangeTree::RangeBlock*; using RangeBlockIterator = RangeTree::RangeBlock::BlockListIterator; using DocsList = std::vector; using Variant = std::variant; public: RangeResult() = default; explicit RangeResult(std::vector doc_ids); explicit RangeResult(absl::InlinedVector blocks); RangeResult(absl::InlinedVector blocks, double l, double r); std::vector Take(); Variant& GetResult(); const Variant& GetResult() const; private: Variant result_; }; // Implementation /******************************************************************/ inline RangeFilterIterator::RangeFilterIterator(BaseIterator begin, BaseIterator end, double l, double r) : l_(l), r_(r), current_(begin), end_(end) { SkipInvalidEntries(kInvalidDocId); } inline RangeFilterIterator::value_type RangeFilterIterator::operator*() const { return (*current_).first; } inline RangeFilterIterator& RangeFilterIterator::operator++() { const DocId last_id = (*current_).first; ++current_; SkipInvalidEntries(last_id); return *this; } inline void RangeFilterIterator::SeekGE(DocId min_doc_id) { current_.SeekGE(min_doc_id); while (current_ != end_ && !InRange(current_)) { DCHECK((*current_).first >= min_doc_id); ++current_; } } inline bool RangeFilterIterator::operator==(const RangeFilterIterator& other) const { return current_ == other.current_; } inline bool RangeFilterIterator::operator!=(const RangeFilterIterator& other) const { return current_ != other.current_; } inline bool RangeFilterIterator::HasReachedEnd() const { return current_ == end_; } inline void RangeFilterIterator::SkipInvalidEntries(DocId last_id) { // Faster than using std::find_if while (current_ != end_ && (!InRange(current_) || (*current_).first == last_id)) { ++current_; } } inline bool RangeFilterIterator::InRange(BaseIterator it) const { return l_ <= (*it).second && (*it).second <= r_; } inline RangeFilterIterator MakeBegin(const RangeTree::RangeBlock& block, double l, double r) { return {block.begin(), block.end(), l, r}; } inline RangeFilterIterator MakeEnd(const RangeTree::RangeBlock& block, double l, double r) { return {block.end(), block.end(), l, r}; } inline SingleBlockRangeResult::SingleBlockRangeResult(const RangeTree::RangeBlock* block, double l, double r) : l_(l), r_(r), block_(block) { DCHECK(block_ != nullptr); } inline RangeFilterIterator SingleBlockRangeResult::begin() const { return MakeBegin(*block_, l_, r_); } inline RangeFilterIterator SingleBlockRangeResult::end() const { return MakeEnd(*block_, l_, r_); } inline size_t SingleBlockRangeResult::size() const { return block_->Size(); } inline TwoBlocksRangeResult::TwoBlocksRangeResult(const RangeTree::RangeBlock* left_block, const RangeTree::RangeBlock* right_block, double l, double r) : l_(l), r_(r), left_block_(left_block), right_block_(right_block) { DCHECK(left_block_ != nullptr); DCHECK(right_block_ != nullptr); } inline size_t TwoBlocksRangeResult::size() const { return left_block_->Size() + right_block_->Size(); } inline TwoBlocksRangeResult::MergingIterator::MergingIterator(RangeFilterIterator l, RangeFilterIterator r) : l_(std::move(l)), r_(std::move(r)) { InitializeMin(); } inline TwoBlocksRangeResult::MergingIterator::value_type TwoBlocksRangeResult::MergingIterator::operator*() const { return current_min_; } inline TwoBlocksRangeResult::MergingIterator& TwoBlocksRangeResult::MergingIterator::operator++() { auto increase_iterator = [&](RangeFilterIterator& it) { ++it; current_min_ = !it.HasReachedEnd() ? *it : std::numeric_limits::max(); }; if (l_.HasReachedEnd()) { increase_iterator(r_); } else if (r_.HasReachedEnd()) { increase_iterator(l_); } else { DCHECK(!l_.HasReachedEnd() && !r_.HasReachedEnd()); if (*l_ == current_min_) { ++l_; } if (*r_ == current_min_) { ++r_; } InitializeMin(); } return *this; } inline void TwoBlocksRangeResult::MergingIterator::SeekGE(DocId min_doc_id) { l_.SeekGE(min_doc_id); r_.SeekGE(min_doc_id); InitializeMin(); } inline bool TwoBlocksRangeResult::MergingIterator::operator==( const TwoBlocksRangeResult::MergingIterator& other) const { return l_ == other.l_ && r_ == other.r_; } inline bool TwoBlocksRangeResult::MergingIterator::operator!=( const TwoBlocksRangeResult::MergingIterator& other) const { return !(*this == other); } inline void TwoBlocksRangeResult::MergingIterator::InitializeMin() { DocId left_value = !l_.HasReachedEnd() ? *l_ : std::numeric_limits::max(); DocId right_value = !r_.HasReachedEnd() ? *r_ : std::numeric_limits::max(); current_min_ = std::min(left_value, right_value); } inline TwoBlocksRangeResult::MergingIterator TwoBlocksRangeResult::begin() const { return MergingIterator{MakeBegin(*left_block_, l_, r_), MakeBegin(*right_block_, l_, r_)}; } inline TwoBlocksRangeResult::MergingIterator TwoBlocksRangeResult::end() const { return MergingIterator{MakeEnd(*left_block_, l_, r_), MakeEnd(*right_block_, l_, r_)}; } inline RangeResult::Variant& RangeResult::GetResult() { return result_; } inline const RangeResult::Variant& RangeResult::GetResult() const { return result_; } } // namespace dfly::search ================================================ FILE: src/core/search/range_tree_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/range_tree.h" #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "util/fibers/fibers.h" namespace dfly::search { class RangeTreeTest : public testing::Test { protected: }; static constexpr double kMinRangeValue = std::numeric_limits::min(); static constexpr double kMaxRangeValue = std::numeric_limits::max(); using Entry = std::pair; using BlocksList = absl::InlinedVector; std::vector ExtractDocPairs(const BlocksList& result) { std::vector out; for (const auto& block : result) { for (const auto& entry : *block) { out.push_back(entry); } } return out; } std::vector> ExtractAllBlocks(const BlocksList& result) { std::vector> all; for (const auto& block : result) { std::vector block_entries; for (const auto& entry : *block) { block_entries.push_back(entry); } all.push_back(std::move(block_entries)); } return all; } MATCHER_P(UnorderedElementsAreDocPairsMatcher, expected_matchers, "") { return testing::ExplainMatchResult(testing::UnorderedElementsAreArray(expected_matchers), ExtractDocPairs(arg), result_listener); } MATCHER_P(BlocksAreMatcher, expected_blocks, "") { std::vector>> matchers; for (const auto& expected_entries : expected_blocks) { matchers.push_back(testing::UnorderedElementsAreArray(expected_entries)); } return testing::ExplainMatchResult(testing::ElementsAreArray(matchers), ExtractAllBlocks(arg), result_listener); } auto UnorderedElementsAreDocPairs(std::vector list) { return UnorderedElementsAreDocPairsMatcher(std::move(list)); } auto BlocksAre(std::initializer_list> blocks) { return BlocksAreMatcher(std::vector>(blocks)); } std::vector ExtractDocIdsFromRange(const std::vector& entries, double l, double r) { std::vector result; for (const auto& entry : entries) { if (entry.second >= l && entry.second <= r) { result.push_back(entry.first); } } std::sort(result.begin(), result.end()); result.erase(std::unique(result.begin(), result.end()), result.end()); return result; } std::vector MergeTwoBlocksRangeResult(const RangeTree& tree, double l, double r) { auto result = tree.Range(l, r).GetResult(); DCHECK(std::holds_alternative(result)); auto& two_blocks_result = std::get(result); return {two_blocks_result.begin(), two_blocks_result.end()}; } TEST_F(RangeTreeTest, AddSimple) { RangeTree tree{PMR_NS::get_default_resource()}; // Add some values tree.Add(1, 10.0); tree.Add(2, 20.0); tree.Add(2, 10.0); tree.Add(3, 30.0); tree.Add(4, 40.0); tree.Add(4, 60.0); auto result = tree.GetAllBlocks(); EXPECT_THAT(result, UnorderedElementsAreDocPairs( {{1, 10.0}, {2, 10.0}, {2, 20.0}, {3, 30.0}, {4, 40.0}, {4, 60.0}})); } TEST_F(RangeTreeTest, Add) { RangeTree tree{PMR_NS::get_default_resource(), 2}; // Add some values tree.Add(1, 10.0); tree.Add(1, 20.0); tree.Add(2, 20.0); tree.Add(3, 20.0); tree.Add(4, 30.0); tree.Add(5, 30.0); tree.Add(6, 30.0); auto result = tree.RangeBlocks(10.0, 30.0); EXPECT_THAT(result, UnorderedElementsAreDocPairs( {{1, 10.0}, {1, 20.0}, {2, 20.0}, {3, 20.0}, {4, 30.0}, {5, 30.0}, {6, 30.0}})); // Test that the ranges was split correctly result = tree.RangeBlocks(kMinRangeValue, 19.0); EXPECT_THAT(result, UnorderedElementsAreDocPairs({{1, 10.0}})); result = tree.RangeBlocks(20.0, 29.0); EXPECT_THAT(result, UnorderedElementsAreDocPairs({{1, 20.0}, {2, 20.0}, {3, 20.0}})); result = tree.RangeBlocks(30.0, kMaxRangeValue); EXPECT_THAT(result, UnorderedElementsAreDocPairs({{4, 30.0}, {5, 30.0}, {6, 30.0}})); } TEST_F(RangeTreeTest, RemoveSimple) { RangeTree tree{PMR_NS::get_default_resource(), 2}; // Add some values tree.Add(1, 10.0); tree.Add(2, 20.0); tree.Add(3, 30.0); tree.Add(4, 40.0); // Remove some values tree.Remove(1, 10.0); tree.Remove(2, 20.0); auto result = tree.GetAllBlocks(); EXPECT_THAT(result, UnorderedElementsAreDocPairs({{3, 30.0}, {4, 40.0}})); } TEST_F(RangeTreeTest, Remove) { using Container = std::vector; Container expected_values; RangeTree tree{PMR_NS::get_default_resource(), 2}; const long long max_value = 100; long long step = 23; long long current_value = max_value; auto do_add = [&](DocId i) { const double value = static_cast(current_value); auto it = std::find(expected_values.begin(), expected_values.end(), std::make_pair(i, value)); if (it != expected_values.end()) { // If the value already exists, we do not add it again // The problem is that for now RangeTree does not support duplicates // TODO: fix this return; } // Otherwise, we add it to the expected values and to the tree expected_values.emplace_back(i, value); tree.Add(i, value); current_value = (max_value + current_value - step) % max_value; }; auto add_entries_with_step = [&](size_t step) { for (size_t i = 0; i < 100; i += step) { do_add(i); } }; auto do_remove = [&](size_t i) { auto pair = expected_values[i]; tree.Remove(pair.first, pair.second); }; auto remove_entries_with_step = [&](size_t step) { Container expected_values_copy; for (size_t i = 0; i < expected_values.size(); i++) { if (i % step == 0) { do_remove(i); } else { expected_values_copy.push_back(expected_values[i]); } } expected_values = std::move(expected_values_copy); }; // First wave of Add and Remove add_entries_with_step(1); step = 37; current_value = max_value; add_entries_with_step(3); // Remove some values remove_entries_with_step(3); auto result = tree.GetAllBlocks(); EXPECT_THAT(result, UnorderedElementsAreDocPairs(expected_values)); // Second wave of Add and Remove step = 31; current_value = max_value; add_entries_with_step(5); // Remove a first half of the values remove_entries_with_step(2); result = tree.GetAllBlocks(); EXPECT_THAT(result, UnorderedElementsAreDocPairs(expected_values)); // Remove all values remove_entries_with_step(1); result = tree.GetAllBlocks(); EXPECT_THAT(result, UnorderedElementsAreDocPairs({})); } TEST_F(RangeTreeTest, RangeSimple) { RangeTree tree{PMR_NS::get_default_resource(), 1}; // Add some values tree.Add(1, 10.0); tree.Add(1, 20.0); tree.Add(2, 20.0); tree.Add(2, 30.0); tree.Add(3, 30.0); tree.Add(3, 40.0); tree.Add(4, 40.0); auto result = tree.RangeBlocks(10.0, 10.0); EXPECT_THAT(result, BlocksAre({{{1, 10.0}}})); result = tree.RangeBlocks(20.0, 20.0); EXPECT_THAT(result, BlocksAre({{{1, 20.0}, {2, 20.0}}})); result = tree.RangeBlocks(30.0, 30.0); EXPECT_THAT(result, BlocksAre({{{2, 30.0}, {3, 30.0}}})); result = tree.RangeBlocks(40.0, 40.0); EXPECT_THAT(result, BlocksAre({{{3, 40.0}, {4, 40.0}}})); result = tree.RangeBlocks(10.0, 30.0); EXPECT_THAT(result, BlocksAre({{{1, 10.0}}, {{1, 20.0}, {2, 20.0}}, {{2, 30.0}, {3, 30.0}}})); result = tree.RangeBlocks(20.0, 40.0); EXPECT_THAT(result, BlocksAre({{{1, 20.0}, {2, 20.0}}, {{2, 30.0}, {3, 30.0}}, {{3, 40.0}, {4, 40.0}}})); result = tree.RangeBlocks(10.0, 40.0); EXPECT_THAT( result, BlocksAre( {{{1, 10.0}}, {{1, 20.0}, {2, 20.0}}, {{2, 30.0}, {3, 30.0}}, {{3, 40.0}, {4, 40.0}}})); } TEST_F(RangeTreeTest, Range) { { RangeTree tree{PMR_NS::get_default_resource(), 4}; tree.Add(1, 10.0); tree.Add(1, 20.0); tree.Add(2, 20.0); tree.Add(3, 30.0); tree.Add(4, 20.0); tree.Add(4, 30.0); auto result = tree.RangeBlocks(10.0, 30.0); EXPECT_THAT( result, BlocksAre({{{1, 10.0}}, {{1, 20.0}, {2, 20.0}, {4, 20.0}}, {{3, 30.0}, {4, 30.0}}})); } { RangeTree tree{PMR_NS::get_default_resource(), 4}; tree.Add(1, 10.0); tree.Add(1, 20.0); tree.Add(2, 20.0); tree.Add(3, 20.0); tree.Add(4, 20.0); auto result = tree.RangeBlocks(10.0, 20.0); EXPECT_THAT(result, BlocksAre({{{1, 10.0}}, {{1, 20.0}, {2, 20.0}, {3, 20.0}, {4, 20.0}}})); } { RangeTree tree{PMR_NS::get_default_resource(), 4}; tree.Add(1, 10.0); tree.Add(2, 10.0); tree.Add(3, 10.0); tree.Add(4, 20.0); tree.Add(4, 10.0); auto result = tree.RangeBlocks(10.0, 20.0); EXPECT_THAT(result, BlocksAre({{{1, 10.0}, {2, 10.0}, {3, 10.0}, {4, 10.0}}, {{4, 20.0}}})); } } // Don't split single block with same value TEST_F(RangeTreeTest, SingleBlockSplit) { RangeTree tree{PMR_NS::get_default_resource(), 4}; for (DocId id = 1; id <= 16; id++) tree.Add(id, 5.0); // One split was made to create an empty leftmost block auto stats = tree.GetStats(); EXPECT_EQ(stats.splits, 1u); EXPECT_EQ(stats.block_count, 2u); // Add value that causes a new block to be started tree.Add(20, 6.0); stats = tree.GetStats(); EXPECT_EQ(stats.splits, 1u); // detected ahead, so no split EXPECT_EQ(stats.block_count, 3u); // but new block // No more splits with same 5.0 tree.Add(17, 5.0); stats = tree.GetStats(); EXPECT_EQ(stats.splits, 1u); // Verify block sizes auto blocks = tree.GetAllBlocks(); EXPECT_EQ(blocks[0]->Size(), 0u); EXPECT_EQ(blocks[1]->Size(), 17u); EXPECT_EQ(blocks[2]->Size(), 1u); } // Make tree split and then delete every nth value to see if blocks merge properly TEST_F(RangeTreeTest, BlockMerge) { RangeTree tree{PMR_NS::get_default_resource(), 8}; for (DocId id = 1; id <= 64; id++) tree.Add(id, id); auto stats = tree.GetStats(); uint64_t splits = stats.splits; EXPECT_GT(splits, 8u); // Blocks have at least half occupancy EXPECT_GT(stats.block_count, 64 / 8); EXPECT_LT(stats.block_count, 2 * 64 / 8); // Delete all except %8 = 0, should trigger merge std::vector expected; for (DocId id = 1; id <= 64; id++) { if (id % 8) tree.Remove(id, id); else expected.emplace_back(id, id); } // Only one block left now stats = tree.GetStats(); size_t blocks = stats.block_count; EXPECT_LT(blocks, 4u); EXPECT_EQ(stats.merges + blocks - 1, splits); // Check the two entries remained auto result = tree.GetAllBlocks(); EXPECT_THAT(result, UnorderedElementsAreDocPairs(expected)); } TEST_F(RangeTreeTest, BugNotUniqueDoubleValues) { // TODO: fix the bug GTEST_SKIP() << "Bug not fixed yet"; RangeTree tree{PMR_NS::get_default_resource()}; tree.Add(1, 10.0); tree.Add(1, 10.0); tree.Remove(1, 10.0); auto result = tree.GetAllBlocks(); EXPECT_THAT(result, BlocksAre({{{1, 10.0}}})); } TEST_F(RangeTreeTest, RangeResultTwoBlocksSimple) { RangeTree tree{PMR_NS::get_default_resource(), 4}; // First block: [[1, 10.0], [16, 12.0], [12, 15.0], [5, 17.0]] // Second block: [[8, 20.0], [5, 30.0], [12, 50.0], [20, 55.0]] // [10.0, 12.0, 15.0, 17.0] | [20.0, 30.0, 50.0, 55.0] tree.Add(1, 10.0); // 1 tree.Add(5, 30.0); // 2 tree.Add(20, 55.0); // 2 tree.Add(5, 17.0); // 1 tree.Add(8, 20.0); // 2 tree.Add(16, 12.0); // 1 tree.Add(12, 15.0); // 1 tree.Add(12, 50.0); // 2 EXPECT_THAT(tree.RangeBlocks(10.0, 55.0), BlocksAre({{{1, 10.0}, {16, 12.0}, {12, 15.0}, {5, 17.0}}, {{8, 20.0}, {5, 30.0}, {12, 50.0}, {20, 55.0}}})); std::vector entries = {{1, 10.0}, {16, 12.0}, {12, 15.0}, {5, 17.0}, {8, 20.0}, {5, 30.0}, {12, 50.0}, {20, 55.0}}; for (size_t i = 0; i < entries.size() / 2; i++) { const double l = entries[i].second; for (size_t j = entries.size() / 2; j < entries.size(); j++) { const double r = entries[j].second; auto range_result = MergeTwoBlocksRangeResult(tree, l, r); EXPECT_THAT(range_result, testing::ElementsAreArray(ExtractDocIdsFromRange(entries, l, r))); } } } TEST_F(RangeTreeTest, RangeResultTwoBlocks) { RangeTree tree{PMR_NS::get_default_resource(), 50}; const long long max_value = 100; long long step = 23; long long current_value = max_value; std::vector entries; for (size_t i = 0; i < 20; i++) { const double value = static_cast(current_value); entries.emplace_back(i, value); entries.emplace_back(i, value + 100.0); current_value = (max_value + current_value - step) % max_value; } for (size_t i = 20; i < 80; i++) { const double value = static_cast(current_value); entries.emplace_back(i, value); current_value = (max_value + current_value - step) % max_value; } DCHECK(entries.size() == 100); std::sort(entries.begin(), entries.end(), [](const Entry& a, const Entry& b) { return a.second < b.second; }); auto add_entries = [&tree, &entries](size_t start, size_t end) { for (size_t i = start; i < end; i++) { tree.Add(entries[i].first, entries[i].second); } }; add_entries(0, 25); add_entries(50, 76); add_entries(25, 50); add_entries(76, entries.size()); for (size_t i = 0; i < 50; i++) { const double l = entries[i].second; for (size_t j = 50; j < entries.size(); j++) { const double r = entries[j].second; auto range_result = MergeTwoBlocksRangeResult(tree, l, r); EXPECT_THAT(range_result, testing::ElementsAreArray(ExtractDocIdsFromRange(entries, l, r))); } } } struct BuilderTest : public RangeTreeTest { static void Shuffle(std::vector* entries) { std::random_device rd; std::shuffle(entries->begin(), entries->end(), std::mt19937(rd())); } }; // Test if the builder builds the tree correctly TEST_F(BuilderTest, Builder) { RangeTree tree{PMR_NS::get_default_resource(), 4}; RangeTree::Builder builder; // Prepare entries shuffled std::vector entries; entries.reserve(100); for (size_t i = 0; i < 120; i++) entries.emplace_back(i, double(i) / 2); Shuffle(&entries); // Add fake entries for (auto [id, v] : entries) { builder.Add(id, v * 2); } // Add all entries for real for (auto [id, v] : entries) { builder.Remove(id, v * 2); builder.Add(id, v); } // Shuffle again Shuffle(&entries); // Remove last while (entries.size() > 100) { builder.Remove(entries.back().first, entries.back().second); entries.pop_back(); } // Build tree builder.Populate(&tree, RenewableQuota::Unlimited()); // Sort for comparisons std::ranges::sort(entries, {}, &RangeTree::Entry::first); auto entry_ids = entries | std::views::keys; // Check correctness of all ids { auto all_values = tree.Range(-1000, +1000); auto got_ids = all_values.Take(); EXPECT_TRUE(std::ranges::equal(got_ids, entry_ids)); } // Check correctness of all values including ids { auto all_pairs = ExtractDocPairs(tree.GetAllBlocks()); std::sort(all_pairs.begin(), all_pairs.end()); EXPECT_EQ(all_pairs, entries); } } TEST_F(BuilderTest, BuilderUpdates) { RangeTree tree{PMR_NS::get_default_resource(), 5}; RangeTree::Builder builder; // Prepare entries shuffled std::vector entries; entries.reserve(1000); for (size_t i = 0; i < 1000; i++) { entries.emplace_back(i, double(i) / 2); entries.emplace_back(i, double(i) / 2 + 0.25); } Shuffle(&entries); // Insert entries for (auto entry : entries) builder.Add(entry.first, entry.second); // Construct while suspending at every node bool done = false; util::fb2::Fiber populate_fb{[&] { builder.Populate(&tree, {0}); // suspend each time done = true; }}; // In the meantime insert new entries DocId current = entries.size(); bool add = false; size_t added = 0; absl::InsecureBitGen gen; while (!done) { if (add) { entries.emplace_back(current, double(current) / 2); builder.Add(entries.back().first, entries.back().second); current++; } else { size_t idx = absl::Uniform(gen, size_t{0}, entries.size()); auto it = entries.begin() + idx; builder.Remove(it->first, it->second); // Change our mind with 50% prob and just update if (current % 2 == 0) { it->second += 1; builder.Add(it->first, it->second); } else { entries.erase(it); } } add = !add; added++; util::ThisFiber::Yield(); } EXPECT_GT(added, 5u); // At least some updates were performed populate_fb.Join(); // Sort for comparisons std::sort(entries.begin(), entries.end()); // auto entry_ids_view = entries | std::views::keys; // Check correctness of all ids // TODO: Range tree doesn't filter duplicate ids //{ // auto all_values = tree.Range(-100000, +100000); // auto got_ids = all_values.Take(); // // std::set entry_ids_set(entry_ids_view.begin(), entry_ids_view.end()); // std::vector entry_ids_vec(entry_ids_set.begin(), entry_ids_set.end()); // // EXPECT_EQ(got_ids, entry_ids_vec); //} // Check correctness of all values including ids { auto all_pairs = ExtractDocPairs(tree.GetAllBlocks()); std::sort(all_pairs.begin(), all_pairs.end()); EXPECT_EQ(all_pairs, entries); } } // Test tree doesn't create unnecessary nodes after initialization TEST_F(RangeTreeTest, DiscreteIntialization) { RangeTree tree{PMR_NS::get_default_resource(), 4}; RangeTree::Builder builder; for (size_t i = 0; i < 32; i++) { builder.Add(i, i % 4); } builder.Populate(&tree, RenewableQuota::Unlimited()); auto result = tree.GetAllBlocks(); EXPECT_EQ(result.size(), 4u); } // Benchmark tree insertion performance with set of discrete values static void BM_DiscreteInsertion(benchmark::State& state) { RangeTree tree{PMR_NS::get_default_resource()}; absl::InsecureBitGen gen{}; size_t variety = state.range(0); DocId id = 0; for (auto _ : state) { double v = absl::Uniform(gen, 0u, variety); tree.Add(id++, v); } } BENCHMARK(BM_DiscreteInsertion)->Arg(2)->Arg(12)->Arg(128)->Arg(1024); } // namespace dfly::search ================================================ FILE: src/core/search/rax_tree.h ================================================ #pragma once #include #include #include #include #include #include #include "base/pmr/memory_resource.h" extern "C" { #include "redis/rax.h" } namespace detail { // Copies an iterators state into another by performing a fresh seek on the source's key. While this // is a little more expensive, it is done to avoid deep copying pointers from raxIterator and // raxStart while taking care of self-reference links in both structs. The return value is used to // decide whether to advance iterator after a successful seek. inline bool CopyIteratorState(raxIterator& destination, raxIterator& source) { raxStart(&destination, source.rt); if (!destination.rt) return false; if (!raxSeek(&destination, "=", source.key, source.key_len)) { // called from constructor, so no error can be returned. but set up the same state as // the SeekIterator constructor, so that it will return true on comparison to RaxTreeMap::end() raxStop(&destination); destination.rt = nullptr; return false; } return true; } } // namespace detail namespace dfly::search { // absl::flat_hash_map/std::unordered_map compatible tree map based on rax tree. // Allocates all objects on heap (with custom memory resource) as rax tree operates fully on // pointers. // TODO: Add full support for polymorphic allocators, including rax trie node allocations template struct RaxTreeMap { using value_type = V; struct FindIterator; // Simple seeking iterator struct SeekIterator { SeekIterator() { it_.rt = nullptr; } SeekIterator(rax* tree, const char* op, std::string_view key) { raxStart(&it_, tree); if (raxSeek(&it_, op, to_key_ptr(key), key.size())) { // Successfuly seeked operator++(); } else { InvalidateIterator(); } } explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) { } SeekIterator(SeekIterator&& other) noexcept : it_{} { *this = std::move(other); } SeekIterator& operator=(SeekIterator&& other) noexcept { if (this != &other) { if (IsValid()) { InvalidateIterator(); } if (::detail::CopyIteratorState(it_, other.it_)) operator++(); if (other.IsValid()) other.InvalidateIterator(); } return *this; } /* Copy constructor deleted to avoid double iterator invalidation */ SeekIterator(const SeekIterator&) = delete; SeekIterator& operator=(const SeekIterator&) = delete; ~SeekIterator() { if (IsValid()) { InvalidateIterator(); } } bool operator==(const SeekIterator& rhs) const { if (!IsValid() || !rhs.IsValid()) return !IsValid() && !rhs.IsValid(); return it_.node == rhs.it_.node; } bool operator!=(const SeekIterator& rhs) const { return !operator==(rhs); } SeekIterator& operator++() { int next_result = raxNext(&it_); if (!next_result) { // OOM or we reached the end of the tree InvalidateIterator(); } return *this; } /* After operator++() the first value (string_view) is invalid. So make sure your copied it to * string */ std::pair operator*() const { assert(IsValid() && it_.node && it_.node->iskey && it_.data); return {std::string_view{reinterpret_cast(it_.key), it_.key_len}, *reinterpret_cast(it_.data)}; } bool IsValid() const { return it_.rt; } private: void InvalidateIterator() { raxStop(&it_); it_.rt = nullptr; } raxIterator it_; }; using iterator = SeekIterator; // Result of find() call. Inherits from pair to mimic iterator interface, not incrementable. struct FindIterator : public std::optional> { bool operator==(const SeekIterator& rhs) const { if (!this->has_value() || !rhs.IsValid()) return !this->has_value() && !rhs.IsValid(); return (*this)->first == (*rhs).first; } bool operator!=(const SeekIterator& rhs) const { return !operator==(rhs); } }; public: explicit RaxTreeMap(PMR_NS::memory_resource* mr) : tree_(raxNew()), alloc_(mr) { } ~RaxTreeMap() { using Allocator = decltype(alloc_); auto free_callback = [](void* data, void* context) { Allocator* allocator = static_cast(context); V* ptr = static_cast(data); std::allocator_traits::destroy(*allocator, ptr); allocator->deallocate(ptr, 1); }; raxFreeWithCallbackAndArgument(tree_, free_callback, &alloc_); } size_t size() const { return raxSize(tree_); } auto begin() const { return SeekIterator{tree_}; } auto end() const { return SeekIterator{}; } auto lower_bound(std::string_view key) const { return SeekIterator{tree_, ">=", key}; } FindIterator find(std::string_view key) const { if (void* ptr = nullptr; raxFind(tree_, to_key_ptr(key), key.size(), &ptr)) return FindIterator{std::pair(std::string(key), *reinterpret_cast(ptr))}; return FindIterator{std::nullopt}; } template std::pair try_emplace(std::string_view key, Args&&... args); void erase(FindIterator it) { V* old = nullptr; raxRemove(tree_, to_key_ptr(it->first.data()), it->first.size(), reinterpret_cast(&old)); std::allocator_traits::destroy(alloc_, old); alloc_.deallocate(old, 1); } auto& get_allocator() const { return alloc_; } private: static unsigned char* to_key_ptr(std::string_view key) { return reinterpret_cast(const_cast(key.data())); } rax* tree_; PMR_NS::polymorphic_allocator alloc_; }; template template std::pair::FindIterator, bool> RaxTreeMap::try_emplace( std::string_view key, Args&&... args) { if (auto it = find(key); it) return {it, false}; V* ptr = alloc_.allocate(1); std::allocator_traits::construct(alloc_, ptr, std::forward(args)...); V* old = nullptr; raxInsert(tree_, to_key_ptr(key), key.size(), ptr, reinterpret_cast(&old)); assert(!old); auto it = std::make_optional(std::pair(std::string(key), *ptr)); return std::make_pair(std::move(FindIterator{it}), true); } } // namespace dfly::search ================================================ FILE: src/core/search/rax_tree_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/rax_tree.h" #include #include #include #include #include #include #include "base/gtest.h" #include "base/iterator.h" #include "base/logging.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly::search { using namespace std; struct RaxTreeTest : public ::testing::Test { static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); } }; TEST_F(RaxTreeTest, EmplaceAndIterate) { RaxTreeMap map(pmr::get_default_resource()); vector> elements(90); for (int i = 10; i < 100; i++) elements[i - 10] = make_pair(absl::StrCat("key-", i), absl::StrCat("value-", i)); for (auto& [key, value] : elements) { auto [it, inserted] = map.try_emplace(key, value); EXPECT_TRUE(inserted); EXPECT_EQ(it->first, key); EXPECT_EQ(it->second, value); } size_t i = 0; for (auto [key, value] : map) { EXPECT_EQ(elements[i].first, key); EXPECT_EQ(elements[i].second, value); i++; } } TEST_F(RaxTreeTest, LowerBound) { RaxTreeMap map(pmr::get_default_resource()); vector keys; for (unsigned i = 0; i < 5; i++) { for (unsigned j = 0; j < 5; j++) { keys.emplace_back(absl::StrCat("key-", string(1, 'a' + i), "-", j)); map.try_emplace(keys.back(), 0); } } auto it1 = map.lower_bound("key-c-3"); auto it2 = lower_bound(keys.begin(), keys.end(), "key-c-3"); while (it1 != map.end()) { EXPECT_EQ((*it1).first, *it2); ++it1; ++it2; } EXPECT_TRUE(it1 == map.end()); EXPECT_TRUE(it2 == keys.end()); // Test lower bound empty string vector keys2; for (auto it = map.lower_bound(string_view{}); it != map.end(); ++it) keys2.emplace_back((*it).first); EXPECT_EQ(keys, keys2); } TEST_F(RaxTreeTest, Find) { RaxTreeMap map(pmr::get_default_resource()); for (unsigned i = 100; i < 999; i += 2) map.try_emplace(absl::StrCat("value-", i), i); auto it = map.begin(); for (unsigned i = 100; i < 999; i++) { auto fit = map.find(absl::StrCat("value-", i)); if (i % 2 == 0) { EXPECT_TRUE(fit == it); EXPECT_EQ(fit->second, i); ++it; } else { EXPECT_TRUE(fit == map.end()); } } // Test find with empty string EXPECT_TRUE(map.find(string_view{}) == map.end()); } /* Run with mimalloc to make sure there is no double free */ TEST_F(RaxTreeTest, Iterate) { const char* kKeys[] = { "aaaaaaaaaaaaaaaaaaaa", "bbbbbbbbbbbbbbbbbbbbbb" "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd" "eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee", }; RaxTreeMap map(pmr::get_default_resource()); for (const char* key : kKeys) { map.try_emplace(key, 2); } for (auto it = map.begin(); it != map.end(); ++it) { EXPECT_EQ((*it).second, 2); } for (auto it = map.begin(); it != map.end(); ++it) { EXPECT_EQ((*it).second, 2); } } TEST_F(RaxTreeTest, MoveIterator) { RaxTreeMap m{pmr::get_default_resource()}; RaxTreeMap::SeekIterator tmp; { // empty map, iterator invalidated on construction tmp = m.begin(); const auto it = std::move(tmp); EXPECT_FALSE(tmp.IsValid()); EXPECT_FALSE(it.IsValid()); } { tmp = m.end(); const auto it = std::move(tmp); EXPECT_FALSE(tmp.IsValid()); EXPECT_FALSE(it.IsValid()); EXPECT_EQ(it, m.end()); } m.try_emplace("first", true); m.try_emplace("second", false); { tmp = m.begin(); RaxTreeMap::SeekIterator it{std::move(tmp)}; EXPECT_FALSE(tmp.IsValid()); EXPECT_TRUE(it.IsValid()); EXPECT_EQ((*it).first, "first"); EXPECT_TRUE((*it).second); ++it; EXPECT_EQ((*it).first, "second"); EXPECT_FALSE((*it).second); ++it; EXPECT_EQ(it, m.end()); } { // advance before moving, the moved-to iterator should pick where the moved-from left off tmp = m.lower_bound("fig"); EXPECT_TRUE(tmp.IsValid()); ++tmp; EXPECT_EQ((*tmp).first, "second"); auto it = std::move(tmp); EXPECT_FALSE(tmp.IsValid()); EXPECT_TRUE(it.IsValid()); EXPECT_EQ((*it).first, "second"); ++it; EXPECT_FALSE(it.IsValid()); EXPECT_EQ(it, m.end()); } { // move into valid iterator auto it = m.begin(); EXPECT_EQ((*it).first, "first"); tmp = m.lower_bound("sea"); EXPECT_EQ((*tmp).first, "second"); it = std::move(tmp); EXPECT_FALSE(tmp.IsValid()); EXPECT_TRUE(it.IsValid()); EXPECT_EQ((*it).first, "second"); ++it; EXPECT_FALSE(it.IsValid()); EXPECT_EQ(it, m.end()); } { auto it = m.lower_bound("sea"); EXPECT_EQ((*it).first, "second"); tmp = m.end(); it = std::move(tmp); EXPECT_FALSE(it.IsValid()); } } } // namespace dfly::search ================================================ FILE: src/core/search/renewable_quota.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/renewable_quota.h" #include "base/cycle_clock.h" #include "base/logging.h" #include "util/fibers/fibers.h" namespace dfly::search { RenewableQuota RenewableQuota::Unlimited() { return RenewableQuota{std::numeric_limits::max()}; } // Quota that yields if the fiber is running for too long void RenewableQuota::Check(std::source_location location) const { size_t cycles = util::ThisFiber::GetRunningTimeCycles(); size_t usec = base::CycleClock::ToUsec(cycles); if (usec >= max_usec) { size_t ms = usec / 1'000; VLOG_IF(1, ms >= 50) << "Grabbed " << ms << "ms for " << location.file_name() << ":" << location.line(); util::ThisFiber::Yield(); } } } // namespace dfly::search ================================================ FILE: src/core/search/renewable_quota.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace dfly::search { // Running time quota that can be reset by suspending the fiber struct RenewableQuota { // Create unlimited quota static RenewableQuota Unlimited(); // Check if quota is remaining and suspend the fiber if it ran out void Check(std::source_location location = std::source_location::current()) const; const size_t max_usec; }; } // namespace dfly::search ================================================ FILE: src/core/search/scanner.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once // We should not include lexer.h when compiling from lexer.cc file because it already // includes lexer.h #ifndef DFLY_LEXER_CC #include "core/search/lexer.h" #endif #include #include "base/logging.h" namespace dfly { namespace search { class Scanner : public Lexer { public: Scanner() : params_{nullptr} { } Parser::symbol_type Lex(); void SetParams(const QueryParams* params) { params_ = params; } private: std::string_view matched_view(size_t skip_left = 0, size_t skip_right = 0) const { std::string_view res(matcher().begin() + skip_left, matcher().size() - skip_left - skip_right); return res; } dfly::search::location loc() { return location(); } Parser::symbol_type ParseParam(std::string_view name, const Parser::location_type& loc) { name.remove_prefix(1); // drop $ symbol std::string_view str = (*params_)[name]; if (str.empty()) throw std::runtime_error(absl::StrCat("Query parameter ", name, " not found")); uint32_t val = 0; if (!absl::SimpleAtoi(str, &val)) return Parser::make_TERM(std::string{str}, loc); return Parser::make_UINT32(std::string{str}, loc); } private: const QueryParams* params_; }; } // namespace search } // namespace dfly ================================================ FILE: src/core/search/search.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/search.h" #include #include #include #include #include #include #include #include "base/logging.h" #include "core/overloaded.h" #include "core/search/ast_expr.h" #include "core/search/index_result.h" #include "core/search/indices.h" #include "core/search/query_driver.h" #include "core/search/sort_indices.h" #include "core/search/tag_types.h" #include "core/search/vector_utils.h" using namespace std; namespace dfly::search { namespace { AstExpr ParseQuery(std::string_view query, const QueryParams* params, const OptionalFilters* filters) { QueryDriver driver{}; driver.ResetScanner(); driver.SetParams(params); driver.SetInput(std::string{query}); (void)Parser (&driver)(); // can throw driver.SetOptionalFilters(filters); return driver.Take(); } // GCC 12 yields a wrong warning in a deeply inlined call in UnifyResults, only ignoring the whole // scope solves it #ifndef __clang__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif struct ProfileBuilder { struct NodeFormatter { template void operator()(std::string* out, const AstAffixNode& node) const { out->append(node.affix); } void operator()(std::string* out, const AstTagsNode::TagValue& value) const { visit([this, out](const auto& n) { this->operator()(out, n); }, value); } }; string GetNodeInfo(const AstNode& node) { Overloaded node_info{ [](monostate) -> string { return ""s; }, [](const AstTermNode& n) { return absl::StrCat("Term{", n.affix, "}"); }, [](const AstPrefixNode& n) { return absl::StrCat("Prefix{", n.affix, "}"); }, [](const AstSuffixNode& n) { return absl::StrCat("Suffix{", n.affix, "}"); }, [](const AstInfixNode& n) { return absl::StrCat("Infix{", n.affix, "}"); }, [](const AstRangeNode& n) { return absl::StrCat("Range{", n.lo, "<>", n.hi, "}"); }, [](const AstLogicalNode& n) { auto op = n.op == AstLogicalNode::AND ? "and" : "or"; return absl::StrCat("Logical{n=", n.nodes.size(), ",o=", op, "}"); }, [](const AstTagsNode& n) { return absl::StrCat("Tags{", absl::StrJoin(n.tags, ",", NodeFormatter()), "}"); }, [](const AstFieldNode& n) { return absl::StrCat("Field{", n.field, "}"); }, [](const AstKnnNode& n) { return absl::StrCat("KNN{l=", n.limit, "}"); }, [](const AstNegateNode& n) { return absl::StrCat("Negate{}"); }, [](const AstStarNode& n) { return absl::StrCat("Star{}"); }, [](const AstStarFieldNode& n) { return absl::StrCat("StarField{}"); }, [](const AstGeoNode& n) { return absl::StrCat("Geo{", n.lat, " ", n.lon, " ", n.radius, " ", n.unit, "}"); }, [](const AstVectorRangeNode& n) { return absl::StrCat("VectorRange{r=", n.radius, "}"); }, }; return visit(node_info, node.Variant()); } using Tp = std::chrono::steady_clock::time_point; Tp Start() { depth_++; return chrono::steady_clock::now(); } void Finish(Tp start, const AstNode& node, const IndexResult& result) { DCHECK_GE(depth_, 1u); auto took = chrono::steady_clock::now() - start; size_t micros = chrono::duration_cast(took).count(); auto descr = GetNodeInfo(node); profile_.events.push_back({std::move(descr), micros, depth_ - 1, result.ApproximateSize()}); depth_--; } AlgorithmProfile Take() { reverse(profile_.events.begin(), profile_.events.end()); return std::move(profile_); } private: size_t depth_; AlgorithmProfile profile_; }; struct BasicSearch { using LogicOp = AstLogicalNode::LogicOp; BasicSearch(const FieldIndices* indices) : indices_{indices} { } void EnableProfiling() { profile_builder_ = ProfileBuilder{}; } BaseIndex* GetBaseIndex(string_view field) { auto index = indices_->GetIndex(field); if (!index) { error_ = absl::StrCat("Invalid field: ", field); return nullptr; } return index; } // Get casted sub index by field template T* GetIndex(string_view field) { static_assert(is_base_of_v); auto base_index = GetBaseIndex(field); if (!base_index) { return nullptr; } auto* casted_ptr = dynamic_cast(base_index); if (!casted_ptr) { error_ = absl::StrCat("Wrong access type for field: ", field); return nullptr; } return casted_ptr; } BaseSortIndex* GetSortIndex(string_view field) { auto index = indices_->GetSortIndex(field); if (!index) { error_ = absl::StrCat("Invalid sort field: ", field); return nullptr; } return index; } // Collect all index results from F(C[i]) template vector GetSubResults(const C& container, const F& f) { vector sub_results(container.size()); for (size_t i = 0; i < container.size(); i++) sub_results[i] = IndexResult{f(container[i])}; return sub_results; } void Merge(IndexResult matched, IndexResult* current_ptr, LogicOp op) { IndexResult& current = *current_ptr; auto vec = MergeIndexResults(matched, current, op); current = IndexResult{std::move(vec)}; } // Efficiently unify multiple sub results with specified logical op IndexResult UnifyResults(vector&& sub_results, LogicOp op) { if (sub_results.empty()) return IndexResult{}; // Unifying from smallest to largest is more efficient. // AND: the result only shrinks, so starting with the smallest is most optimal. // OR: unifying smaller sets first reduces the number of element traversals on average. sort(sub_results.begin(), sub_results.end(), [](const auto& l, const auto& r) { return l.ApproximateSize() < r.ApproximateSize(); }); IndexResult out{std::move(sub_results[0])}; for (auto& matched : absl::MakeSpan(sub_results).subspan(1)) Merge(std::move(matched), &out, op); return out; } template IndexResult CollectMatches(BaseStringIndex* index, std::string_view word, F&& f) { IndexResult result{}; invoke(f, *index, word, [&result, this](const auto* c) { Merge(IndexResult{c}, &result, LogicOp::OR); }); return result; } IndexResult Search(monostate, string_view) { return IndexResult{}; } IndexResult Search(const AstStarNode& node, string_view active_field) { DCHECK(active_field.empty()); return IndexResult{&indices_->GetAllDocs()}; } IndexResult Search(const AstStarFieldNode& node, string_view active_field) { // Try to get a sort index first, as `@field:*` might imply wanting sortable behavior BaseSortIndex* sort_index = indices_->GetSortIndex(active_field); if (sort_index) { return IndexResult{sort_index->GetAllDocsWithNonNullValues()}; } // If sort index doesn't exist try regular index BaseIndex* base_index = GetBaseIndex(active_field); return base_index ? IndexResult{base_index->GetAllDocsWithNonNullValues()} : IndexResult{}; } template IndexResult Search(const AstAffixNode& node, string_view active_field) { vector indices; if (!active_field.empty()) { if (auto* index = GetIndex(active_field); index) indices = {index}; else return IndexResult{}; } else { indices = indices_->GetAllTextIndices(); } auto mapping = [&node, this](TextIndex* index) { if constexpr (T == TagType::PREFIX) return CollectMatches(index, node.affix, &TextIndex::MatchPrefix); else if constexpr (T == TagType::SUFFIX) return CollectMatches(index, node.affix, &TextIndex::MatchSuffix); else if constexpr (T == TagType::INFIX) return CollectMatches(index, node.affix, &TextIndex::MatchInfix); else return vector{}; }; return UnifyResults(GetSubResults(indices, mapping), LogicOp::OR); } // "term": access field's text index or unify results from all text indices if no field is set IndexResult Search(const AstAffixNode node, string_view active_field) { std::string term = node.affix; bool strip_whitespace = true; if (auto synonyms = indices_->GetSynonyms(); synonyms) { if (auto group_id = synonyms->GetGroupToken(term); group_id) { term = *group_id; strip_whitespace = false; } } if (!active_field.empty()) { if (auto* index = GetIndex(active_field); index) return IndexResult{index->Matching(term, strip_whitespace)}; return IndexResult{}; } vector selected_indices = indices_->GetAllTextIndices(); auto mapping = [&term, strip_whitespace](TextIndex* index) { return index->Matching(term, strip_whitespace); }; return UnifyResults(GetSubResults(selected_indices, mapping), LogicOp::OR); } // [range]: access field's numeric index IndexResult Search(const AstRangeNode& node, string_view active_field) { DCHECK(!active_field.empty()); if (auto* index = GetIndex(active_field); index) { return IndexResult{index->Range(node.lo, node.hi)}; } return IndexResult{}; } IndexResult Search(const AstGeoNode& node, string_view active_field) { DCHECK(!active_field.empty()); if (auto* index = GetIndex(active_field); index) { return IndexResult{index->RadiusSearch(node.lon, node.lat, node.radius, node.unit)}; } return IndexResult{}; } // negate -(*subquery*): explicitly compute result complement. Needs further optimizations IndexResult Search(const AstNegateNode& node, string_view active_field) { auto matched = SearchGeneric(*node.node, active_field).Take().first; vector all = indices_->GetAllDocs(); // To negate a result, we have to find the complement of matched to all documents, // so we remove all matched documents from the set of all documents. auto pred = [&matched](DocId doc) { return binary_search(matched.begin(), matched.end(), doc); }; all.erase(remove_if(all.begin(), all.end(), pred), all.end()); return IndexResult{std::move(all)}; } // logical query: unify all sub results IndexResult Search(const AstLogicalNode& node, string_view active_field) { auto mapping = [&](auto& node) { return SearchGeneric(node, active_field); }; return UnifyResults(GetSubResults(node.nodes, mapping), node.op); } // @field: set active field for sub tree IndexResult Search(const AstFieldNode& node, string_view active_field) { DCHECK(active_field.empty()); DCHECK(node.node); return SearchGeneric(*node.node, node.field); } // {tags | ...}: Unify results for all tags IndexResult Search(const AstTagsNode& node, string_view active_field) { auto* tag_index = GetIndex(active_field); if (!tag_index) return IndexResult{}; Overloaded ov{[tag_index](const AstTermNode& term) -> IndexResult { return IndexResult{tag_index->Matching(term.affix)}; }, [tag_index, this](const AstPrefixNode& prefix) { return CollectMatches(tag_index, prefix.affix, &TagIndex::MatchPrefix); }, [tag_index, this](const AstSuffixNode& suffix) { return CollectMatches(tag_index, suffix.affix, &TagIndex::MatchSuffix); }, [tag_index, this](const AstInfixNode& infix) { return CollectMatches(tag_index, infix.affix, &TagIndex::MatchInfix); }}; auto mapping = [ov](const auto& tag) { return visit(ov, tag); }; return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR); } void SearchKnnFlat(FlatVectorIndex* vec_index, const AstKnnNode& knn, IndexResult&& sub_results) { knn_distances_.reserve(sub_results.ApproximateSize()); auto cb = [&](auto* set) { auto [dim, sim] = vec_index->Info(); for (DocId matched_doc : *set) { const float* vec = vec_index->Get(matched_doc); if (!vec) continue; float dist = VectorDistance(knn.vec.first.get(), vec, dim, sim); knn_distances_.emplace_back(dist, matched_doc); } }; visit(cb, sub_results.Borrowed()); size_t prefix_size = min(knn.limit, knn_distances_.size()); partial_sort(knn_distances_.begin(), knn_distances_.begin() + prefix_size, knn_distances_.end()); knn_distances_.resize(prefix_size); } void SearchVectorRangeFlat(FlatVectorIndex* vec_index, const AstVectorRangeNode& node) { const auto& all_docs = indices_->GetAllDocs(); auto [dim, sim] = vec_index->Info(); for (DocId doc : all_docs) { const float* vec = vec_index->Get(doc); if (!vec) continue; float dist = VectorDistance(node.vec.first.get(), vec, dim, sim); if (dist <= static_cast(node.radius)) { knn_scores_.emplace_back(doc, dist); } } } // [@field:[VECTOR_RANGE r vec]=>{$YIELD_DISTANCE_AS: alias}]: // Return all docs within distance radius, storing distances in knn_scores_ IndexResult Search(const AstVectorRangeNode& node, string_view active_field) { DCHECK(active_field.empty()); auto* vec_index = GetIndex(node.field); if (!vec_index) return IndexResult{}; if (node.vec.second == 0) return IndexResult{}; if (node.radius < 0 || std::isnan(node.radius)) { error_ = absl::StrCat("VECTOR_RANGE radius must be non-negative, got: ", node.radius); return IndexResult{}; } if (auto [dim, _] = vec_index->Info(); dim != node.vec.second) { error_ = absl::StrCat("Wrong vector index dimensions, got: ", node.vec.second, ", expected: ", dim); return IndexResult{}; } knn_scores_.clear(); // HNSW fields are not stored in FieldIndices::indices_, so GetIndex above // returns nullptr for HNSW before we reach this point. // HNSW range search support is planned separately (see hnsw_index.h). if (auto* flat_index = dynamic_cast(vec_index); flat_index) SearchVectorRangeFlat(flat_index, node); vector out(knn_scores_.size()); for (size_t i = 0; i < knn_scores_.size(); i++) out[i] = knn_scores_[i].first; return IndexResult{std::move(out)}; } // [KNN limit @field vec]: Compute distance from `vec` to all vectors keep closest `limit` IndexResult Search(const AstKnnNode& knn, string_view active_field) { DCHECK(active_field.empty()); auto sub_results = SearchGeneric(*knn.filter, active_field); auto* vec_index = GetIndex(knn.field); if (!vec_index) return IndexResult{}; // If vector dimension is 0, treat as placeholder/invalid - return empty results // This allows tests to use dummy vector values like "" if (knn.vec.second == 0) return IndexResult{}; if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second) { error_ = absl::StrCat("Wrong vector index dimensions, got: ", knn.vec.second, ", expected: ", dim); return IndexResult{}; } knn_scores_.clear(); if (auto flat_index = dynamic_cast(vec_index); flat_index) SearchKnnFlat(dynamic_cast(vec_index), knn, std::move(sub_results)); vector out(knn_distances_.size()); knn_scores_.reserve(knn_distances_.size()); for (size_t i = 0; i < knn_distances_.size(); i++) { knn_scores_.emplace_back(knn_distances_[i].second, knn_distances_[i].first); out[i] = knn_distances_[i].second; } return IndexResult{std::move(out)}; } // Determine node type and call specific search function IndexResult SearchGeneric(const AstNode& node, string_view active_field, bool top_level = false) { if (!error_.empty()) return IndexResult{}; ProfileBuilder::Tp start = profile_builder_ ? profile_builder_->Start() : ProfileBuilder::Tp{}; auto cb = [this, active_field](const auto& inner) { return Search(inner, active_field); }; auto result = visit(cb, node.Variant()); // Top level results don't need to be sorted, because they will be scored, sorted by fields or // used by knn DCHECK(top_level || holds_alternative(node.Variant()) || holds_alternative(node.Variant()) || holds_alternative(node.Variant()) || visit([](auto* set) { return is_sorted(set->begin(), set->end()); }, result.Borrowed())); if (profile_builder_) profile_builder_->Finish(start, node, result); return result; } SearchResult Search(const AstNode& query, size_t cuttoff_limit) { IndexResult result = SearchGeneric(query, "", true); // Extract profile if enabled optional profile = profile_builder_ ? make_optional(profile_builder_->Take()) : nullopt; auto [out, total_size] = result.Take(cuttoff_limit); return SearchResult{total_size, std::move(out), std::move(knn_scores_), std::move(profile), std::move(error_)}; } const FieldIndices* indices_; string error_; optional profile_builder_ = ProfileBuilder{}; std::vector> knn_scores_; vector> knn_distances_; }; #ifndef __clang__ #pragma GCC diagnostic pop #endif } // namespace AstNode OptionalNumericFilter::Node(std::string field) { return AstFieldNode{"@" + field, AstRangeNode(lo_, false, hi_, false)}; } string_view Schema::LookupAlias(string_view alias) const { if (auto it = field_names.find(alias); it != field_names.end()) return it->second; return alias; } string_view Schema::LookupIdentifier(string_view identifier) const { if (auto it = fields.find(identifier); it != fields.end()) return it->second.short_name; return identifier; } IndicesOptions::IndicesOptions() { static absl::flat_hash_set kDefaultStopwords{ "a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "it", "no", "not", "of", "on", "or", "such", "that", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"}; stopwords = kDefaultStopwords; } FieldIndices::FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr, const Synonyms* synonyms) : schema_{schema}, options_{options}, synonyms_{synonyms} { CreateIndices(mr); CreateSortIndices(); } void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) { for (const auto& [field_ident, field_info] : schema_.fields) { if ((field_info.flags & SchemaField::NOINDEX) > 0) continue; switch (field_info.type) { case SchemaField::TEXT: { const auto& tparams = std::get(field_info.special_params); indices_[field_ident] = make_unique(mr, &options_.stopwords, synonyms_, tparams.with_suffixtrie); break; } case SchemaField::NUMERIC: { const auto& nparams = std::get(field_info.special_params); indices_[field_ident] = make_unique(nparams.block_size, mr); break; } case SchemaField::TAG: { const auto& tparams = std::get(field_info.special_params); indices_[field_ident] = make_unique(mr, tparams); break; } case SchemaField::VECTOR: { unique_ptr vector_index; DCHECK(holds_alternative(field_info.special_params)); const auto& vparams = std::get(field_info.special_params); // Use global HNSW index if (vparams.use_hnsw) break; vector_index = make_unique(vparams, mr); indices_[field_ident] = std::move(vector_index); break; } case SchemaField::GEO: { indices_[field_ident] = make_unique(mr); break; } } } } void FieldIndices::CreateSortIndices() { for (const auto& [field_ident, field_info] : schema_.fields) { if ((field_info.flags & SchemaField::SORTABLE) == 0) continue; switch (field_info.type) { case SchemaField::TAG: case SchemaField::TEXT: sort_indices_[field_ident] = make_unique(); break; case SchemaField::NUMERIC: sort_indices_[field_ident] = make_unique(); break; case SchemaField::VECTOR: case SchemaField::GEO: break; } } } bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) { bool was_added = true; std::vector> successfully_added_indices; successfully_added_indices.reserve(indices_.size() + sort_indices_.size()); auto try_add = [&](const auto& indices_container) { for (auto& [field, index] : indices_container) { if (index->Add(doc, access, field)) { successfully_added_indices.emplace_back(field, index.get()); } else { was_added = false; break; } } }; try_add(indices_); if (was_added) { try_add(sort_indices_); } if (!was_added) { for (auto& [field, index] : successfully_added_indices) { index->Remove(doc, access, field); } return false; } all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc); return true; } void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) { for (auto& [field, index] : indices_) index->Remove(doc, access, field); for (auto& [field, sort_index] : sort_indices_) sort_index->Remove(doc, access, field); auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc); DCHECK(it != all_ids_.end() && *it == doc); all_ids_.erase(it); } BaseIndex* FieldIndices::GetIndex(string_view field) const { auto it = indices_.find(schema_.LookupAlias(field)); return it != indices_.end() ? it->second.get() : nullptr; } BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const { auto it = sort_indices_.find(schema_.LookupAlias(field)); return it != sort_indices_.end() ? it->second.get() : nullptr; } std::vector FieldIndices::GetAllTextIndices() const { vector out; for (const auto& [field_name, field_info] : schema_.fields) { if (field_info.type != SchemaField::TEXT || (field_info.flags & SchemaField::NOINDEX) > 0) continue; auto* index = dynamic_cast(GetIndex(field_name)); DCHECK(index); out.push_back(index); } return out; } const vector& FieldIndices::GetAllDocs() const { return all_ids_; } const Schema& FieldIndices::GetSchema() const { return schema_; } SortableValue FieldIndices::GetSortIndexValue(DocId doc, std::string_view field_identifier) const { auto it = sort_indices_.find(field_identifier); DCHECK(it != sort_indices_.end()); return it->second->Lookup(doc); } void FieldIndices::FinalizeInitialization() { for (auto& [field, index] : indices_) { index->FinalizeInitialization(); } } DefragmentResult FieldIndices::Defragment(PageUsage* page_usage) { auto defrag = [&](auto& indices, string* key) { DefragmentMap dm{indices, key}; return dm.Defragment(page_usage); }; DefragmentResult result = defrag(indices_, &next_defrag_field_); result.Merge(defrag(sort_indices_, &next_defrag_sort_field_)); return result; } const Synonyms* FieldIndices::GetSynonyms() const { return synonyms_; } SearchAlgorithm::SearchAlgorithm() = default; SearchAlgorithm::~SearchAlgorithm() = default; bool SearchAlgorithm::Init(string_view query, const QueryParams* params, const OptionalFilters* filters) { try { query_ = make_unique(ParseQuery(query, params, filters)); } catch (const Parser::syntax_error& se) { LOG(INFO) << "Failed to parse query \"" << query << "\":" << se.what(); return false; } catch (...) { LOG_EVERY_T(INFO, 10) << "Unexpected query parser error \"" << query << "\""; return false; } if (holds_alternative(*query_)) { LOG_EVERY_T(INFO, 10) << "Empty result after parsing query \"" << query << "\""; return false; } return true; } SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_limit) const { DCHECK(query_); auto bs = BasicSearch{index}; if (profiling_enabled_) bs.EnableProfiling(); return bs.Search(*query_, cuttoff_limit); } std::optional SearchAlgorithm::GetKnnScoreSortOption() const { // HNSW KNN query if (knn_hnsw_score_sort_option_) { return knn_hnsw_score_sort_option_; } // FLAT KNN query if (auto* knn = get_if(query_.get()); knn) return KnnScoreSortOption{string_view{knn->score_alias}, knn->limit}; return nullopt; } bool SearchAlgorithm::IsKnnQuery() const { DCHECK(query_); return std::holds_alternative(*query_); } AstKnnNode* SearchAlgorithm::GetKnnNode() const { if (auto* knn = get_if(query_.get()); knn) { return knn; } return nullptr; } std::unique_ptr SearchAlgorithm::PopKnnNode() { if (auto* knn = get_if(query_.get()); knn) { // Save knn score sort option knn_hnsw_score_sort_option_ = KnnScoreSortOption{string_view{knn->score_alias}, knn->limit}; auto node = std::move(query_); AstKnnNode* moved_knn_node = reinterpret_cast(node.get()); if (!std::holds_alternative(*moved_knn_node->filter)) query_.swap(moved_knn_node->filter); return node; } LOG(DFATAL) << "Should not reach here"; return nullptr; } void SearchAlgorithm::EnableProfiling() { profiling_enabled_ = true; } const AstVectorRangeNode* SearchAlgorithm::GetVectorRangeNode() const { return get_if(query_.get()); } } // namespace dfly::search ================================================ FILE: src/core/search/search.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include "base/pmr/memory_resource.h" #include "core/search/base.h" #include "core/search/range_tree.h" #include "core/search/synonyms.h" namespace dfly::search { struct AstNode; struct TextIndex; struct AstKnnNode; struct AstVectorRangeNode; // Optional FILTER struct OptionalNumericFilter : public OptionalFilterBase { OptionalNumericFilter(size_t lo, size_t hi) : empty_(false), lo_(lo), hi_(hi) { } bool IsEmpty() const override { return empty_; } AstNode Node(std::string field) override; void AddRange(size_t lo, size_t hi) { if (empty_) { return; } if ((hi_ < lo) || (hi < lo_)) { empty_ = true; } else { lo_ = std::max(lo_, lo); hi_ = std::min(hi_, hi); } } private: bool empty_; size_t lo_; size_t hi_; }; // Describes a specific index field struct SchemaField { enum FieldType { TAG, TEXT, NUMERIC, VECTOR, GEO }; enum FieldFlags : uint8_t { NOINDEX = 1 << 0, SORTABLE = 1 << 1 }; struct VectorParams { bool use_hnsw = false; size_t dim = 0u; // dimension of knn vectors VectorSimilarity sim = VectorSimilarity::L2; // similarity type size_t capacity = 1000; // initial capacity size_t hnsw_ef_construction = 200; size_t hnsw_m = 16; }; struct TagParams { char separator = ','; bool case_sensitive = false; bool with_suffixtrie = false; // see TextParams }; struct TextParams { // if enabled, suffix trie is build for efficient suffix and infix queries bool with_suffixtrie = false; }; struct NumericParams { // Block size of the range tree // Check RangeTree for details. size_t block_size = RangeTree::kDefaultMaxRangeBlockSize; }; bool IsIndexableHnswField() const { return type == VECTOR && !(flags & NOINDEX) && std::get(special_params).use_hnsw; } using ParamsVariant = std::variant; FieldType type; uint8_t flags; std::string short_name; // equal to ident if none provided ParamsVariant special_params{std::monostate{}}; }; // Describes the fields of an index struct Schema { // List of fields by identifier. absl::flat_hash_map fields; // Mapping for short field names (aliases). absl::flat_hash_map field_names; // Return identifier for alias if found, otherwise return passed value std::string_view LookupAlias(std::string_view alias) const; // Return alias for identifier if found, otherwise return passed value std::string_view LookupIdentifier(std::string_view identifier) const; }; struct IndicesOptions { IndicesOptions(); explicit IndicesOptions(absl::flat_hash_set stopwords) : stopwords{std::move(stopwords)} { } absl::flat_hash_set stopwords; }; // Collection of indices for all fields in schema class FieldIndices { public: // Create indices based on schema and options. Both must outlive the indices FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr, const Synonyms* synonyms); // Returns true if document was added bool Add(DocId doc, const DocumentAccessor& access); void Remove(DocId doc, const DocumentAccessor& access); BaseIndex* GetIndex(std::string_view field) const; BaseSortIndex* GetSortIndex(std::string_view field) const; std::vector GetAllTextIndices() const; const std::vector& GetAllDocs() const; const Schema& GetSchema() const; const Synonyms* GetSynonyms() const; SortableValue GetSortIndexValue(DocId doc, std::string_view field_identifier) const; void FinalizeInitialization(); DefragmentResult Defragment(PageUsage* page_usage); private: void CreateIndices(PMR_NS::memory_resource* mr); void CreateSortIndices(); const Schema& schema_; const IndicesOptions& options_; std::vector all_ids_; absl::flat_hash_map> indices_; absl::flat_hash_map> sort_indices_; const Synonyms* synonyms_; std::string next_defrag_field_; std::string next_defrag_sort_field_; }; struct AlgorithmProfile { struct ProfileEvent { std::string descr; size_t micros; // time event took in microseconds size_t depth; // tree depth of event size_t num_processed; // number of results processed by the event }; std::vector events; }; // Represents a search result returned from the search algorithm. struct SearchResult { size_t total; // how many documents were matched in total // The ids of the matched documents std::vector ids; // Contains final scores if an aggregation was present std::vector> knn_scores; // If profiling was enabled std::optional profile; // If an error occurred, last recent one std::string error; }; struct KnnScoreSortOption { std::string_view score_field_alias; size_t limit = std::numeric_limits::max(); }; // SearchAlgorithm allows searching field indices with a query class SearchAlgorithm { public: SearchAlgorithm(); ~SearchAlgorithm(); // Init with query and optional filters and return true if successful. bool Init(std::string_view query, const QueryParams* params, const OptionalFilters* filters = nullptr); // Search on given index with predefined limit for cutting off result ids SearchResult Search(const FieldIndices* index, size_t cuttoff_limit = std::numeric_limits::max()) const; std::optional GetKnnScoreSortOption() const; bool IsKnnQuery() const; AstKnnNode* GetKnnNode() const; std::unique_ptr PopKnnNode(); const AstVectorRangeNode* GetVectorRangeNode() const; void EnableProfiling(); private: bool profiling_enabled_ = false; std::unique_ptr query_; std::optional knn_hnsw_score_sort_option_; }; } // namespace dfly::search ================================================ FILE: src/core/search/search_parser_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "base/gtest.h" #include "base/logging.h" #include "core/search/base.h" #include "core/search/query_driver.h" #include "core/search/search.h" namespace dfly::search { using namespace std; class SearchParserTest : public ::testing::Test { protected: SearchParserTest() { query_driver_.scanner()->set_debug(1); } void SetInput(const std::string& str) { query_driver_.SetInput(str); } Parser::symbol_type Lex() { return query_driver_.Lex(); } int Parse(const std::string& str) { query_driver_.ResetScanner(); query_driver_.SetInput(str); return Parser(&query_driver_)(); } void SetParams(const QueryParams* params) { query_driver_.SetParams(params); } QueryDriver query_driver_; }; // tokens are not assignable, so we can not reuse them. This macros reduce the boilerplate. #define NEXT_EQ(tok_enum, type, val) \ { \ auto tok = Lex(); \ ASSERT_EQ(tok.type_get(), Parser::token::tok_enum); \ EXPECT_EQ(val, tok.value.as()); \ } #define NEXT_TOK(tok_enum) \ { \ auto tok = Lex(); \ ASSERT_EQ(tok.type_get(), Parser::token::tok_enum); \ } #define NEXT_ERROR() \ { \ bool caught = false; \ try { \ auto tok = Lex(); \ } catch (const Parser::syntax_error& e) { \ caught = true; \ } \ ASSERT_TRUE(caught); \ } TEST_F(SearchParserTest, Scanner) { SetInput("ab cd"); // 3.5.1 does not have name() method. // EXPECT_STREQ("term", tok.name()); NEXT_EQ(TOK_TERM, string, "ab"); NEXT_EQ(TOK_TERM, string, "cd"); NEXT_TOK(TOK_YYEOF); SetInput("*"); NEXT_TOK(TOK_STAR); SetInput("(5a 6) "); NEXT_TOK(TOK_LPAREN); NEXT_EQ(TOK_TERM, string, "5a"); NEXT_EQ(TOK_UINT32, string, "6"); NEXT_TOK(TOK_RPAREN); SetInput(R"( "hello\"world" )"); NEXT_EQ(TOK_TERM, string, R"(hello"world)"); SetInput("@field:hello"); NEXT_EQ(TOK_FIELD, string, "@field"); NEXT_TOK(TOK_COLON); NEXT_EQ(TOK_TERM, string, "hello"); SetInput("@field:{ tag }"); NEXT_EQ(TOK_FIELD, string, "@field"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TERM, string, "tag"); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\,1\\\\\\$\\+}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, R"(blue,1\$+)"); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\.1\\\"\\%\\=}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue.1\"%="); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\<1\\'\\^\\~}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue<1'^~"); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\>1\\:\\&\\/}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue>1:&/"); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\{1\\;\\*\\ }"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue{1;* "); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\}1\\!\\(}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue}1!("); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\[1\\@\\)}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue[1@)"); NEXT_TOK(TOK_RCURLBR); SetInput("@color:{blue\\]1\\#\\-}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "blue]1#-"); NEXT_TOK(TOK_RCURLBR); // Colon in tag value (unescaped) SetInput("@t:{Tag:value}"); NEXT_EQ(TOK_FIELD, string, "@t"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TAG_VAL, string, "Tag:value"); NEXT_TOK(TOK_RCURLBR); // Prefix simple SetInput("pre*"); NEXT_EQ(TOK_PREFIX, string, "pre"); // TODO: uncomment when we support escaped terms // Prefix escaped (redis doesn't support quoted prefix matches) // SetInput("pre\\**"); // NEXT_EQ(TOK_PREFIX, string, "pre*"); // Prefix in tag SetInput("@color:{prefix*}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_PREFIX, string, "prefix"); NEXT_TOK(TOK_RCURLBR); // Prefix escaped star SetInput("@color:{\"prefix*\"}"); NEXT_EQ(TOK_FIELD, string, "@color"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_TERM, string, "prefix*"); NEXT_TOK(TOK_RCURLBR); // Prefix spaced with star SetInput("pre *"); NEXT_EQ(TOK_TERM, string, "pre"); NEXT_TOK(TOK_STAR); SetInput("почтальон Печкин"); NEXT_EQ(TOK_TERM, string, "почтальон"); NEXT_EQ(TOK_TERM, string, "Печкин"); SetInput("33.3"); NEXT_EQ(TOK_DOUBLE, string, "33.3"); } TEST_F(SearchParserTest, EscapedTagPrefixes) { SetInput("@name:{escape\\-err*}"); NEXT_EQ(TOK_FIELD, string, "@name"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_PREFIX, string, "escape-err"); NEXT_TOK(TOK_RCURLBR); SetInput("@name:{escape\\+pre*}"); NEXT_EQ(TOK_FIELD, string, "@name"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_PREFIX, string, "escape+pre"); NEXT_TOK(TOK_RCURLBR); SetInput("@name:{escape\\.pre*}"); NEXT_EQ(TOK_FIELD, string, "@name"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_PREFIX, string, "escape.pre"); NEXT_TOK(TOK_RCURLBR); SetInput("@name:{complex\\-escape\\+with\\.many\\*chars*}"); NEXT_EQ(TOK_FIELD, string, "@name"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LCURLBR); NEXT_EQ(TOK_PREFIX, string, "complex-escape+with.many*chars"); NEXT_TOK(TOK_RCURLBR); } TEST_F(SearchParserTest, Parse) { EXPECT_EQ(0, Parse(" foo bar (baz) ")); EXPECT_EQ(0, Parse(" -(foo) @foo:bar @ss:[1 2]")); EXPECT_EQ(0, Parse("@foo:{ tag1 | tag2 }")); EXPECT_EQ(0, Parse("@foo:{1|2}")); EXPECT_EQ(0, Parse("@foo:{1|2.0|4|3.0}")); EXPECT_EQ(0, Parse("@foo:{1|hello|3.0|world|4}")); EXPECT_EQ(0, Parse("@name:{escape\\-err*}")); // Parenthesized star - used by LangChain for KNN queries (issue #6342) EXPECT_EQ(0, Parse("(*)")); EXPECT_EQ(0, Parse("((*))")); EXPECT_EQ(0, Parse("(((*)))")); // Colon in tag value EXPECT_EQ(0, Parse("@t:{Tag:value}")); EXPECT_EQ(0, Parse("@t:{Tag:*}")); EXPECT_EQ(0, Parse("@category:{Product:Electronics}")); EXPECT_EQ(1, Parse(" -(foo ")); EXPECT_EQ(1, Parse(" foo:bar ")); EXPECT_EQ(1, Parse(" @foo:@bar ")); EXPECT_EQ(1, Parse(" @foo: ")); EXPECT_EQ(0, Parse("*suffix")); EXPECT_EQ(0, Parse("*infix*")); EXPECT_EQ(1, Parse("pre***")); // Geo units EXPECT_EQ(0, Parse("@t:{km}")); EXPECT_EQ(0, Parse("@t:{Km|M}")); EXPECT_EQ(0, Parse("@t:{ft|mi}")); EXPECT_EQ(0, Parse("@location:[0.0 0.0 1 m]")); EXPECT_EQ(0, Parse("@location:[0.0 0.0 1 Km]")); EXPECT_EQ(1, Parse("@location:[0.0 0.0 1 yd]")); } TEST_F(SearchParserTest, ParseParams) { QueryParams params; params["k"] = "10"; params["name"] = "alex"; SetParams(¶ms); SetInput("$name $k"); NEXT_EQ(TOK_TERM, string, "alex"); NEXT_EQ(TOK_UINT32, string, "10"); } TEST_F(SearchParserTest, Quotes) { SetInput(" \"fir st\" 'sec@o@nd' \":third:\" 'four\\\"th' "); NEXT_EQ(TOK_TERM, string, "fir st"); NEXT_EQ(TOK_TERM, string, "sec@o@nd"); NEXT_EQ(TOK_TERM, string, ":third:"); NEXT_EQ(TOK_TERM, string, "four\"th"); } TEST_F(SearchParserTest, Numeric) { SetInput("11 123123123123 '22'"); NEXT_EQ(TOK_UINT32, string, "11"); NEXT_EQ(TOK_DOUBLE, string, "123123123123"); NEXT_EQ(TOK_TERM, string, "22"); } TEST_F(SearchParserTest, VectorRange) { // Full vector range query tokenization SetInput("@vector:[VECTOR_RANGE $radius $vec]=>{$YIELD_DISTANCE_AS: dist}"); NEXT_EQ(TOK_FIELD, string, "@vector"); NEXT_TOK(TOK_COLON); NEXT_TOK(TOK_LBRACKET); NEXT_TOK(TOK_VECTOR_RANGE); } TEST_F(SearchParserTest, VectorRangeParse) { QueryParams params; params["radius"] = "1"; // 4 bytes = one float dimension params["vec"] = std::string(4, '\0'); SetParams(¶ms); // Basic syntax parses without error EXPECT_EQ(0, Parse("@f:[VECTOR_RANGE $radius $vec]=>{$YIELD_DISTANCE_AS: dist}")); } TEST_F(SearchParserTest, KNN) { SetInput("*=>[KNN 1 @vector field_vec]"); NEXT_TOK(TOK_STAR); NEXT_TOK(TOK_ARROW); NEXT_TOK(TOK_LBRACKET); } TEST_F(SearchParserTest, KNNfull) { SetInput("*=>[Knn 1 @vector field_vec EF_Runtime 15 as vec_sort]"); NEXT_TOK(TOK_STAR); NEXT_TOK(TOK_ARROW); NEXT_TOK(TOK_LBRACKET); NEXT_TOK(TOK_KNN); NEXT_EQ(TOK_UINT32, string, "1"); NEXT_TOK(TOK_FIELD); NEXT_TOK(TOK_TERM); NEXT_TOK(TOK_EF_RUNTIME); NEXT_EQ(TOK_UINT32, string, "15"); NEXT_TOK(TOK_AS); NEXT_EQ(TOK_TERM, string, "vec_sort"); NEXT_TOK(TOK_RBRACKET); } } // namespace dfly::search ================================================ FILE: src/core/search/search_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/search.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include "absl/base/macros.h" #include "base/gtest.h" #include "base/logging.h" #include "core/search/base.h" #include "core/search/hnsw_index.h" #include "core/search/query_driver.h" #include "core/search/stateless_allocator.h" #include "core/search/vector_utils.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { namespace search { using namespace std; using ::testing::HasSubstr; // Used for NumericIndex benchmarks. // The value is used to determine the maximum size of a range block in the range tree. constexpr size_t kMaxRangeBlockSize = 500000; struct MockedDocument : public DocumentAccessor { public: using Map = absl::flat_hash_map; MockedDocument() = default; MockedDocument(Map map) : fields_{map} { } MockedDocument(std::string test_field) : fields_{{"field", test_field}} { } std::optional GetStrings(string_view field) const override { auto it = fields_.find(field); if (it == fields_.end()) { return EmptyAccessResult(); } return StringList{string_view{it->second}}; } std::optional GetTags(string_view field) const override { return GetStrings(field); } std::optional GetVector(string_view field, size_t dim) const override { auto strings_list = GetStrings(field); if (!strings_list) return std::nullopt; return !strings_list->empty() ? BytesToFtVectorSafe(strings_list->front()) : OwnedFtVector{}; } std::optional GetNumbers(std::string_view field) const override { auto strings_list = GetStrings(field); if (!strings_list) return std::nullopt; NumsList nums_list; nums_list.reserve(strings_list->size()); for (auto str : strings_list.value()) { auto num = ParseNumericField(str); if (!num) { return std::nullopt; } nums_list.push_back(num.value()); } return nums_list; } string DebugFormat() { string out = "{"; for (const auto& [field, value] : fields_) absl::StrAppend(&out, field, "=", value, ","); if (out.size() > 1) out.pop_back(); out += "}"; return out; } void Set(Map hset) { fields_ = hset; } private: Map fields_{}; }; IndicesOptions kEmptyOptions{{}}; struct SchemaFieldInitializer { SchemaFieldInitializer(std::string_view name, SchemaField::FieldType type) : name{name}, type{type} { switch (type) { case SchemaField::TAG: special_params = SchemaField::TagParams{}; break; case SchemaField::TEXT: special_params = SchemaField::TextParams{}; break; case SchemaField::NUMERIC: special_params = SchemaField::NumericParams{}; break; case SchemaField::VECTOR: special_params = SchemaField::VectorParams{}; break; case SchemaField::GEO: break; } } SchemaFieldInitializer(std::string_view name, SchemaField::FieldType type, SchemaField::ParamsVariant special_params) : name{name}, type{type}, special_params{special_params} { } std::string_view name; SchemaField::FieldType type; SchemaField::ParamsVariant special_params{std::monostate{}}; }; Schema MakeSimpleSchema(initializer_list ilist, bool make_sortable = false) { Schema schema; uint8_t flags = make_sortable ? SchemaField::SORTABLE : 0; for (auto ifield : ilist) { auto& field = schema.fields[ifield.name]; field = {ifield.type, flags, string{ifield.name}, ifield.special_params}; } return schema; } class SearchTest : public ::testing::Test { protected: static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); // Initialize SimSIMD runtime for tests that may exercise vector kernels InitSimSIMD(); } SearchTest() { PrepareSchema({{"field", SchemaField::TEXT}}); } ~SearchTest() { EXPECT_EQ(entries_.size(), 0u) << "Missing check"; } void PrepareSchema(initializer_list ilist) { schema_ = MakeSimpleSchema(ilist); } void PrepareQuery(string_view query) { query_ = query; } template void ExpectAll(Args... args) { (entries_.emplace_back(args, true), ...); } template void ExpectNone(Args... args) { (entries_.emplace_back(args, false), ...); } bool Check() { absl::Cleanup cl{[this] { entries_.clear(); }}; FieldIndices index{schema_, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; shuffle(entries_.begin(), entries_.end(), default_random_engine{}); for (DocId i = 0; i < entries_.size(); i++) index.Add(i, entries_[i].first); index.FinalizeInitialization(); SearchAlgorithm search_algo{}; if (!search_algo.Init(query_, ¶ms_)) { error_ = "Failed to parse query"; return false; } auto matched = search_algo.Search(&index); if (!is_sorted(matched.ids.begin(), matched.ids.end())) LOG(FATAL) << "Search result is not sorted"; for (DocId i = 0; i < entries_.size(); i++) { bool doc_matched = binary_search(matched.ids.begin(), matched.ids.end(), i); if (doc_matched != entries_[i].second) { error_ = "doc: \"" + entries_[i].first.DebugFormat() + "\"" + " was expected" + (entries_[i].second ? "" : " not") + " to match" + " query: \"" + query_ + "\""; return false; } } return true; } string_view GetError() const { return error_; } private: using DocEntry = pair; QueryParams params_; Schema schema_; vector entries_; string query_, error_; }; TEST_F(SearchTest, MatchTerm) { PrepareQuery("foo"); // Check basic cases ExpectAll("foo", "foo bar", "more foo bar"); ExpectNone("wrong", "nomatch"); // Check part of sentence + case. ExpectAll("Foo is cool.", "Where is foo?", "One. FOO!. More", "Foo is foo."); // Check part of word is not matched ExpectNone("foocool", "veryfoos", "ufoo", "morefoomore", "thefoo"); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, MatchNotTerm) { PrepareQuery("-foo"); ExpectAll("faa", "definitielyright"); ExpectNone("foo", "foo bar", "more foo bar"); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, MatchLogicalNode) { { PrepareQuery("foo bar"); ExpectAll("foo bar", "bar foo", "more bar and foo"); ExpectNone("wrong", "foo", "bar", "foob", "far"); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("foo | bar"); ExpectAll("foo bar", "foo", "bar", "foo and more", "or only bar"); ExpectNone("wrong", "only far"); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("foo bar baz"); ExpectAll("baz bar foo", "bar and foo and baz"); ExpectNone("wrong", "foo baz", "bar baz", "and foo"); EXPECT_TRUE(Check()) << GetError(); } } TEST_F(SearchTest, MatchParenthesis) { PrepareQuery("( foo | oof ) ( bar | rab )"); ExpectAll("foo bar", "oof rab", "foo rab", "oof bar", "foo oof bar rab"); ExpectNone("wrong", "bar rab", "foo oof"); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, CheckNotPriority) { for (auto expr : {"-bar foo baz", "foo -bar baz", "foo baz -bar"}) { PrepareQuery(expr); ExpectAll("foo baz", "foo rab baz", "baz rab foo"); ExpectNone("wrong", "bar", "foo bar baz", "foo baz bar"); EXPECT_TRUE(Check()) << GetError(); } for (auto expr : {"-bar | foo", "foo | -bar"}) { PrepareQuery(expr); ExpectAll("foo", "right", "foo bar"); ExpectNone("bar", "bar baz"); EXPECT_TRUE(Check()) << GetError(); } for (auto expr : {"-bar far|-foo tam"}) { PrepareQuery(expr); ExpectAll("far baz", "far foo", "bar tam"); ExpectNone("bar far", "foo tam", "bar foo", "far bar foo"); EXPECT_TRUE(Check()) << GetError(); } } TEST_F(SearchTest, CheckParenthesisPriority) { { PrepareQuery("foo | -(bar baz)"); ExpectAll("foo", "not b/r and b/z", "foo bar baz", "single bar", "only baz"); ExpectNone("bar baz", "some more bar and baz"); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("( foo (bar | baz) (rab | zab) ) | true"); ExpectAll("true", "foo bar rab", "foo baz zab", "foo bar zab"); ExpectNone("wrong", "foo bar baz", "foo rab zab", "foo bar what", "foo rab foo"); EXPECT_TRUE(Check()) << GetError(); } } TEST_F(SearchTest, CheckPrefix) { { PrepareQuery("pre*"); ExpectAll("pre", "prepre", "preachers", "prepared", "pRetty", "PRedators", "prEcisely!"); ExpectNone("pristine", "represent", "repair", "depreciation"); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("new*"); ExpectAll("new", "New York", "Newham", "newbie", "news", "Welcome to Newark!"); ExpectNone("ne", "renew", "nev", "ne-w", "notnew", "casino in neVada"); EXPECT_TRUE(Check()) << GetError(); } } using Map = MockedDocument::Map; TEST_F(SearchTest, MatchField) { PrepareSchema({{"f1", SchemaField::TEXT}, {"f2", SchemaField::TEXT}, {"f3", SchemaField::TEXT}}); PrepareQuery("@f1:foo @f2:bar @f3:baz"); ExpectAll(Map{{"f1", "foo"}, {"f2", "bar"}, {"f3", "baz"}}); ExpectNone(Map{{"f1", "foo"}, {"f2", "bar"}, {"f3", "last is wrong"}}, Map{{"f1", "its"}, {"f2", "totally"}, {"f3", "wrong"}}, Map{{"f1", "im foo but its only me and"}, {"f2", "bar"}}); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, MatchRange) { PrepareSchema({{"f1", SchemaField::NUMERIC}, {"f2", SchemaField::NUMERIC}}); PrepareQuery("@f1:[1 10] @f2:[50 100]"); ExpectAll(Map{{"f1", "5"}, {"f2", "50"}}, Map{{"f1", "1"}, {"f2", "100"}}, Map{{"f1", "10"}, {"f2", "50"}}); ExpectNone(Map{{"f1", "11"}, {"f2", "49"}}, Map{{"f1", "0"}, {"f2", "101"}}); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, MatchDoubleRange) { PrepareSchema({{"f1", SchemaField::NUMERIC}}); { PrepareQuery("@f1: [100.03 199.97]"); ExpectAll(Map{{"f1", "130"}}, Map{{"f1", "170"}}, Map{{"f1", "100.03"}}, Map{{"f1", "199.97"}}); ExpectNone(Map{{"f1", "0"}}, Map{{"f1", "200"}}, Map{{"f1", "100.02999"}}, Map{{"f1", "199.9700001"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@f1: [(100 (199.9]"); ExpectAll(Map{{"f1", "150"}}, Map{{"f1", "100.00001"}}, Map{{"f1", "199.8999999"}}); ExpectNone(Map{{"f1", "50"}}, Map{{"f1", "100"}}, Map{{"f1", "199.9"}}, Map{{"f1", "200"}}); EXPECT_TRUE(Check()) << GetError(); } } TEST_F(SearchTest, MatchStar) { PrepareQuery("*"); ExpectAll("one", "two", "three", "and", "all", "documents"); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, CheckExprInField) { PrepareSchema({{"f1", SchemaField::TEXT}, {"f2", SchemaField::TEXT}, {"f3", SchemaField::TEXT}}); { PrepareQuery("@f1:(a|b) @f2:(c d) @f3:-e"); ExpectAll(Map{{"f1", "a"}, {"f2", "c and d"}, {"f3", "right"}}, Map{{"f1", "b"}, {"f2", "d and c"}, {"f3", "ok"}}); ExpectNone(Map{{"f1", "none"}, {"f2", "only d"}, {"f3", "ok"}}, Map{{"f1", "b"}, {"f2", "d and c"}, {"f3", "it has an e"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery({"@f1:(a (b | c) -(d | e)) @f2:-(a|b)"}); ExpectAll(Map{{"f1", "a b w"}, {"f2", "c"}}); ExpectNone(Map{{"f1", "a b d"}, {"f2", "c"}}, Map{{"f1", "a b w"}, {"f2", "a"}}, Map{{"f1", "a w"}, {"f2", "c"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@f1:(-a c|-b d)"); ExpectAll(Map{{"f1", "c"}}, Map{{"f1", "d"}}); ExpectNone(Map{{"f1", "a"}}, Map{{"f1", "b"}}); EXPECT_TRUE(Check()) << GetError(); } } TEST_F(SearchTest, CheckTag) { PrepareSchema({{"f1", SchemaField::TAG}, {"f2", SchemaField::TAG}}); PrepareQuery("@f1:{red | blue} @f2:{circle | square}"); ExpectAll(Map{{"f1", "red"}, {"f2", "square"}}, Map{{"f1", "blue"}, {"f2", "square"}}, Map{{"f1", "red"}, {"f2", "circle"}}, Map{{"f1", "red"}, {"f2", "circle, square"}}, Map{{"f1", "red"}, {"f2", "triangle, circle"}}, Map{{"f1", "red, green"}, {"f2", "square"}}, Map{{"f1", "green, blue"}, {"f2", "circle"}}); ExpectNone(Map{{"f1", "green"}, {"f2", "square"}}, Map{{"f1", "green"}, {"f2", "circle"}}, Map{{"f1", "red"}, {"f2", "triangle"}}, Map{{"f1", "blue"}, {"f2", "line, triangle"}}); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, CheckTagPrefix) { PrepareSchema({{"color", SchemaField::TAG}}); PrepareQuery("@color:{green* | orange | yellow*}"); ExpectAll(Map{{"color", "green"}}, Map{{"color", "yellow"}}, Map{{"color", "greenish"}}, Map{{"color", "yellowish"}}, Map{{"color", "green-forestish"}}, Map{{"color", "yellowsunish"}}, Map{{"color", "orange"}}); ExpectNone(Map{{"color", "red"}}, Map{{"color", "blue"}}, Map{{"color", "orangeish"}}, Map{{"color", "darkgreen"}}, Map{{"color", "light-yellow"}}); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, IntegerTerms) { PrepareSchema({{"status", SchemaField::TAG}, {"title", SchemaField::TEXT}}); PrepareQuery("@status:{1} @title:33"); ExpectAll(Map{{"status", "1"}, {"title", "33 cars on the road"}}); ExpectNone(Map{{"status", "0"}, {"title", "22 trains on the tracks"}}); EXPECT_TRUE(Check()) << GetError(); } TEST_F(SearchTest, StopWords) { auto schema = MakeSimpleSchema({{"title", SchemaField::TEXT}}); IndicesOptions options{{"some", "words", "are", "left", "out"}}; FieldIndices indices{schema, options, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo{}; QueryParams params; vector documents = {"some words left out", // "some can be found", // "words are never matched", // "explicitly found!"}; for (size_t i = 0; i < documents.size(); i++) { MockedDocument doc{{{"title", documents[i]}}}; indices.Add(i, doc); } // words is a stopword algo.Init("words", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre()); // some is a stopword algo.Init("some", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre()); // found is not a stopword algo.Init("found", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 3)); } class SearchRaxTest : public SearchTest, public testing::WithParamInterface> { }; TEST_P(SearchRaxTest, SuffixInfix) { auto [with_trie, use_tag] = GetParam(); Schema schema = MakeSimpleSchema({{"title", use_tag ? SchemaField::TAG : SchemaField::TEXT}}); if (use_tag) { schema.fields["title"].special_params = SchemaField::TagParams{.with_suffixtrie = with_trie}; } else { schema.fields["title"].special_params = SchemaField::TextParams{.with_suffixtrie = with_trie}; } FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo{}; QueryParams params; vector documents = {"Berries", "BlueBeRRies", "Blackberries", "APPLES", "CranbeRRies", "Wolfberry", "StraWberry"}; for (size_t i = 0; i < documents.size(); i++) { MockedDocument doc{{{"title", documents[i]}}}; indices.Add(i, doc); } auto prepare = [&, use_tag = use_tag](string q) { if (use_tag) q = "@title:{"s + q + "}"s; algo.Init(q, ¶ms); }; // suffix queries prepare("*Es"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4)); prepare("*beRRies"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 4)); prepare("*les"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(3)); prepare("*lueBERRies"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1)); prepare("*berrY"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(5, 6)); // infix queries prepare("*berr*"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 4, 5, 6)); prepare("*ANB*"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4)); prepare("*berries*"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 4)); prepare("*bL*"); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2)); } INSTANTIATE_TEST_SUITE_P(NoTrieText, SearchRaxTest, testing::Values(pair{false, false})); INSTANTIATE_TEST_SUITE_P(WithTrieText, SearchRaxTest, testing::Values(pair{true, false})); INSTANTIATE_TEST_SUITE_P(NoTrieTag, SearchRaxTest, testing::Values(pair{false, true})); INSTANTIATE_TEST_SUITE_P(WithTrieTag, SearchRaxTest, testing::Values(pair{true, true})); std::string ToBytes(absl::Span vec) { return string{reinterpret_cast(vec.data()), sizeof(float) * vec.size()}; } TEST_F(SearchTest, Errors) { auto schema = MakeSimpleSchema( {{"score", SchemaField::NUMERIC}, {"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo{}; QueryParams params; // Non-existent field algo.Init("@cantfindme:[1 10]", ¶ms); EXPECT_THAT(algo.Search(&indices).error, HasSubstr("Invalid field")); // Invalid type algo.Init("@even:[1 10]", ¶ms); EXPECT_THAT(algo.Search(&indices).error, HasSubstr("Wrong access type")); // Wrong vector index dimensions params["vec"] = ToBytes({1, 2, 3, 4}); algo.Init("* => [KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).error, HasSubstr("Wrong vector index dimensions")); } TEST_F(SearchTest, MatchNumericRangeWithCommas) { PrepareSchema({{"f1", SchemaField::NUMERIC}, {"draw_end", SchemaField::NUMERIC}}); // Main tests for point range with identical values and different delimiters { PrepareQuery("@draw_end:[1742916180 1742916180]"); ExpectAll(Map{{"draw_end", "1742916180"}}); ExpectNone(Map{{"draw_end", "1742916181"}}, Map{{"draw_end", "1742916179"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@draw_end:[1742916180, 1742916180]"); ExpectAll(Map{{"draw_end", "1742916180"}}); ExpectNone(Map{{"draw_end", "1742916181"}}, Map{{"draw_end", "1742916179"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@draw_end:[1742916180 ,1742916180]"); ExpectAll(Map{{"draw_end", "1742916180"}}); ExpectNone(Map{{"draw_end", "1742916181"}}, Map{{"draw_end", "1742916179"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@draw_end:[1742916180 1742916180]"); ExpectAll(Map{{"draw_end", "1742916180"}}); ExpectNone(Map{{"draw_end", "1742916181"}}, Map{{"draw_end", "1742916179"}}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@f1:[100 , 200]"); ExpectAll(Map{{"f1", "100"}}, Map{{"f1", "150"}}, Map{{"f1", "200"}}); ExpectNone(Map{{"f1", "99"}}, Map{{"f1", "201"}}); EXPECT_TRUE(Check()) << GetError(); } } class KnnTest : public SearchTest {}; class VectorRangeTest : public ::testing::Test { protected: static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); InitSimSIMD(); } }; TEST_F(VectorRangeTest, FlatRange1D) { auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; // Place 10 points on a line: 1, 2, ..., 10 (avoid zero vector for doc 0) for (size_t i = 0; i < 10; i++) { MockedDocument doc{Map{{"pos", ToBytes({float(i + 1)})}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; // Query at 5.0 with radius 1.5 → points at pos 4,5,6 → doc ids 3,4,5 { params["vec"] = ToBytes({5.0f}); algo.Init("@pos:[VECTOR_RANGE 1.5 $vec]=>{$YIELD_DISTANCE_AS: dist}", ¶ms); auto result = algo.Search(&indices); EXPECT_THAT(result.ids, testing::UnorderedElementsAre(3, 4, 5)); } // Exact match at pos 4.0 with radius 0 → only doc 3 { params["vec"] = ToBytes({4.0f}); algo.Init("@pos:[VECTOR_RANGE 0 $vec]=>{$YIELD_DISTANCE_AS: dist}", ¶ms); auto result = algo.Search(&indices); EXPECT_THAT(result.ids, testing::UnorderedElementsAre(3)); } // Large radius → all 10 points { params["vec"] = ToBytes({5.0f}); algo.Init("@pos:[VECTOR_RANGE 100 $vec]=>{$YIELD_DISTANCE_AS: dist}", ¶ms); auto result = algo.Search(&indices); EXPECT_EQ(result.ids.size(), 10u); } // Empty result when radius is too small { params["vec"] = ToBytes({5.5f}); algo.Init("@pos:[VECTOR_RANGE 0.1 $vec]=>{$YIELD_DISTANCE_AS: dist}", ¶ms); auto result = algo.Search(&indices); EXPECT_TRUE(result.ids.empty()); } } TEST_F(VectorRangeTest, FlatRangeDistancesStoredInScores) { auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; // Use i+1 so doc positions are 1..5 (query radius 1.5 from pos 2.0 catches docs 0,1,2) for (size_t i = 0; i < 5; i++) { MockedDocument doc{Map{{"pos", ToBytes({float(i + 1)})}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; params["vec"] = ToBytes({2.0f}); algo.Init("@pos:[VECTOR_RANGE 1.5 $vec]=>{$YIELD_DISTANCE_AS: vector_distance}", ¶ms); ASSERT_NE(nullptr, algo.GetVectorRangeNode()); EXPECT_STREQ("vector_distance", algo.GetVectorRangeNode()->score_alias.c_str()); auto result = algo.Search(&indices); // Positions 1,2,3 (docs 0,1,2) are within L2 distance 1.5 from query pos 2.0 EXPECT_THAT(result.ids, testing::UnorderedElementsAre(0, 1, 2)); // knn_scores should contain distances for all matched docs EXPECT_EQ(result.knn_scores.size(), 3u); } TEST_F(VectorRangeTest, FlatStarQueryZeroVectorIsValid) { // Regression: @field:* on a FLAT vector index uses GetAllDocsWithNonNullValues(), which // incorrectly skips zero vectors. The zero vector [0.0,...,0.0] is a valid embedding. auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 2}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; // doc 0: zero vector [0.0, 0.0] — valid embedding, must not be skipped indices.Add(0, MockedDocument{Map{{"pos", ToBytes({0.0f, 0.0f})}}}); // doc 1: non-zero vector [1.0, 0.0] indices.Add(1, MockedDocument{Map{{"pos", ToBytes({1.0f, 0.0f})}}}); SearchAlgorithm algo{}; QueryParams params; algo.Init("@pos:*", ¶ms); auto result = algo.Search(&indices); // Both docs must appear — zero vector is NOT null EXPECT_THAT(result.ids, testing::UnorderedElementsAre(0, 1)); } TEST_F(VectorRangeTest, FlatStarQueryRemovedDocNotMatched) { // Regression: @field:* on a FLAT vector index uses GetAllDocsWithNonNullValues(), which // iterates entries_ directly and does NOT respect all_ids_. After Remove(), the doc's // slot in entries_ is still non-zero, so the removed doc incorrectly appears in results. auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; indices.Add(0, MockedDocument{Map{{"pos", ToBytes({1.0f})}}}); indices.Add(1, MockedDocument{Map{{"pos", ToBytes({2.0f})}}}); indices.Add(2, MockedDocument{Map{{"pos", ToBytes({3.0f})}}}); // Remove doc 1 MockedDocument doc1{Map{{"pos", ToBytes({2.0f})}}}; indices.Remove(1, doc1); SearchAlgorithm algo{}; QueryParams params; algo.Init("@pos:*", ¶ms); auto result = algo.Search(&indices); // Doc 1 was removed, only docs 0 and 2 should appear EXPECT_THAT(result.ids, testing::UnorderedElementsAre(0, 2)); } TEST_F(KnnTest, Simple1D) { auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; // Place points on a straight line for (size_t i = 0; i < 100; i++) { Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}}; MockedDocument doc{values}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; // Five closest to 50 { params["vec"] = ToBytes({50.0}); algo.Init("*=>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52)); } // Five closest to 0 { params["vec"] = ToBytes({0.0}); algo.Init("*=>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4)); } // Five closest to 20, all even { params["vec"] = ToBytes({20.0}); algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24)); } // Three closest to 31, all odd { params["vec"] = ToBytes({31.0}); algo.Init("@even:{no} =>[KNN 3 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33)); } // Two closest to 70.5 { params["vec"] = ToBytes({70.5}); algo.Init("* =>[KNN 2 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71)); } // Two closest to 70.5 { params["vec"] = ToBytes({70.5}); algo.Init("* =>[KNN 2 @pos $vec as vector_distance]", ¶ms); EXPECT_EQ("vector_distance", algo.GetKnnScoreSortOption()->score_field_alias); SearchResult result = algo.Search(&indices); EXPECT_THAT(result.ids, testing::UnorderedElementsAre(70, 71)); } } TEST_F(KnnTest, Simple2D) { // Square: // 3 2 // 4 // 0 1 const pair kTestCoords[] = {{0, 0}, {1, 0}, {1, 1}, {0, 1}, {0.5, 0.5}}; auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 2}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); MockedDocument doc{Map{{"pos", coords}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; // Single center { params["vec"] = ToBytes({0.5, 0.5}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4)); } // Lower left { params["vec"] = ToBytes({0, 0}); algo.Init("* =>[KNN 4 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4)); } // Upper right { params["vec"] = ToBytes({1, 1}); algo.Init("* =>[KNN 4 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4)); } // Request more than there is { params["vec"] = ToBytes({0, 0}); algo.Init("* => [KNN 10 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4)); } // Test correct order: (0.7, 0.15) { params["vec"] = ToBytes({0.7, 0.15}); algo.Init("* => [KNN 10 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3)); } // Test correct order: (0.8, 0.9) { params["vec"] = ToBytes({0.8, 0.9}); algo.Init("* => [KNN 10 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0)); } } TEST_F(KnnTest, Cosine) { // Four arrows, closest cosing distance will be closes by angle // 0 🡢 1 🡣 2 🡠 3 🡡 const pair kTestCoords[] = {{1, 0}, {0, -1}, {-1, 0}, {0, 1}}; auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 2, VectorSimilarity::COSINE}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); MockedDocument doc{Map{{"pos", coords}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; // Point down { params["vec"] = ToBytes({-0.1, -10}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1)); } // Point left { params["vec"] = ToBytes({-0.1, -0.01}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(2)); } // Point up { params["vec"] = ToBytes({0, 5}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(3)); } // Point right { params["vec"] = ToBytes({0.2, 0.05}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0)); } } TEST_F(KnnTest, IP) { // Test with normalized unit vectors for IP distance // Using unit vectors pointing in different directions const pair kTestCoords[] = { {1.0f, 0.0f}, {0.0f, 1.0f}, {-1.0f, 0.0f}, {0.0f, -1.0f}}; auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 2, VectorSimilarity::IP}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); MockedDocument doc{Map{{"pos", coords}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; // Query with vector pointing right - should find exact match (highest dot product) { params["vec"] = ToBytes({1.0f, 0.0f}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0)); } // Query with vector pointing up - should find exact match (highest dot product) { params["vec"] = ToBytes({0.0f, 1.0f}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1)); } } TEST_F(KnnTest, AddRemove) { auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1, VectorSimilarity::L2}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; vector documents(10); for (size_t i = 0; i < 10; i++) { documents[i] = Map{{"pos", ToBytes({float(i)})}}; indices.Add(i, documents[i]); } SearchAlgorithm algo{}; QueryParams params; // search leftmost 5 { params["vec"] = ToBytes({-1.0}); algo.Init("* =>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(0, 1, 2, 3, 4)); } // delete leftmost 5 for (size_t i = 0; i < 5; i++) indices.Remove(i, documents[i]); // search leftmost 5 again { params["vec"] = ToBytes({-1.0}); algo.Init("* =>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(5, 6, 7, 8, 9)); } // add removed elements for (size_t i = 0; i < 5; i++) indices.Add(i, documents[i]); // repeat first search { params["vec"] = ToBytes({-1.0}); algo.Init("* =>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(0, 1, 2, 3, 4)); } } TEST_F(KnnTest, AutoResize) { // Make sure index resizes automatically even with a small initial capacity const size_t kInitialCapacity = 5; auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1, VectorSimilarity::L2, kInitialCapacity}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; for (size_t i = 0; i < 100; i++) { MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}}; indices.Add(i, doc); } EXPECT_EQ(indices.GetAllDocs().size(), 100); } // Parameterized HNSW serialization round-trip test. // Parameters: {num_elements, dim, similarity} struct HnswSerParam { size_t num_elements; size_t dim; VectorSimilarity sim; friend std::ostream& operator<<(std::ostream& os, const HnswSerParam& p) { const char* sim_name[] = {"L2", "IP", "COSINE"}; return os << p.num_elements << "el_" << p.dim << "d_" << sim_name[static_cast(p.sim)]; } }; class HnswSerializationTest : public ::testing::TestWithParam { protected: void SetUp() override { InitTLSearchMR(PMR_NS::get_default_resource()); } void TearDown() override { InitTLSearchMR(nullptr); } }; TEST_P(HnswSerializationTest, RoundTrip) { const auto [num_elements, dim, sim] = GetParam(); SchemaField::VectorParams params; params.use_hnsw = true; params.dim = dim; params.sim = sim; params.capacity = std::max(num_elements, 10); params.hnsw_m = 16; params.hnsw_ef_construction = 200; HnswVectorIndex original(params, /*copy_vector=*/true); std::mt19937 rng(42); std::uniform_real_distribution dist(0.0f, 1.0f); vector docs(num_elements); for (size_t i = 0; i < num_elements; i++) { vector coords(dim); for (size_t d = 0; d < dim; d++) coords[d] = dist(rng); docs[i] = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}}; original.Add(i, docs[i], "vec"); } // Serialize auto metadata = original.GetMetadata(); ASSERT_EQ(metadata.cur_element_count, num_elements); std::vector nodes; { auto lock = original.GetReadLock(); nodes = original.GetNodesRange(0, metadata.cur_element_count); } ASSERT_EQ(nodes.size(), num_elements); // Verify node data integrity for (const auto& node : nodes) { EXPECT_EQ(node.levels_links.size(), static_cast(node.level + 1)); EXPECT_GT(node.TotalSize(), 0u); } // Deserialize into a fresh index HnswVectorIndex restored(params, /*copy_vector=*/true); restored.SetMetadata(metadata); restored.RestoreFromNodes(nodes, metadata); // Before UpdateVectorData, all nodes must be marked deleted. // KNN should safely return empty results (no crash from nullptr dereference). if (num_elements > 0) { vector probe(dim, 0.5f); auto pre_results = restored.Knn(probe.data(), 10, std::nullopt); EXPECT_TRUE(pre_results.empty()) << "All nodes should be deleted before UpdateVectorData"; } for (size_t i = 0; i < num_elements; i++) restored.UpdateVectorData(i, docs[i], "vec"); // Metadata must match auto rm = restored.GetMetadata(); EXPECT_EQ(rm.cur_element_count, metadata.cur_element_count); EXPECT_EQ(rm.maxlevel, metadata.maxlevel); EXPECT_EQ(rm.enterpoint_node, metadata.enterpoint_node); // Graph links must be identical std::vector restored_nodes; { auto lock = restored.GetReadLock(); restored_nodes = restored.GetNodesRange(0, rm.cur_element_count); } ASSERT_EQ(restored_nodes.size(), nodes.size()); for (size_t i = 0; i < nodes.size(); i++) { EXPECT_EQ(restored_nodes[i].internal_id, nodes[i].internal_id); EXPECT_EQ(restored_nodes[i].global_id, nodes[i].global_id); EXPECT_EQ(restored_nodes[i].level, nodes[i].level); ASSERT_EQ(restored_nodes[i].levels_links.size(), nodes[i].levels_links.size()); for (size_t lvl = 0; lvl < nodes[i].levels_links.size(); lvl++) EXPECT_EQ(restored_nodes[i].levels_links[lvl], nodes[i].levels_links[lvl]); } if (num_elements == 0) return; // KNN results must match for several queries auto compare_knn = [&](vector query, size_t k) { auto orig = original.Knn(query.data(), k, std::nullopt); auto rest = restored.Knn(query.data(), k, std::nullopt); ASSERT_EQ(orig.size(), rest.size()); for (size_t j = 0; j < orig.size(); j++) { EXPECT_EQ(orig[j].second, rest[j].second); EXPECT_NEAR(orig[j].first, rest[j].first, 1e-5); } }; size_t k = std::min(num_elements, 10); compare_knn(vector(dim, 0.0f), k); compare_knn(vector(dim, 0.5f), k); compare_knn(vector(dim, 1.0f), k); // Filtered KNN must also match vector allowed; for (size_t i = 0; i < num_elements; i += 2) allowed.push_back(i); size_t fk = std::min(allowed.size(), 5); vector q(dim, 0.5f); auto orig_f = original.Knn(q.data(), fk, std::nullopt, allowed); auto rest_f = restored.Knn(q.data(), fk, std::nullopt, allowed); ASSERT_EQ(orig_f.size(), rest_f.size()); for (size_t i = 0; i < orig_f.size(); i++) { EXPECT_EQ(orig_f[i].second, rest_f[i].second); EXPECT_NEAR(orig_f[i].first, rest_f[i].first, 1e-5); } } INSTANTIATE_TEST_SUITE_P(HnswSer, HnswSerializationTest, testing::Values(HnswSerParam{0, 2, VectorSimilarity::L2}, HnswSerParam{10, 2, VectorSimilarity::L2}, HnswSerParam{1000, 4, VectorSimilarity::L2}, HnswSerParam{10000, 8, VectorSimilarity::L2}, HnswSerParam{10, 3, VectorSimilarity::COSINE}, HnswSerParam{1000, 4, VectorSimilarity::COSINE}, HnswSerParam{10, 2, VectorSimilarity::IP}, HnswSerParam{1000, 4, VectorSimilarity::IP}), [](const testing::TestParamInfo& info) { std::ostringstream name; name << info.param; return name.str(); }); // Test fixture for HNSW deferred operations. // Verifies that Add/Remove called while a read lock is held are properly // deferred and replayed once the lock is released. class HnswDeferredOpsTest : public ::testing::Test { protected: static constexpr size_t kDim = 4; static constexpr size_t kCapacity = 100; void SetUp() override { InitTLSearchMR(PMR_NS::get_default_resource()); SchemaField::VectorParams params; params.use_hnsw = true; params.dim = kDim; params.sim = VectorSimilarity::L2; params.capacity = kCapacity; params.hnsw_m = 16; params.hnsw_ef_construction = 200; index_ = std::make_unique(params, /*copy_vector=*/true); } void TearDown() override { index_.reset(); InitTLSearchMR(nullptr); } MockedDocument MakeDoc(std::initializer_list coords) { return MockedDocument::Map{{"vec", ToBytes(coords)}}; } // Helper: run KNN for the zero vector and return the set of found GlobalDocIds. absl::flat_hash_set KnnIds(size_t k) { vector q(kDim, 0.0f); auto results = index_->Knn(q.data(), k, std::nullopt); absl::flat_hash_set ids; for (auto& [dist, id] : results) ids.insert(id); return ids; } std::unique_ptr index_; }; TEST_F(HnswDeferredOpsTest, AddWhileReadLocked) { // Hold a read lock (simulating serialization), then add elements. auto doc0 = MakeDoc({1, 0, 0, 0}); auto doc1 = MakeDoc({0, 1, 0, 0}); { auto lock = index_->GetReadLock(); // These Adds cannot acquire the write lock and must be deferred. index_->Add(0, doc0, "vec"); index_->Add(1, doc1, "vec"); // While the read lock is still held, KNN should not find the deferred docs. auto ids = KnnIds(10); EXPECT_TRUE(ids.empty()); } // After the read lock is released, deferred ops should replay. // The next operation that touches the index triggers ProcessDeferred. auto ids = KnnIds(10); EXPECT_EQ(ids.size(), 2u); EXPECT_TRUE(ids.contains(0)); EXPECT_TRUE(ids.contains(1)); } TEST_F(HnswDeferredOpsTest, RemoveWhileReadLocked) { // Pre-populate the index. auto doc0 = MakeDoc({1, 0, 0, 0}); auto doc1 = MakeDoc({0, 1, 0, 0}); auto doc2 = MakeDoc({0, 0, 1, 0}); index_->Add(0, doc0, "vec"); index_->Add(1, doc1, "vec"); index_->Add(2, doc2, "vec"); { auto lock = index_->GetReadLock(); // Remove doc1 while read-locked — should be deferred. index_->Remove(1, doc1, "vec"); // doc1 is still visible because the remove is deferred. auto ids = KnnIds(10); EXPECT_EQ(ids.size(), 3u); } // After releasing the lock, removal should take effect. auto ids = KnnIds(10); EXPECT_EQ(ids.size(), 2u); EXPECT_TRUE(ids.contains(0)); EXPECT_TRUE(ids.contains(2)); EXPECT_FALSE(ids.contains(1)); } TEST_F(HnswDeferredOpsTest, DuplicateDeferredOpsKeepLatest) { // Pre-populate with doc0. auto doc0 = MakeDoc({1, 0, 0, 0}); index_->Add(0, doc0, "vec"); auto doc1 = MakeDoc({0, 1, 0, 0}); { auto lock = index_->GetReadLock(); // Add doc1, then remove doc1 — both deferred for the same id. // Only the last operation (remove) should survive. index_->Add(1, doc1, "vec"); index_->Remove(1, doc1, "vec"); } // After lock release, doc1 should not exist (remove was last). auto ids = KnnIds(10); EXPECT_EQ(ids.size(), 1u); EXPECT_TRUE(ids.contains(0)); EXPECT_FALSE(ids.contains(1)); } TEST_F(HnswDeferredOpsTest, DuplicateDeferredOpsAddOverridesRemove) { // Pre-populate with doc0 and doc1. auto doc0 = MakeDoc({1, 0, 0, 0}); auto doc1 = MakeDoc({0, 1, 0, 0}); index_->Add(0, doc0, "vec"); index_->Add(1, doc1, "vec"); auto doc1_new = MakeDoc({0, 0, 1, 0}); { auto lock = index_->GetReadLock(); // Remove doc1, then re-add it with new data — the add should win. index_->Remove(1, doc1, "vec"); index_->Add(1, doc1_new, "vec"); } // After lock release, doc1 should still be present with updated data. auto ids = KnnIds(10); EXPECT_EQ(ids.size(), 2u); EXPECT_TRUE(ids.contains(0)); EXPECT_TRUE(ids.contains(1)); } // Verify that Remove without a read lock also works correctly. TEST_F(HnswDeferredOpsTest, RemoveWithoutReadLock) { auto doc0 = MakeDoc({1, 0, 0, 0}); auto doc1 = MakeDoc({0, 1, 0, 0}); index_->Add(0, doc0, "vec"); index_->Add(1, doc1, "vec"); index_->Remove(1, doc1, "vec"); auto ids = KnnIds(10); EXPECT_EQ(ids.size(), 1u); EXPECT_TRUE(ids.contains(0)); EXPECT_FALSE(ids.contains(1)); } class HnswSubsetKnnTest : public ::testing::TestWithParam { protected: void SetUp() override { InitTLSearchMR(PMR_NS::get_default_resource()); } void TearDown() override { InitTLSearchMR(nullptr); } // Helper to create a simple index with vectors on a line for easy verification unique_ptr CreateSimple1DIndex(size_t num_elements, VectorSimilarity sim) { SchemaField::VectorParams params; params.use_hnsw = true; params.dim = 1; params.sim = sim; params.capacity = std::max(num_elements, 10); params.hnsw_m = 16; params.hnsw_ef_construction = 200; auto index = make_unique(params, /*copy_vector=*/true); for (size_t i = 0; i < num_elements; i++) { vector coords = {static_cast(i)}; auto doc = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}}; index->Add(i, MockedDocument(doc), "vec"); } return index; } // Helper to create a 2D index with unit-circle vectors, for COSINE similarity testing. // Vector i is placed at angle i * (2π / num_elements), giving meaningful cosine distances. unique_ptr CreateCircle2DIndex(size_t num_elements, VectorSimilarity sim) { SchemaField::VectorParams params; params.use_hnsw = true; params.dim = 2; params.sim = sim; params.capacity = std::max(num_elements, 10); params.hnsw_m = 16; params.hnsw_ef_construction = 200; auto index = make_unique(params, /*copy_vector=*/true); const float step = 2.0f * static_cast(acos(-1.0)) / static_cast(num_elements); for (size_t i = 0; i < num_elements; i++) { float angle = step * static_cast(i); vector coords = {cosf(angle), sinf(angle)}; auto doc = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}}; index->Add(i, MockedDocument(doc), "vec"); } return index; } }; TEST_P(HnswSubsetKnnTest, CorrectResults) { // Test that SubsetKnn returns correct top-k from a subset auto sim = GetParam(); auto index = CreateSimple1DIndex(100, sim); vector query = {50.0f}; vector subset; // Create subset: only even numbers from 40 to 60 for (size_t i = 40; i <= 60; i += 2) { subset.push_back(i); } // Ask for top 5 auto results = index->SubsetKnn(query.data(), 5, subset); // Should get exactly 5 results ASSERT_EQ(results.size(), 5u); // All results should be from the subset for (const auto& [dist, id] : results) { EXPECT_TRUE(std::find(subset.begin(), subset.end(), id) != subset.end()) << "Result ID " << id << " not in subset"; } // For L2 similarity, verify the closest point is 50 if (sim == VectorSimilarity::L2) { bool found_50 = false; for (const auto& [dist, id] : results) { if (id == 50) { found_50 = true; break; } } EXPECT_TRUE(found_50) << "For L2, point 50 should be in top 5 closest to query {50}"; } } TEST_P(HnswSubsetKnnTest, EmptySubset) { // Test edge case: empty subset auto sim = GetParam(); auto index = CreateSimple1DIndex(10, sim); vector query = {5.0f}; vector empty_subset; auto results = index->SubsetKnn(query.data(), 5, empty_subset); EXPECT_TRUE(results.empty()) << "SubsetKnn with empty subset should return empty results"; } TEST_P(HnswSubsetKnnTest, KEqualsZero) { // Test edge case: k = 0 auto sim = GetParam(); auto index = CreateSimple1DIndex(10, sim); vector query = {5.0f}; vector subset = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; auto results = index->SubsetKnn(query.data(), 0, subset); EXPECT_TRUE(results.empty()) << "SubsetKnn with k=0 should return empty results"; } TEST_P(HnswSubsetKnnTest, KGreaterThanSubsetSize) { // Test edge case: k > number of valid documents in subset auto sim = GetParam(); auto index = CreateSimple1DIndex(10, sim); vector query = {5.0f}; vector subset = {1, 3, 5}; // Only 3 elements auto results = index->SubsetKnn(query.data(), 10, subset); // Ask for 10 EXPECT_EQ(results.size(), 3u) << "SubsetKnn should return at most subset.size() results"; // Verify all 3 are returned vector result_ids; for (const auto& [dist, id] : results) { result_ids.push_back(id); } EXPECT_THAT(result_ids, testing::UnorderedElementsAre(1, 3, 5)); } TEST_P(HnswSubsetKnnTest, NonExistentIds) { // Test that non-existent IDs in subset are gracefully ignored auto sim = GetParam(); auto index = CreateSimple1DIndex(10, sim); vector query = {5.0f}; // Mix of valid (0-9) and invalid (100-105) IDs vector subset = {100, 4, 101, 5, 102, 6, 103, 104, 105}; auto results = index->SubsetKnn(query.data(), 3, subset); EXPECT_EQ(results.size(), 3u); // Should only return valid IDs: 5, 4, 6 (closest to 5) vector result_ids; for (const auto& [dist, id] : results) { result_ids.push_back(id); } EXPECT_THAT(result_ids, testing::UnorderedElementsAre(4, 5, 6)); } TEST_P(HnswSubsetKnnTest, AllDeletedDocuments) { // Test edge case: all documents in subset are marked deleted auto sim = GetParam(); SchemaField::VectorParams params; params.use_hnsw = true; params.dim = 1; params.sim = sim; params.capacity = 10; params.hnsw_m = 16; params.hnsw_ef_construction = 200; HnswVectorIndex index(params, /*copy_vector=*/true); // Add and then remove documents vector docs; for (size_t i = 0; i < 5; i++) { vector coords = {static_cast(i)}; docs.push_back( MockedDocument(MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}})); index.Add(i, docs[i], "vec"); } // Delete all documents for (size_t i = 0; i < 5; i++) { index.Remove(i, docs[i], "vec"); } vector query = {2.5f}; vector subset = {0, 1, 2, 3, 4}; auto results = index.SubsetKnn(query.data(), 3, subset); EXPECT_TRUE(results.empty()) << "SubsetKnn should return empty when all docs are deleted"; } TEST_P(HnswSubsetKnnTest, MixedDeletedAndValidDocs) { // Test with a mix of deleted and valid documents auto sim = GetParam(); SchemaField::VectorParams params; params.use_hnsw = true; params.dim = 1; params.sim = sim; params.capacity = 10; params.hnsw_m = 16; params.hnsw_ef_construction = 200; HnswVectorIndex index(params, /*copy_vector=*/true); // Add documents vector docs; for (size_t i = 0; i < 10; i++) { vector coords = {static_cast(i)}; docs.push_back( MockedDocument(MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}})); index.Add(i, docs[i], "vec"); } // Delete even documents for (size_t i = 0; i < 10; i += 2) { index.Remove(i, docs[i], "vec"); } vector query = {5.0f}; // Subset includes both deleted (even) and valid (odd) docs vector subset = {2, 3, 4, 5, 6, 7, 8}; auto results = index.SubsetKnn(query.data(), 3, subset); EXPECT_EQ(results.size(), 3u); // Should only return odd (non-deleted) IDs: 5, 3, 7 (closest to 5) vector result_ids; for (const auto& [dist, id] : results) { result_ids.push_back(id); } EXPECT_THAT(result_ids, testing::UnorderedElementsAre(3, 5, 7)); } TEST_P(HnswSubsetKnnTest, CompareWithFilteredKnn) { // Integration test: verify SubsetKnn produces similar results to filtered Knn // SubsetKnn uses brute-force exact search, while Knn uses HNSW approximate search // So results may differ slightly, but should have significant overlap constexpr double kMinOverlapRatio = 0.7; // 70% minimum overlap threshold auto sim = GetParam(); // COSINE similarity is undefined for 1D positive vectors (all share the same direction, // so all cosine distances equal 0). Use 2D unit-circle vectors instead, where element i // is at angle i * 2π/100, giving each pair a distinct, meaningful cosine distance. unique_ptr index; vector query; if (sim == VectorSimilarity::COSINE) { constexpr size_t kNumElements = 100; index = CreateCircle2DIndex(kNumElements, sim); const float step = 2.0f * static_cast(acos(-1.0)) / static_cast(kNumElements); float angle = step * 50.0f; query = {cosf(angle), sinf(angle)}; } else { index = CreateSimple1DIndex(100, sim); query = {50.0f}; } vector subset; // Create a small subset (well below typical 8192 threshold) for (size_t i = 40; i <= 60; i++) { subset.push_back(i); } size_t k = 10; // Get results from SubsetKnn (exact brute-force) auto subset_results = index->SubsetKnn(query.data(), k, subset); // Get results from regular filtered Knn (HNSW approximate) auto knn_results = index->Knn(query.data(), k, std::nullopt, subset); // Both should return k results (or fewer if subset is smaller) EXPECT_LE(subset_results.size(), k); EXPECT_LE(knn_results.size(), k); // Extract IDs from both std::set subset_ids; for (const auto& [dist, id] : subset_results) { subset_ids.insert(id); } std::set knn_ids; for (const auto& [dist, id] : knn_results) { knn_ids.insert(id); } // Count overlap - since HNSW is approximate, we expect good but not perfect overlap size_t overlap = 0; for (const auto& id : subset_ids) { if (knn_ids.count(id) > 0) { overlap++; } } // Expect at least kMinOverlapRatio overlap (HNSW is approximate, so some difference is expected) size_t min_overlap = static_cast(std::min(subset_ids.size(), knn_ids.size()) * kMinOverlapRatio); EXPECT_GE(overlap, min_overlap) << "Expected at least " << min_overlap << " overlapping results, got " << overlap; } INSTANTIATE_TEST_SUITE_P(SubsetKnnSimilarities, HnswSubsetKnnTest, testing::Values(VectorSimilarity::L2, VectorSimilarity::COSINE, VectorSimilarity::IP), [](const testing::TestParamInfo& info) { switch (info.param) { case VectorSimilarity::L2: return "L2"; case VectorSimilarity::COSINE: return "COSINE"; case VectorSimilarity::IP: return "IP"; default: return "Unknown"; } }); // Tests for HnswVectorIndex::RangeQuery class HnswRangeQueryTest : public ::testing::TestWithParam { protected: void SetUp() override { InitTLSearchMR(PMR_NS::get_default_resource()); } void TearDown() override { InitTLSearchMR(nullptr); } // 1-D index: doc i has vector {float(i)}, GlobalDocId = i unique_ptr CreateSimple1DIndex(size_t num_elements) { SchemaField::VectorParams params; params.use_hnsw = true; params.dim = 1; params.sim = VectorSimilarity::L2; params.capacity = std::max(num_elements, 10); params.hnsw_m = 16; params.hnsw_ef_construction = 200; auto index = make_unique(params, /*copy_vector=*/true); for (size_t i = 0; i < num_elements; i++) { vector coords = {static_cast(i)}; index->Add(i, MockedDocument(MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}}), "vec"); } return index; } }; TEST_P(HnswRangeQueryTest, BasicRange) { // 10 docs at positions 0..9. Query at 5.0 with radius 1.5 → docs 4,5,6 (dist 1.0,0.0,1.0) (void)GetParam(); // L2 only for 1-D auto index = CreateSimple1DIndex(10); vector query = {5.0f}; auto results = index->RangeQuery(query.data(), 1.5f); set ids; for (const auto& [dist, id] : results) ids.insert(id); EXPECT_THAT(ids, testing::UnorderedElementsAre(4, 5, 6)); } TEST_P(HnswRangeQueryTest, ExactMatch) { // Radius 0: only the doc at exact position (void)GetParam(); auto index = CreateSimple1DIndex(10); vector query = {3.0f}; auto results = index->RangeQuery(query.data(), 0.0f); ASSERT_EQ(results.size(), 1u); EXPECT_EQ(results[0].second, GlobalDocId{3}); EXPECT_FLOAT_EQ(results[0].first, 0.0f); } TEST_P(HnswRangeQueryTest, LargeRadiusReturnsAll) { (void)GetParam(); auto index = CreateSimple1DIndex(20); vector query = {10.0f}; auto results = index->RangeQuery(query.data(), 1000.0f); EXPECT_EQ(results.size(), 20u); } TEST_P(HnswRangeQueryTest, EmptyResultOutsideRadius) { (void)GetParam(); auto index = CreateSimple1DIndex(10); vector query = {5.5f}; auto results = index->RangeQuery(query.data(), 0.1f); EXPECT_TRUE(results.empty()); } TEST_P(HnswRangeQueryTest, EmptyIndex) { (void)GetParam(); auto index = CreateSimple1DIndex(0); vector query = {0.0f}; auto results = index->RangeQuery(query.data(), 100.0f); EXPECT_TRUE(results.empty()); } TEST_P(HnswRangeQueryTest, DistancesCorrect) { // Verify returned distances match actual L2 distances (void)GetParam(); auto index = CreateSimple1DIndex(10); vector query = {5.0f}; auto results = index->RangeQuery(query.data(), 2.0f); // docs 3,4,5,6,7 EXPECT_EQ(results.size(), 5u); for (const auto& [dist, id] : results) { float expected = std::abs(static_cast(id) - 5.0f); // L2Distance returns sqrt(sum of squares); for 1-D: sqrt((a-b)²) = |a-b| EXPECT_FLOAT_EQ(dist, expected); } } TEST_P(HnswRangeQueryTest, DeletedDocNotReturned) { (void)GetParam(); auto index = CreateSimple1DIndex(10); // Remove doc 5 (at position 5.0, distance 0 from query) index->Remove(5); vector query = {5.0f}; auto results = index->RangeQuery(query.data(), 1.5f); set ids; for (const auto& [dist, id] : results) ids.insert(id); EXPECT_THAT(ids, testing::UnorderedElementsAre(4, 6)); EXPECT_THAT(ids, testing::Not(testing::Contains(GlobalDocId{5}))); } TEST_P(HnswRangeQueryTest, ConsistentWithBruteForce) { // Compare RangeQuery results against brute-force SubsetKnn-based check (void)GetParam(); const size_t n = 50; auto index = CreateSimple1DIndex(n); vector query = {25.0f}; float radius = 5.0f; auto results = index->RangeQuery(query.data(), radius); // Brute force: collect all docs within radius. // L2Distance returns |a-b| for 1-D vectors (actual Euclidean, not squared). set expected; for (size_t i = 0; i < n; i++) { float dist = std::abs(static_cast(i) - 25.0f); if (dist <= radius) expected.insert(i); } set got; for (const auto& [dist, id] : results) got.insert(id); EXPECT_EQ(got, expected); } INSTANTIATE_TEST_SUITE_P(HnswRangeL2, HnswRangeQueryTest, testing::Values(VectorSimilarity::L2), [](const testing::TestParamInfo&) { return "L2"; }); TEST_F(SearchTest, GeoSearch) { auto schema = MakeSimpleSchema({{"name", SchemaField::TEXT}, {"location", SchemaField::GEO}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; indices.Add(0, MockedDocument(Map{{"name", "Mountain View"}, {"location", "-122.08, 37.386"}})); indices.Add(1, MockedDocument(Map{{"name", "Palo Alto"}, {"location", "-122.143, 37.444"}})); indices.Add(2, MockedDocument(Map{{"name", "San Jose"}, {"location", "-121.886, 37.338"}})); indices.Add(3, MockedDocument(Map{{"name", "San Francisco"}, {"location", "-122.419, 37.774"}})); SearchAlgorithm algo{}; QueryParams params; // Search around Mount View 30 miles - San Francisco not included { algo.Init("@location:[-122.083 37.386 30 mi]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2)); } // Search around Mount View 50 miles - all points included { algo.Init("@location:[-122.083 37.386 50 mi]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3)); } // Return all indexes { algo.Init("@location:*", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3)); } // Search around Mount View 50 miles - all points included and filter on prefix { algo.Init("San* @location:[-122.083 37.386 50 mi]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(2, 3)); } // Add duplicate point of San Francisco and search again to include this point also { indices.Add(4, MockedDocument(Map{{"name", "San Francisco"}, {"location", "-122.419, 37.774"}})); algo.Init("San* @location:[-122.083 37.386 50 mi]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(2, 3, 4)); } // Remove first index of San Francisco (id = 3) and search { indices.Remove( 3, MockedDocument(Map{{"name", "San Francisco"}, {"location", "-122.419, 37.774"}})); algo.Init("San* @location:[-122.083 37.386 50 mi]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(2, 4)); } } TEST_F(SearchTest, VectorDistanceBasic) { // Test basic vector distance calculations std::vector vec1 = {1.0f, 2.0f, 3.0f}; std::vector vec2 = {4.0f, 5.0f, 6.0f}; // Test L2 distance float l2_dist = VectorDistance(vec1.data(), vec2.data(), 3, VectorSimilarity::L2); EXPECT_GT(l2_dist, 0.0f); EXPECT_LT(l2_dist, 10.0f); // Should be reasonable value // Test Cosine distance float cos_dist = VectorDistance(vec1.data(), vec2.data(), 3, VectorSimilarity::COSINE); EXPECT_GE(cos_dist, 0.0f); EXPECT_LE(cos_dist, 2.0f); // Cosine distance range // Test IP distance float ip_dist = VectorDistance(vec1.data(), vec2.data(), 3, VectorSimilarity::IP); // IP distance can be negative for non-normalized vectors EXPECT_NE(ip_dist, 0.0f); // Should be non-zero for different vectors // Test identical vectors float l2_same = VectorDistance(vec1.data(), vec1.data(), 3, VectorSimilarity::L2); EXPECT_NEAR(l2_same, 0.0f, 1e-6); float cos_same = VectorDistance(vec1.data(), vec1.data(), 3, VectorSimilarity::COSINE); EXPECT_NEAR(cos_same, 0.0f, 1e-6); float ip_same = VectorDistance(vec1.data(), vec1.data(), 3, VectorSimilarity::IP); // For identical vectors: IP = 1 - dot_product(v, v) = 1 - ||v||^2 // For vec1 = {1, 2, 3}: ||v||^2 = 1 + 4 + 9 = 14, so IP = 1 - 14 = -13 EXPECT_LT(ip_same, 0.0f); // Should be negative for non-normalized vectors } TEST_F(SearchTest, VectorDistanceConsistency) { // Test that results are consistent across multiple calls std::vector vec1 = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; std::vector vec2 = {0.6f, 0.7f, 0.8f, 0.9f, 1.0f}; float l2_dist1 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::L2); float l2_dist2 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::L2); EXPECT_EQ(l2_dist1, l2_dist2); float cos_dist1 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::COSINE); float cos_dist2 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::COSINE); EXPECT_EQ(cos_dist1, cos_dist2); float ip_dist1 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::IP); float ip_dist2 = VectorDistance(vec1.data(), vec2.data(), 5, VectorSimilarity::IP); EXPECT_EQ(ip_dist1, ip_dist2); } static void BM_VectorSearch(benchmark::State& state) { // Ensure SimSIMD dynamic dispatch is initialized for the benchmark InitSimSIMD(); unsigned ndims = state.range(0); unsigned nvecs = state.range(1); auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, ndims}; FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; auto random_vec = [ndims]() { vector coords; for (size_t j = 0; j < ndims; j++) coords.push_back(static_cast(rand()) / static_cast(RAND_MAX)); return coords; }; for (size_t i = 0; i < nvecs; i++) { auto rv = random_vec(); MockedDocument doc{Map{{"pos", ToBytes(rv)}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; auto rv = random_vec(); params["vec"] = ToBytes(rv); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); while (state.KeepRunningBatch(10)) { for (size_t i = 0; i < 10; i++) benchmark::DoNotOptimize(algo.Search(&indices)); } } BENCHMARK(BM_VectorSearch)->Args({120, 10'000}); TEST_F(SearchTest, MatchNonNullField) { PrepareSchema({{"text_field", SchemaField::TEXT}, {"tag_field", SchemaField::TAG}, {"num_field", SchemaField::NUMERIC}}); { PrepareQuery("@text_field:*"); ExpectAll(Map{{"text_field", "any value"}}, Map{{"text_field", "another value"}}, Map{{"text_field", "third"}, {"tag_field", "tag1"}}); ExpectNone(Map{{"tag_field", "wrong field"}}, Map{{"num_field", "123"}}, Map{}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@tag_field:*"); ExpectAll(Map{{"tag_field", "tag1"}}, Map{{"tag_field", "tag2"}}, Map{{"text_field", "value"}, {"tag_field", "tag3"}}); ExpectNone(Map{{"text_field", "wrong field"}}, Map{{"num_field", "456"}}, Map{}); EXPECT_TRUE(Check()) << GetError(); } { PrepareQuery("@num_field:*"); ExpectAll(Map{{"num_field", "123"}}, Map{{"num_field", "456"}}, Map{{"text_field", "value"}, {"num_field", "789"}}); ExpectNone(Map{{"text_field", "wrong field"}}, Map{{"tag_field", "tag1"}}, Map{}); EXPECT_TRUE(Check()) << GetError(); } } TEST_F(SearchTest, InvalidVectorParameter) { search::Schema schema; schema.fields["v"] = search::SchemaField{ search::SchemaField::VECTOR, 0, // flags "v" // short_name }; search::SchemaField::VectorParams params; params.use_hnsw = true; params.dim = 2; params.sim = search::VectorSimilarity::L2; params.capacity = 10; params.hnsw_m = 16; params.hnsw_ef_construction = 200; schema.fields["v"].special_params = params; search::IndicesOptions options; search::FieldIndices indices{schema, options, PMR_NS::get_default_resource(), nullptr}; search::SearchAlgorithm algo; search::QueryParams query_params; query_params["b"] = "abcdefg"; // Parser accepts any string as placeholder // Invalid vectors result in empty vector (dimension 0) which returns empty results ASSERT_TRUE(algo.Init("*=>[KNN 2 @v $b]", &query_params)); // Search should return empty results for invalid vector auto result = algo.Search(&indices); EXPECT_TRUE(result.ids.empty()); } class SortIndexTest : public testing::Test { protected: void SetUp() override { InitTLSearchMR(PMR_NS::get_default_resource()); } void TearDown() override { InitTLSearchMR(nullptr); } }; TEST_F(SortIndexTest, StringSort) { constexpr auto field = "name"; const auto schema = MakeSimpleSchema({{field, SchemaField::TAG}}, true); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; indices.Add(0, MockedDocument{Map{{field, "charlie"}}}); indices.Add(1, MockedDocument{Map{{field, "alpha"}}}); indices.Add(2, MockedDocument{Map{{field, "bravo"}}}); std::vector ids{0, 1, 2}; constexpr bool desc = false; const auto index = indices.GetSortIndex(field); index->Sort(&ids, ids.size(), desc); std::vector expected{1, 2, 0}; EXPECT_EQ(ids, expected); index->Sort(&ids, ids.size(), !desc); expected = {0, 2, 1}; EXPECT_EQ(ids, expected); // conversion from stateless to normal string auto lookup = index->Lookup(1); EXPECT_TRUE(std::holds_alternative(lookup)); EXPECT_EQ(std::get(lookup), "alpha"); } TEST_F(SortIndexTest, NumSort) { constexpr auto field = "cost"; const auto schema = MakeSimpleSchema({{field, SchemaField::NUMERIC}}, true); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; indices.Add(0, MockedDocument{Map{{field, "2999"}}}); indices.Add(1, MockedDocument{Map{{field, "999"}}}); indices.Add(2, MockedDocument{Map{{field, "12"}}}); std::vector ids{0, 1, 2}; constexpr bool desc = false; auto index = indices.GetSortIndex(field); index->Sort(&ids, ids.size(), desc); std::vector expected{2, 1, 0}; EXPECT_EQ(ids, expected); index->Sort(&ids, ids.size(), !desc); expected = {0, 1, 2}; EXPECT_EQ(ids, expected); auto lookup = index->Lookup(1); EXPECT_TRUE(std::holds_alternative(lookup)); EXPECT_EQ(std::get(lookup), 999); } // Enumeration for different search types enum class SearchType { PREFIX = 0, SUFFIX = 1, INFIX = 2 }; // Helper function to generate content with ASCII characters static std::string GenerateWordSequence(size_t word_count, size_t doc_offset = 0) { std::string content; for (size_t i = 0; i < word_count; ++i) { std::string word; char start_char = 'a' + ((doc_offset + i) % 26); size_t word_len = 3 + (i % 5); // Word length 3-7 chars for (size_t j = 0; j < word_len; ++j) { char c = start_char + (j % 26); if (c > 'z') c = 'a' + (c - 'z' - 1); word += c; } if (i > 0) content += " "; content += word; } return content; } // Helper function to generate pattern with variety static std::string GeneratePattern(SearchType search_type, size_t pattern_len, bool use_uniform) { if (use_uniform) { // Original uniform pattern for comparison switch (search_type) { case SearchType::PREFIX: return std::string(pattern_len, 'p'); case SearchType::SUFFIX: return std::string(pattern_len, 's'); case SearchType::INFIX: return std::string(pattern_len, 'i'); } } else { // Diverse ASCII pattern std::string pattern; char base_char = (search_type == SearchType::PREFIX) ? 'p' : (search_type == SearchType::SUFFIX) ? 's' : 'i'; for (size_t i = 0; i < pattern_len; ++i) { char c = base_char + (i % 10); // Use variety of chars if (c > 'z') c = 'a' + (c - 'z' - 1); pattern += c; } return pattern; } return ""; } static void BM_SearchByTypeImpl(benchmark::State& state, bool use_diverse_pattern) { size_t num_docs = state.range(0); size_t pattern_len = state.range(1); SearchType search_type = static_cast(state.range(2)); auto schema = MakeSimpleSchema({{"title", SchemaField::TEXT}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; // Generate pattern std::string pattern = GeneratePattern(search_type, pattern_len, !use_diverse_pattern); std::string search_type_name = (search_type == SearchType::PREFIX) ? "prefix" : (search_type == SearchType::SUFFIX) ? "suffix" : "infix"; // Generate test data with more realistic content for (size_t i = 0; i < num_docs; i++) { std::string content; if (i < num_docs / 2) { // Half documents have the pattern in appropriate position std::string base_content = GenerateWordSequence(5 + (i % 5), i); switch (search_type) { case SearchType::PREFIX: content = pattern + base_content; break; case SearchType::SUFFIX: content = base_content + pattern; break; case SearchType::INFIX: // Fix: embed pattern inside a word, not as separate word size_t split_pos = base_content.length() / 2; content = base_content.substr(0, split_pos) + pattern + base_content.substr(split_pos); break; } } else { // Half don't have the pattern - generate different content content = GenerateWordSequence(8 + (i % 3), i + 1000); } MockedDocument doc{Map{{"title", content}}}; indices.Add(i, doc); } SearchAlgorithm algo{}; QueryParams params; std::string query; // Generate query based on search type switch (search_type) { case SearchType::PREFIX: query = pattern + "*"; break; case SearchType::SUFFIX: query = "*" + pattern; break; case SearchType::INFIX: query = "*" + pattern + "*"; break; } if (!algo.Init(query, ¶ms)) { state.SkipWithError("Failed to initialize " + search_type_name + " search"); return; } while (state.KeepRunning()) { auto result = algo.Search(&indices); benchmark::DoNotOptimize(result); // If result has error, skip the benchmark if (!result.error.empty()) { state.SkipWithError(search_type_name + " search returned error: " + result.error); return; } } // Set counters for analysis state.counters["docs_total"] = num_docs; state.counters["pattern_length"] = pattern_len; state.counters["diverse_pattern"] = use_diverse_pattern ? 1 : 0; state.SetLabel(search_type_name + (use_diverse_pattern ? "_diverse" : "_uniform")); } // Instantiate template functions static void BM_SearchByType_Uniform(benchmark::State& state) { BM_SearchByTypeImpl(state, false); } static void BM_SearchByType_Diverse(benchmark::State& state) { BM_SearchByTypeImpl(state, true); } // Benchmark to compare all search types - removed 100K docs per romange's suggestion BENCHMARK(BM_SearchByType_Uniform) // Uniform patterns (original test) ->Args({1000, 3, static_cast(SearchType::PREFIX)}) ->Args({1000, 5, static_cast(SearchType::PREFIX)}) ->Args({10000, 3, static_cast(SearchType::PREFIX)}) ->Args({10000, 5, static_cast(SearchType::PREFIX)}) ->Args({1000, 3, static_cast(SearchType::SUFFIX)}) ->Args({1000, 5, static_cast(SearchType::SUFFIX)}) ->Args({10000, 3, static_cast(SearchType::SUFFIX)}) ->Args({10000, 5, static_cast(SearchType::SUFFIX)}) ->Args({1000, 3, static_cast(SearchType::INFIX)}) ->Args({1000, 5, static_cast(SearchType::INFIX)}) ->Args({10000, 3, static_cast(SearchType::INFIX)}) ->Args({10000, 5, static_cast(SearchType::INFIX)}) ->ArgNames({"docs", "pattern_len", "search_type"}) ->Unit(benchmark::kMicrosecond); BENCHMARK(BM_SearchByType_Diverse) // Diverse patterns (new test with ASCII variety) ->Args({1000, 3, static_cast(SearchType::PREFIX)}) ->Args({1000, 5, static_cast(SearchType::PREFIX)}) ->Args({10000, 3, static_cast(SearchType::PREFIX)}) ->Args({10000, 5, static_cast(SearchType::PREFIX)}) ->Args({1000, 3, static_cast(SearchType::SUFFIX)}) ->Args({1000, 5, static_cast(SearchType::SUFFIX)}) ->Args({10000, 3, static_cast(SearchType::SUFFIX)}) ->Args({10000, 5, static_cast(SearchType::SUFFIX)}) ->Args({1000, 3, static_cast(SearchType::INFIX)}) ->Args({1000, 5, static_cast(SearchType::INFIX)}) ->Args({10000, 3, static_cast(SearchType::INFIX)}) ->Args({10000, 5, static_cast(SearchType::INFIX)}) ->ArgNames({"docs", "pattern_len", "search_type"}) ->Unit(benchmark::kMicrosecond); // Helper function to generate random vector static std::vector GenerateRandomVector(size_t dims, unsigned seed = 42) { std::mt19937 gen(seed); std::uniform_real_distribution dis(-1.0f, 1.0f); std::vector vec(dims); for (size_t i = 0; i < dims; ++i) { vec[i] = dis(gen); } return vec; } static void BM_SearchDocIds(benchmark::State& state) { auto schema = MakeSimpleSchema({{"score", SchemaField::NUMERIC}, {"tag", SchemaField::TAG}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; default_random_engine rnd; const char* tag_vals[] = {"test", "example", "sample", "demo", "demo2"}; uniform_int_distribution tag_dist(0, ABSL_ARRAYSIZE(tag_vals) - 1); uniform_int_distribution score_dist(0, 100); for (size_t i = 0; i < 1000; i++) { MockedDocument doc{ Map{{"score", std::to_string(score_dist(rnd))}, {"tag", tag_vals[tag_dist(rnd)]}}}; indices.Add(i, doc); } std::string queries[] = {"@tag:{test} @score:[10 50]", "@tag: *", "@score:*"}; size_t query_type = state.range(0); CHECK_LT(query_type, ABSL_ARRAYSIZE(queries)); CHECK(algo.Init(queries[query_type], ¶ms)); while (state.KeepRunning()) { auto result = algo.Search(&indices); CHECK(result.error.empty()); } } BENCHMARK(BM_SearchDocIds)->Range(0, 2); static void BM_SearchNumericIndexes(benchmark::State& state) { auto schema = MakeSimpleSchema({{"numeric", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; default_random_engine rnd; using NumericType = long long; uniform_int_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); const size_t num_docs = state.range(0); for (size_t i = 0; i < num_docs; i++) { MockedDocument doc{Map{{"numeric", std::to_string(dist(rnd))}}}; indices.Add(i, doc); } std::string queries[] = {"@numeric:[15 +inf]", "@numeric:[-inf 20]", "@numeric:[-inf +inf]", "@numeric:[0 100000]"}; std::unordered_map> expected_results_per_num_docs = { {10000, {4982, 5018, 10000, 0}}, {100000, {49885, 50115, 100000, 0}}, {1000000, {500853, 499147, 1000000, 0}}, }; while (state.KeepRunning()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(queries); ++i) { const auto& query = queries[i]; CHECK(algo.Init(query, ¶ms)); auto result = algo.Search(&indices); CHECK(result.error.empty()); const size_t expected_result = expected_results_per_num_docs[num_docs][i]; CHECK_EQ(result.total, expected_result); CHECK_EQ(result.ids.size(), expected_result); } } } BENCHMARK(BM_SearchNumericIndexes)->Arg(10000)->Arg(100000)->Arg(1000000)->ArgNames({"num_docs"}); static void BM_SearchNumericIndexesSmallRanges(benchmark::State& state) { auto schema = MakeSimpleSchema({{"numeric", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; default_random_engine rnd; using NumericType = uint16_t; uniform_int_distribution dist(0, std::numeric_limits::max()); const size_t num_docs = state.range(0); // Insert zero values for (size_t i = 0; i < num_docs / 50; i++) { MockedDocument doc{Map{{"numeric", "0"}}}; indices.Add(i, doc); } for (size_t i = num_docs / 50; i < num_docs; i++) { MockedDocument doc{Map{{"numeric", std::to_string(dist(rnd))}}}; indices.Add(i, doc); } std::string queries[] = {"@numeric:[0 40000]", "@numeric:[-inf +inf]"}; std::unordered_map> expected_results_per_num_docs = { {100000, {61939, 100000}}, {1000000, {618365, 1000000}}, }; while (state.KeepRunning()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(queries); ++i) { const auto& query = queries[i]; CHECK(algo.Init(query, ¶ms)); auto result = algo.Search(&indices); CHECK(result.error.empty()); const size_t expected_result = expected_results_per_num_docs[num_docs][i]; CHECK_EQ(result.total, expected_result); CHECK_EQ(result.ids.size(), expected_result); } } } BENCHMARK(BM_SearchNumericIndexesSmallRanges) ->Arg(100000) // One block ->Arg(1000000) // Two blocks ->ArgNames({"num_docs"}); static void BM_SearchTwoNumericIndexes(benchmark::State& state) { auto schema = MakeSimpleSchema({ {"numeric1", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}, {"numeric2", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}, }); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; std::default_random_engine rnd; using NumericType = long long; uniform_int_distribution dist1(std::numeric_limits::min(), std::numeric_limits::max()); uniform_int_distribution dist2(std::numeric_limits::min(), std::numeric_limits::max()); const size_t num_docs = state.range(0); for (size_t i = 0; i < num_docs; ++i) { MockedDocument doc{Map{ {"numeric1", std::to_string(dist1(rnd))}, {"numeric2", std::to_string(dist2(rnd))}, }}; indices.Add(i, doc); } std::string queries[] = {absl::StrCat("@numeric1:[15 +inf] @numeric2:[-inf 20]"), absl::StrCat("@numeric1:[-inf 20] @numeric2:[15 +inf]"), absl::StrCat("@numeric1:[0 100000] @numeric2:[-100000 0]"), absl::StrCat("@numeric1:[-100000 0] @numeric2:[0 100000]")}; std::unordered_map> expected_results_per_num_docs = { {10000, {2508, 2507, 0, 0}}, {100000, {25119, 25232, 0, 0}}, {1000000, {250623, 250643, 0, 0}}, }; while (state.KeepRunning()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(queries); ++i) { const auto& query = queries[i]; CHECK(algo.Init(query, ¶ms)); auto result = algo.Search(&indices); CHECK(result.error.empty()); const size_t expected_result = expected_results_per_num_docs[num_docs][i]; CHECK_EQ(result.total, expected_result); CHECK_EQ(result.ids.size(), expected_result); } } } BENCHMARK(BM_SearchTwoNumericIndexes) ->Arg(10000) ->Arg(100000) ->Arg(1000000) ->ArgNames({"num_docs"}); static void BM_SearchNumericAndTagIndexes(benchmark::State& state) { auto schema = MakeSimpleSchema({{"tag", SchemaField::TAG}, {"numeric", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; default_random_engine rnd; using NumericType = long long; uniform_int_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); size_t tag_number = 0; const size_t max_tag_number = 1000; const size_t num_docs = state.range(0); for (size_t i = 0; i < num_docs; i++) { MockedDocument doc{ Map{{"tag", absl::StrCat("tag", tag_number)}, {"numeric", std::to_string(dist(rnd))}}}; indices.Add(i, doc); tag_number = (tag_number + 1) % max_tag_number; } std::string queries[] = {absl::StrCat("@tag:{tag230|tag3|tag942} @numeric:[15 +inf]"), absl::StrCat("@tag:{tag1|tag829|tag236} @numeric:[-inf 20]"), absl::StrCat("@tag:{tag0|tag999} @numeric:[-1000000 +inf]")}; std::unordered_map> expected_results_per_num_docs = { {10000, {19, 16, 8}}, {100000, {164, 157, 97}}, {1000000, {1528, 1518, 1017}}, }; while (state.KeepRunning()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(queries); ++i) { const auto& query = queries[i]; CHECK(algo.Init(query, ¶ms)); auto result = algo.Search(&indices); CHECK(result.error.empty()); const size_t expected_result = expected_results_per_num_docs[num_docs][i]; CHECK_EQ(result.total, expected_result); CHECK_EQ(result.ids.size(), expected_result); } } } BENCHMARK(BM_SearchNumericAndTagIndexes) ->Arg(10000) ->Arg(100000) ->Arg(1000000) ->ArgNames({"num_docs"}); static void BM_SearchSeveralNumericAndTagIndexes(benchmark::State& state) { auto schema = MakeSimpleSchema({{"tag", SchemaField::TAG}, {"numeric1", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}, {"numeric2", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}, {"numeric3", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}}); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; default_random_engine rnd; using NumericType = uint16_t; uniform_int_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); const size_t num_docs = state.range(0); size_t tag_number = 0; const size_t max_tag_number = num_docs / 30; for (size_t i = 0; i < num_docs; i++) { MockedDocument doc{Map{{"tag", absl::StrCat("tag", tag_number)}, {"numeric1", std::to_string(dist(rnd))}, {"numeric2", std::to_string(dist(rnd))}, {"numeric3", std::to_string(dist(rnd))}}}; indices.Add(i, doc); tag_number = (tag_number + 1) % max_tag_number; } std::string queries[] = { absl::StrCat( "@tag:{tag230|tag3} @numeric1:[0 10000] @numeric2:[20000 30000] @numeric3:[-1000 +inf]"), absl::StrCat("@tag:{tag829|tag236} @numeric1:[-inf 10000] @numeric2:[40000 +inf] " "@numeric3:[10000 30000]"), absl::StrCat( "@tag:{tag0|tag999} @numeric1:[-inf +inf] @numeric2:[20 +inf] @numeric3:[1000 10000]")}; std::unordered_map> expected_results_per_num_docs = { {10000, {1, 0, 4}}, {100000, {1, 1, 10}}, {1000000, {0, 1, 9}}, }; while (state.KeepRunning()) { for (size_t i = 0; i < ABSL_ARRAYSIZE(queries); ++i) { const auto& query = queries[i]; CHECK(algo.Init(query, ¶ms)); auto result = algo.Search(&indices); CHECK(result.error.empty()); const size_t expected_result = expected_results_per_num_docs[num_docs][i]; CHECK_EQ(result.total, expected_result); CHECK_EQ(result.ids.size(), expected_result); } } } BENCHMARK(BM_SearchSeveralNumericAndTagIndexes) ->Arg(10000) ->Arg(100000) ->Arg(1000000) ->ArgNames({"num_docs"}); static void BM_SearchMergeEqualSets(benchmark::State& state) { auto schema = MakeSimpleSchema({ {"numeric1", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}, {"numeric2", SchemaField::NUMERIC, SchemaField::NumericParams{.block_size = kMaxRangeBlockSize}}, }); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; SearchAlgorithm algo; QueryParams params; std::default_random_engine rnd; using NumericType = long long; uniform_int_distribution dist1(std::numeric_limits::min(), std::numeric_limits::max()); uniform_int_distribution dist2(std::numeric_limits::min(), std::numeric_limits::max()); const size_t num_docs = state.range(0); for (size_t i = 0; i < num_docs; ++i) { MockedDocument doc{Map{ {"numeric1", std::to_string(dist1(rnd))}, {"numeric2", std::to_string(dist2(rnd))}, }}; indices.Add(i, doc); } std::string query = absl::StrCat("@numeric1:[-inf +inf] @numeric2:[-inf +inf]"); while (state.KeepRunning()) { CHECK(algo.Init(query, ¶ms)); auto result = algo.Search(&indices); CHECK(result.error.empty()); // All documents should match both conditions, so total should equal num_docs CHECK_EQ(result.total, num_docs); CHECK_EQ(result.ids.size(), num_docs); } } BENCHMARK(BM_SearchMergeEqualSets) ->Arg(100) ->Arg(1000) ->Arg(10000) ->Arg(100000) ->Arg(1000000) ->ArgNames({"num_docs"}); static void BM_SearchRangeTreeSplits(benchmark::State& state) { auto schema = MakeSimpleSchema({ {"num", SchemaField::NUMERIC, SchemaField::NumericParams{}}, }); FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource(), nullptr}; const size_t batch_size = state.range(0); std::default_random_engine rnd; using NumericType = long long; uniform_int_distribution dist(0, batch_size + 1); size_t doc_index = 0; while (state.KeepRunning()) { for (size_t i = 0; i < batch_size; i++) { MockedDocument doc{Map{{"num", std::to_string(dist(rnd))}}}; indices.Add(doc_index++, doc); } } } BENCHMARK(BM_SearchRangeTreeSplits) ->Arg(100000) ->Arg(1000000) ->Arg(3000000) ->ArgNames({"batch_size"}); // Semantics test for cosine on zero vectors (independent of SimSIMD) TEST(CosineDistanceTest, ZeroVectors) { const size_t dims = 128; std::vector zero(dims, 0.0f); float d = VectorDistance(zero.data(), zero.data(), dims, VectorSimilarity::COSINE); EXPECT_EQ(d, 0.0f); } // Unified vector distance benchmarks using VectorDistance function static void BM_VectorDistance(benchmark::State& state) { // Ensure SimSIMD dynamic dispatch is initialized for the benchmark InitSimSIMD(); size_t dims = state.range(0); size_t num_pairs = state.range(1); VectorSimilarity sim = static_cast(state.range(2)); std::vector> vectors_a, vectors_b; vectors_a.reserve(num_pairs); vectors_b.reserve(num_pairs); for (size_t i = 0; i < num_pairs; ++i) { vectors_a.push_back(GenerateRandomVector(dims, i)); vectors_b.push_back(GenerateRandomVector(dims, i + 1000)); } size_t pair_idx = 0; for (auto _ : state) { float distance = VectorDistance(vectors_a[pair_idx].data(), vectors_b[pair_idx].data(), dims, sim); benchmark::DoNotOptimize(distance); pair_idx = (pair_idx + 1) % num_pairs; } state.counters["dims"] = dims; state.counters["pairs"] = num_pairs; std::string sim_name = (sim == VectorSimilarity::L2) ? "L2" : (sim == VectorSimilarity::COSINE) ? "Cosine" : "IP"; state.SetLabel(sim_name); } // Intensive benchmark with batch processing static void BM_VectorDistance_Intensive(benchmark::State& state) { // Ensure SimSIMD dynamic dispatch is initialized for the benchmark InitSimSIMD(); size_t dims = 512; // Fixed medium size size_t batch_size = 1000; VectorSimilarity sim = static_cast(state.range(0)); std::vector> vectors_a, vectors_b; vectors_a.reserve(batch_size); vectors_b.reserve(batch_size); for (size_t i = 0; i < batch_size; ++i) { vectors_a.push_back(GenerateRandomVector(dims, i)); vectors_b.push_back(GenerateRandomVector(dims, i + 4000)); } size_t total_ops = 0; while (state.KeepRunning()) { for (size_t i = 0; i < batch_size; ++i) { float distance = VectorDistance(vectors_a[i].data(), vectors_b[i].data(), dims, sim); benchmark::DoNotOptimize(distance); ++total_ops; } } state.counters["ops"] = total_ops; state.counters["ops_per_sec"] = benchmark::Counter(total_ops, benchmark::Counter::kIsRate); std::string sim_name = (sim == VectorSimilarity::L2) ? "L2" : (sim == VectorSimilarity::COSINE) ? "Cosine" : "IP"; state.SetLabel(sim_name + "_Intensive"); } // Benchmark declarations BENCHMARK(BM_VectorDistance) // Small vectors - L2 Distance ->Args({32, 100, static_cast(VectorSimilarity::L2)}) ->Args({32, 1000, static_cast(VectorSimilarity::L2)}) ->Args({32, 10000, static_cast(VectorSimilarity::L2)}) // Medium vectors - L2 Distance ->Args({128, 100, static_cast(VectorSimilarity::L2)}) ->Args({128, 1000, static_cast(VectorSimilarity::L2)}) ->Args({128, 10000, static_cast(VectorSimilarity::L2)}) // Large vectors - L2 Distance ->Args({512, 100, static_cast(VectorSimilarity::L2)}) ->Args({512, 1000, static_cast(VectorSimilarity::L2)}) ->Args({512, 5000, static_cast(VectorSimilarity::L2)}) // Very large vectors - L2 Distance ->Args({1536, 100, static_cast(VectorSimilarity::L2)}) ->Args({1536, 1000, static_cast(VectorSimilarity::L2)}) // Small vectors - Cosine Distance ->Args({32, 100, static_cast(VectorSimilarity::COSINE)}) ->Args({32, 1000, static_cast(VectorSimilarity::COSINE)}) ->Args({32, 10000, static_cast(VectorSimilarity::COSINE)}) // Medium vectors - Cosine Distance ->Args({128, 100, static_cast(VectorSimilarity::COSINE)}) ->Args({128, 1000, static_cast(VectorSimilarity::COSINE)}) ->Args({128, 10000, static_cast(VectorSimilarity::COSINE)}) // Large vectors - Cosine Distance ->Args({512, 100, static_cast(VectorSimilarity::COSINE)}) ->Args({512, 1000, static_cast(VectorSimilarity::COSINE)}) ->Args({512, 5000, static_cast(VectorSimilarity::COSINE)}) // Very large vectors - Cosine Distance ->Args({1536, 100, static_cast(VectorSimilarity::COSINE)}) ->Args({1536, 1000, static_cast(VectorSimilarity::COSINE)}) // Small vectors - IP Distance ->Args({32, 100, static_cast(VectorSimilarity::IP)}) ->Args({32, 1000, static_cast(VectorSimilarity::IP)}) ->Args({32, 10000, static_cast(VectorSimilarity::IP)}) // Medium vectors - IP Distance ->Args({128, 100, static_cast(VectorSimilarity::IP)}) ->Args({128, 1000, static_cast(VectorSimilarity::IP)}) ->Args({128, 10000, static_cast(VectorSimilarity::IP)}) // Large vectors - IP Distance ->Args({512, 100, static_cast(VectorSimilarity::IP)}) ->Args({512, 1000, static_cast(VectorSimilarity::IP)}) ->Args({512, 5000, static_cast(VectorSimilarity::IP)}) // Very large vectors - IP Distance ->Args({1536, 100, static_cast(VectorSimilarity::IP)}) ->Args({1536, 1000, static_cast(VectorSimilarity::IP)}) ->ArgNames({"dims", "pairs", "similarity"}) ->Unit(benchmark::kMicrosecond); BENCHMARK(BM_VectorDistance_Intensive) ->Arg(static_cast(VectorSimilarity::L2)) ->Arg(static_cast(VectorSimilarity::COSINE)) ->Arg(static_cast(VectorSimilarity::IP)) ->ArgNames({"similarity_type"}) ->Unit(benchmark::kMicrosecond); } // namespace search } // namespace dfly ================================================ FILE: src/core/search/sort_indices.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/sort_indices.h" #include #include #include #include #include #include #include namespace dfly::search { using namespace std; namespace { template using ScoreT = std::conditional_t, std::string, T>; } // namespace template bool SimpleValueSortIndex::ParsedSortValue::HasValue() const { return !std::holds_alternative(value); } template bool SimpleValueSortIndex::ParsedSortValue::IsNullValue() const { return std::holds_alternative(value); } template SortableValue SimpleValueSortIndex::Lookup(DocId doc) const { DCHECK_LT(doc, occupied_.size()); if (!occupied_[doc]) return std::monostate{}; DCHECK_LT(doc, values_.size()); return ScoreT{values_[doc]}; } template std::vector SimpleValueSortIndex::Sort(std::vector* ids, size_t limit, bool desc) const { auto cb = [this, desc](const auto& lhs, const auto& rhs) { // null values are at the end auto p1 = make_pair(!occupied_[lhs], cref(values_[lhs])); auto p2 = make_pair(!occupied_[rhs], cref(values_[rhs])); return desc ? (p1 > p2) : (p1 < p2); }; std::partial_sort(ids->begin(), ids->begin() + std::min(ids->size(), limit), ids->end(), cb); // Turn stateless string into std::string vector out(min(ids->size(), limit)); for (size_t i = 0; i < out.size(); i++) out[i] = ScoreT{values_[(*ids)[i]]}; return out; } template bool SimpleValueSortIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { auto field_value = Get(doc, field); if (!field_value.HasValue()) { return false; } if (id >= values_.size()) { values_.resize(id + 1); occupied_.resize(id + 1); } if (!field_value.IsNullValue()) { values_[id] = std::move(std::get(field_value.value)); occupied_[id] = true; } return true; } template void SimpleValueSortIndex::Remove(DocId id, const DocumentAccessor& doc, std::string_view field) { DCHECK_LT(id, values_.size()); DCHECK_EQ(values_.size(), occupied_.size()); values_[id] = T{}; occupied_[id] = false; } template std::vector SimpleValueSortIndex::GetAllDocsWithNonNullValues() const { std::vector result; result.reserve(values_.size()); for (DocId id = 0; id < values_.size(); ++id) { if (occupied_[id]) result.push_back(id); } return result; } template struct SimpleValueSortIndex; template struct SimpleValueSortIndex; SimpleValueSortIndex::ParsedSortValue NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) { auto numbers_list = doc.GetNumbers(field); if (!numbers_list) { return {}; } if (numbers_list->empty()) { return ParsedSortValue{std::nullopt}; } return ParsedSortValue{numbers_list->front()}; } SimpleValueSortIndex::ParsedSortValue StringSortIndex::Get( const DocumentAccessor& doc, std::string_view field) { auto strings_list = doc.GetTags(field); if (!strings_list) { return {}; } if (strings_list->empty()) { return ParsedSortValue{std::nullopt}; } return ParsedSortValue{StatelessString{strings_list->front()}}; } } // namespace dfly::search ================================================ FILE: src/core/search/sort_indices.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include "core/search/base.h" #include "core/search/stateless_allocator.h" namespace dfly::search { using StatelessString = std::basic_string, StatelessSearchAllocator>; static_assert(sizeof(StatelessString) == sizeof(std::string)); template using StatelessVector = std::vector>; static_assert(sizeof(StatelessVector) == sizeof(std::vector)); template struct SimpleValueSortIndex : BaseSortIndex { protected: struct ParsedSortValue { bool HasValue() const; bool IsNullValue() const; // std::monostate - no value was found. // std::nullopt - found value is null. // T - found value. std::variant value; }; public: SortableValue Lookup(DocId doc) const override; std::vector Sort(std::vector* ids, size_t limit, bool desc) const override; bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; // Override GetAllResults to return all documents with non-null values std::vector GetAllDocsWithNonNullValues() const override; protected: virtual ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field_value) = 0; private: StatelessVector values_; StatelessVector occupied_; // instead of optional in values to avoid memory overhead }; struct NumericSortIndex : SimpleValueSortIndex { ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override; }; // TODO: Map tags to integers for fast sort struct StringSortIndex : SimpleValueSortIndex { ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override; }; } // namespace dfly::search ================================================ FILE: src/core/search/stateless_allocator.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #pragma once #include #include "base/pmr/memory_resource.h" #include "core/detail/stateless_allocator.h" namespace dfly { namespace detail { inline thread_local PMR_NS::memory_resource* search_tl_mr = nullptr; } template class StatelessSearchAllocator : public StatelessAllocatorBase> { public: StatelessSearchAllocator() noexcept { assert(detail::search_tl_mr != nullptr); } template StatelessSearchAllocator(const StatelessSearchAllocator&) noexcept { // NOLINT } static PMR_NS::memory_resource* resource() { return detail::search_tl_mr; } }; template bool operator==(const StatelessSearchAllocator&, const StatelessSearchAllocator&) noexcept { return true; } template bool operator!=(const StatelessSearchAllocator&, const StatelessSearchAllocator&) noexcept { return false; } inline void InitTLSearchMR(PMR_NS::memory_resource* mr) { detail::search_tl_mr = mr; } } // namespace dfly ================================================ FILE: src/core/search/synonyms.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "synonyms.h" #include #include namespace dfly::search { const absl::flat_hash_map& Synonyms::GetGroups() const { return groups_; } void Synonyms::UpdateGroup(const std::string_view& id, const std::vector& terms) { auto& group = groups_[id]; // Convert all terms to lowercase before adding them to the group for (const std::string_view& term : terms) { group.insert(una::cases::to_lowercase_utf8(term)); } } std::optional Synonyms::GetGroupToken(std::string term) const { term = una::cases::to_lowercase_utf8(term); for (const auto& [id, group] : groups_) { if (group.count(term)) { // Add space before group id to avoid matching the term itself return absl::StrCat(" ", id); } } return std::nullopt; } } // namespace dfly::search ================================================ FILE: src/core/search/synonyms.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace dfly::search { // Class that manages synonym groups for search indices. // Allows defining groups of related terms that should be considered equivalent during search. // All terms are converted to lowercase for normalization. // // When retrieving a group token via GetGroupToken, the group identifier is returned with a space // prefix. The space is intentionally added to avoid matching with the term itself during text // tokenization and to distinguish the group identifier from regular terms during search. class Synonyms { public: // Represents a group of synonymous terms using Group = absl::flat_hash_set; // Get all synonym groups const absl::flat_hash_map& GetGroups() const; // Update or create a synonym group void UpdateGroup(const std::string_view& id, const std::vector& terms); // Get the group ID for a term std::optional GetGroupToken(std::string term) const; private: // Maps group ID to synonym group absl::flat_hash_map groups_; }; } // namespace dfly::search ================================================ FILE: src/core/search/tag_types.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once namespace dfly { namespace search { enum class TagType { PREFIX, SUFFIX, INFIX, REGULAR }; } // namespace search } // namespace dfly ================================================ FILE: src/core/search/vector_utils.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/search/vector_utils.h" #include #include #include "base/logging.h" namespace dfly::search { using namespace std; namespace { #ifdef WITH_SIMSIMD #include #endif #if defined(__GNUC__) && !defined(__clang__) #define FAST_MATH __attribute__((optimize("fast-math"))) #else #define FAST_MATH #endif OwnedFtVector ConvertToFtVector(string_view value) { // Value cannot be casted directly as it might be not aligned as a float (4 bytes). // Misaligned memory access is UB. size_t size = value.size() / sizeof(float); auto out = make_unique(size); memcpy(out.get(), value.data(), size * sizeof(float)); return OwnedFtVector{std::move(out), size}; } } // namespace // Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 ) FAST_MATH float L2Distance(const float* u, const float* v, size_t dims) { #ifdef WITH_SIMSIMD simsimd_distance_t distance = 0; simsimd_l2_f32(u, v, dims, &distance); return static_cast(distance); #else float sum = 0; for (size_t i = 0; i < dims; i++) sum += (u[i] - v[i]) * (u[i] - v[i]); return sqrt(sum); #endif } // Inner product distance: 1 - dot_product(u, v) // For normalized vectors, this is equivalent to cosine distance FAST_MATH float IPDistance(const float* u, const float* v, size_t dims) { #ifdef WITH_SIMSIMD // Use SimSIMD dot product and convert to inner product distance: 1 - dot(u, v). simsimd_distance_t dot = 0; simsimd_dot_f32(u, v, dims, &dot); return 1.0f - static_cast(dot); #else float sum_uv = 0; for (size_t i = 0; i < dims; i++) sum_uv += u[i] * v[i]; return 1.0f - sum_uv; #endif } // Cosine distance: 1 - (dot_product(u, v) / (||u|| * ||v||)) FAST_MATH float CosineDistance(const float* u, const float* v, size_t dims) { #ifdef WITH_SIMSIMD simsimd_distance_t distance = 0; simsimd_cos_f32(u, v, dims, &distance); return static_cast(distance); #else float sum_uv = 0, sum_uu = 0, sum_vv = 0; for (size_t i = 0; i < dims; i++) { sum_uv += u[i] * v[i]; sum_uu += u[i] * u[i]; sum_vv += v[i] * v[i]; } if (float denom = sum_uu * sum_vv; denom != 0.0f) return 1 - sum_uv / sqrt(denom); return 0.0f; #endif } OwnedFtVector BytesToFtVector(string_view value) { DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size(); return ConvertToFtVector(value); } std::optional BytesToFtVectorSafe(string_view value) { if (value.size() % sizeof(float)) { return std::nullopt; } return ConvertToFtVector(value); } float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) { switch (sim) { case VectorSimilarity::L2: return L2Distance(u, v, dims); case VectorSimilarity::IP: return IPDistance(u, v, dims); case VectorSimilarity::COSINE: return CosineDistance(u, v, dims); }; return 0.0f; } void InitSimSIMD() { #if defined(WITH_SIMSIMD) (void)simsimd_capabilities(); #endif } } // namespace dfly::search ================================================ FILE: src/core/search/vector_utils.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include "core/search/base.h" namespace dfly::search { // Initializes SimSIMD runtime if dynamic dispatch is enabled. void InitSimSIMD(); OwnedFtVector BytesToFtVector(std::string_view value); // Returns std::nullopt if value can not be converted to the vector // TODO: Remove unsafe version std::optional BytesToFtVectorSafe(std::string_view value); float L2Distance(const float* u, const float* v, size_t dims); float IPDistance(const float* u, const float* v, size_t dims); float CosineDistance(const float* u, const float* v, size_t dims); float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim); } // namespace dfly::search ================================================ FILE: src/core/segment_allocator.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/segment_allocator.h" #define MI_BUILD_RELEASE 1 #include #include "base/logging.h" namespace dfly { SegmentAllocator::SegmentAllocator(mi_heap_t* heap) : heap_(heap) { // 256GB constexpr size_t limit = 1ULL << 35; static_assert((1ULL << (kSegmentIdBits + kSegmentShift)) == limit); // mimalloc uses 32MiB segments and we might need change this code if it changes. static_assert(kSegmentShift == MI_SEGMENT_SHIFT); static_assert((~kSegmentAlignMask) == (MI_SEGMENT_MASK)); } void SegmentAllocator::ValidateMapSize() { if (address_table_.size() > (1u << kSegmentIdBits)) { // This can happen if we restrict dragonfly to small number of threads on high-memory machine, // for example. LOG(WARNING) << "address_table_ map is growing too large: " << address_table_.size(); } } bool SegmentAllocator::CanAllocate() { return address_table_.size() < (1u << kSegmentIdBits); } } // namespace dfly ================================================ FILE: src/core/segment_allocator.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include /*** * This class is tightly coupled with mimalloc segment allocation logic and is designed to provide * a compact pointer representation (4bytes ptr) over 64bit address space that gives you * 32GB of allocations with option to extend it to 32*256GB if needed. * */ namespace dfly { /** * @brief Tightly coupled with mi_malloc 2.x implementation. * Fetches 32MiB segment pointers from the allocated pointers. * Provides own indexing of small pointers to real address space using the segment ptrs/ */ class SegmentAllocator { // (2 ^ 10) total segments static constexpr uint32_t kSegmentIdBits = 10; static constexpr uint32_t kSegmentIdMask = (1u << kSegmentIdBits) - 1; // (2 ^ 25) total bytes per segment = 32MiB static constexpr uint32_t kSegmentShift = 25; // Segment range that we cover within a single segment. static constexpr uint64_t kSegmentAlignMask = ~((1ULL << kSegmentShift) - 1); public: using Ptr = uint32_t; SegmentAllocator(mi_heap_t* heap); bool CanAllocate(); uint8_t* Translate(Ptr p) const { return address_table_[p & kSegmentIdMask] + Offset(p); } std::pair Allocate(uint32_t size); void Free(Ptr ptr) { void* p = Translate(ptr); used_ -= mi_usable_size(p); mi_free(p); } mi_heap_t* heap() { return heap_; } size_t used() const { return used_; } private: static uint32_t Offset(Ptr p) { return (p >> kSegmentIdBits) * 8; } void ValidateMapSize(); std::vector address_table_; absl::flat_hash_map rev_indx_; mi_heap_t* heap_; size_t used_ = 0; }; inline auto SegmentAllocator::Allocate(uint32_t size) -> std::pair { void* ptr = mi_heap_malloc(heap_, size); if (!ptr) throw std::bad_alloc{}; uint64_t iptr = (uint64_t)ptr; uint64_t seg_ptr = iptr & kSegmentAlignMask; // could be speed up using last used seg_ptr. auto [it, inserted] = rev_indx_.emplace(seg_ptr, address_table_.size()); if (inserted) { ValidateMapSize(); address_table_.push_back((uint8_t*)seg_ptr); } uint32_t seg_offset = (iptr - seg_ptr) / 8; Ptr res = (seg_offset << kSegmentIdBits) | it->second; used_ += mi_good_size(size); return std::make_pair(res, (uint8_t*)ptr); } } // namespace dfly ================================================ FILE: src/core/size_tracking_channel.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "util/fibers/simple_channel.h" namespace dfly { // SimpleQueue-like interface, but also keeps track over the size of Ts it owns. // It has a slightly less efficient TryPush() API as it forces construction of Ts even if they are // not pushed. // T must have a .size() method, which should return the heap-allocated size of T, excluding // anything included in sizeof(T). We could generalize this in the future. template > class SizeTrackingChannel { public: SizeTrackingChannel(size_t n, unsigned num_producers = 1) : queue_(n, num_producers) { } // Here and below, we must accept a T instead of building it from variadic args, as we need to // know its size in case it is added. size_t Push(T t) noexcept { size_t tsize = t.size(); size_t res = size_.fetch_add(tsize, std::memory_order_relaxed); queue_.Push(std::move(t)); return res + tsize; } bool TryPush(T t) noexcept { const size_t size = t.size(); if (queue_.TryPush(std::move(t))) { size_.fetch_add(size, std::memory_order_relaxed); return true; } return false; } bool Pop(T& dest) { if (queue_.Pop(dest)) { size_.fetch_sub(dest.size(), std::memory_order_relaxed); return true; } return false; } void StartClosing() { queue_.StartClosing(); } bool TryPop(T& dest) { if (queue_.TryPop(dest)) { size_.fetch_sub(dest.size(), std::memory_order_relaxed); return true; } return false; } bool IsClosing() const { return queue_.IsClosing(); } size_t GetSize() const { return queue_.Capacity() * sizeof(T) + size_.load(std::memory_order_relaxed); } private: util::fb2::SimpleChannel queue_; std::atomic size_ = 0; }; } // namespace dfly ================================================ FILE: src/core/small_string.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/small_string.h" #include #include #include #include "base/logging.h" #include "core/page_usage/page_usage_stats.h" #include "core/segment_allocator.h" namespace dfly { using namespace std; namespace { class XXH3_Deleter { public: void operator()(XXH3_state_t* ptr) const { XXH3_freeState(ptr); } }; struct TL { unique_ptr xxh_state; unique_ptr seg_alloc; }; thread_local TL tl; constexpr XXH64_hash_t kHashSeed = 24061983; // same as in compact_object.cc } // namespace void SmallString::InitThreadLocal(void* heap) { SegmentAllocator* ns = new SegmentAllocator((mi_heap_t*)heap); tl.seg_alloc.reset(ns); tl.xxh_state.reset(XXH3_createState()); XXH3_64bits_reset_withSeed(tl.xxh_state.get(), kHashSeed); } bool SmallString::CanAllocate(size_t size) { return size <= kMaxSize && tl.seg_alloc->CanAllocate(); } size_t SmallString::UsedThreadLocal() { return tl.seg_alloc ? tl.seg_alloc->used() : 0; } static_assert(sizeof(SmallString) == 16); size_t SmallString::Assign(std::string_view s) { DCHECK_GT(s.size(), kPrefLen); DCHECK(CanAllocate(s.size())); uint8_t* realptr = nullptr; // reallocate if we need a larger allocation or it becomes space-inefficient size_t heap_len = s.size() - kPrefLen; if (size_t available = MallocUsed(); available < heap_len || heap_len * 2 < available) { Free(); auto [sp, rp] = tl.seg_alloc->Allocate(heap_len); small_ptr_ = sp; realptr = rp; } else { realptr = tl.seg_alloc->Translate(small_ptr_); } size_ = s.size(); memcpy(prefix_, s.data(), kPrefLen); memcpy(realptr, s.data() + kPrefLen, heap_len); return mi_malloc_usable_size(realptr); } void SmallString::Free() { if (size_) tl.seg_alloc->Free(small_ptr_); size_ = 0; } uint16_t SmallString::MallocUsed() const { if (size_) return mi_malloc_usable_size(tl.seg_alloc->Translate(small_ptr_)); return 0; } bool SmallString::Equal(std::string_view o) const { if (size_ != o.size()) return false; if (size_ == 0) return true; if (memcmp(prefix_, o.data(), kPrefLen) != 0) return false; uint8_t* realp = tl.seg_alloc->Translate(small_ptr_); return memcmp(realp, o.data() + kPrefLen, size_ - kPrefLen) == 0; } bool SmallString::Equal(const SmallString& os) const { if (size_ != os.size_) return false; return Get() == os.Get(); } uint64_t SmallString::HashCode() const { array slice = Get(); XXH3_state_t* state = tl.xxh_state.get(); XXH3_64bits_reset_withSeed(state, kHashSeed); XXH3_64bits_update(state, slice[0].data(), slice[0].size()); XXH3_64bits_update(state, slice[1].data(), slice[1].size()); return XXH3_64bits_digest(state); } array SmallString::Get() const { DCHECK(size_); array dest; dest[0] = string_view{prefix_, kPrefLen}; uint8_t* ptr = tl.seg_alloc->Translate(small_ptr_); dest[1] = string_view{reinterpret_cast(ptr), size_ - kPrefLen}; return dest; } void SmallString::Get(char* out) const { auto strs = Get(); memcpy(out, strs[0].data(), strs[0].size()); memcpy(out + strs[0].size(), strs[1].data(), strs[1].size()); } void SmallString::Get(std::string* dest) const { dest->resize(size_); Get(dest->data()); } bool SmallString::DefragIfNeeded(PageUsage* page_usage) { uint8_t* cur_real_ptr = tl.seg_alloc->Translate(small_ptr_); if (!page_usage->IsPageForObjectUnderUtilized(tl.seg_alloc->heap(), cur_real_ptr)) return false; if (!CanAllocate(size_ - kPrefLen)) // Forced return false; auto [sp, rp] = tl.seg_alloc->Allocate(size_ - kPrefLen); memcpy(rp, cur_real_ptr, size_ - kPrefLen); tl.seg_alloc->Free(small_ptr_); small_ptr_ = sp; return true; } } // namespace dfly ================================================ FILE: src/core/small_string.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly { class PageUsage; // Efficient storage of strings longer than 10 bytes. // Requires explicit memory management class SmallString { static constexpr unsigned kPrefLen = 10; static constexpr unsigned kMaxSize = (1 << 8) - 1; public: static void InitThreadLocal(void* heap); static size_t UsedThreadLocal(); static bool CanAllocate(size_t size); // Returns malloc used. size_t Assign(std::string_view s); void Free(); bool Equal(std::string_view o) const; bool Equal(const SmallString& mps) const; uint64_t HashCode() const; uint16_t MallocUsed() const; std::array Get() const; void Get(char* out) const; void Get(std::string* dest) const; bool DefragIfNeeded(PageUsage* page_usage); size_t size() const { return size_; } uint8_t first_byte() const { return prefix_[0]; } private: // The string is stored broken up into two parts, the first one - in this array char prefix_[kPrefLen]; uint32_t small_ptr_; // 32GB capacity because we ignore 3 lsb bits (i.e. x8). uint16_t size_; // uint16_t - total size (including prefix) } __attribute__((packed)); } // namespace dfly ================================================ FILE: src/core/sorted_map.cc ================================================ // Copyright 2023, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #include "core/sorted_map.h" #include #include extern "C" { #include "redis/listpack.h" #include "redis/redis_aux.h" #include "redis/util.h" #include "redis/zmalloc.h" } #include #include "base/endian.h" #include "base/logging.h" using namespace std; namespace dfly { namespace detail { namespace { double GetObjScore(const void* obj) { sds s = (sds)obj; char* ptr = s + sdslen(s) + 1; return absl::bit_cast(absl::little_endian::Load64(ptr)); } void SetObjScore(void* obj, double score) { sds s = (sds)obj; char* ptr = s + sdslen(s) + 1; absl::little_endian::Store64(ptr, absl::bit_cast(score)); } // buf must be at least 10 chars long. void* BuildScoredKey(double score, char buf[]) { buf[0] = SDS_TYPE_5; // length 0. buf[1] = 0; absl::little_endian::Store64(buf + 2, absl::bit_cast(score)); void* key = buf + 1; return key; } // Copied from t_zset.c /* Returns 1 if the double value can safely be represented in long long without * precision loss, in which case the corresponding long long is stored in the out variable. */ static int double2ll(double d, long long* out) { #if (DBL_MANT_DIG >= 52) && (DBL_MANT_DIG <= 63) && (LLONG_MAX == 0x7fffffffffffffffLL) /* Check if the float is in a safe range to be casted into a * long long. We are assuming that long long is 64 bit here. * Also we are assuming that there are no implementations around where * double has precision < 52 bit. * * Under this assumptions we test if a double is inside a range * where casting to long long is safe. Then using two castings we * make sure the decimal part is zero. If all this is true we can use * integer without precision loss. * * Note that numbers above 2^52 and below 2^63 use all the fraction bits as real part, * and the exponent bits are positive, which means the "decimal" part must be 0. * i.e. all double values in that range are representable as a long without precision loss, * but not all long values in that range can be represented as a double. * we only care about the first part here. */ if (d < (double)(-LLONG_MAX / 2) || d > (double)(LLONG_MAX / 2)) return 0; long long ll = d; if (ll == d) { *out = ll; return 1; } #endif return 0; } /* Compare element in sorted set with given element. */ int zzlCompareElements(unsigned char* eptr, unsigned char* cstr, unsigned int clen) { unsigned char* vstr; unsigned int vlen; long long vlong; unsigned char vbuf[32]; int minlen, cmp; vstr = lpGetValue(eptr, &vlen, &vlong); if (vstr == NULL) { /* Store string representation of long long in buf. */ vlen = ll2string((char*)vbuf, sizeof(vbuf), vlong); vstr = vbuf; } minlen = (vlen < clen) ? vlen : clen; cmp = memcmp(vstr, cstr, minlen); if (cmp == 0) return vlen - clen; return cmp; } using double_conversion::DoubleToStringConverter; constexpr unsigned kConvFlags = DoubleToStringConverter::UNIQUE_ZERO; DoubleToStringConverter score_conv(kConvFlags, "inf", "nan", 'e', -6, 21, 6, 0); // Copied from redis code but uses double_conversion to encode double values. unsigned char* ZzlInsertAt(unsigned char* zl, unsigned char* eptr, std::string_view ele, double score) { unsigned char* sptr; char scorebuf[128]; unsigned scorelen = 0; long long lscore; int score_is_long = double2ll(score, &lscore); if (!score_is_long) { // Use double converter to get the shortest representation. double_conversion::StringBuilder sb(scorebuf, sizeof(scorebuf)); score_conv.ToShortest(score, &sb); scorelen = sb.position(); sb.Finalize(); DCHECK_EQ(scorelen, strlen(scorebuf)); } // Argument parsing converts empty strings to default initialized string views. // Such string views have a null data field, which if passed into lpAppend (via zzlInsertAt) // results in the replace operation being applied on the listpack. In addition to being wrong, it // also causes assertion failures. To circumvent this corner case we pass here a string view // pointing to an empty string on the stack, which has a non-null data field. if (ele.data() == nullptr) { ele = ""sv; } if (eptr == NULL) { zl = lpAppend(zl, (const unsigned char*)(ele.data()), ele.size()); if (score_is_long) zl = lpAppendInteger(zl, lscore); else zl = lpAppend(zl, (unsigned char*)scorebuf, scorelen); } else { /* Insert member before the element 'eptr'. */ zl = lpInsertString(zl, (const unsigned char*)ele.data(), ele.size(), eptr, LP_BEFORE, &sptr); /* Insert score after the member. */ if (score_is_long) zl = lpInsertInteger(zl, lscore, sptr, LP_AFTER, NULL); else zl = lpInsertString(zl, (unsigned char*)scorebuf, scorelen, sptr, LP_AFTER, NULL); } return zl; } double ZzlStrtod(unsigned char* vstr, unsigned int vlen) { char buf[128]; if (vlen > sizeof(buf)) vlen = sizeof(buf); memcpy(buf, vstr, vlen); buf[vlen] = '\0'; return strtod(buf, NULL); } /* Return a listpack element as an SDS string. */ sds LpGetObject(const uint8_t* sptr) { unsigned char* vstr; unsigned int vlen; long long vlong; serverAssert(sptr != NULL); vstr = lpGetValue(const_cast(sptr), &vlen, &vlong); if (vstr) { return sdsnewlen((char*)vstr, vlen); } else { return sdsfromlonglong(vlong); } } // static representation of sds strings char kMinStrData[] = "\110" "minstring"; char kMaxStrData[] = "\110" "maxstring"; } // namespace double ZzlGetScore(const uint8_t* sptr) { unsigned char* vstr; unsigned int vlen; long long vlong; double score; DCHECK(sptr != NULL); vstr = lpGetValue(const_cast(sptr), &vlen, &vlong); if (vstr) { score = ZzlStrtod(vstr, vlen); } else { score = vlong; } return score; } /* Move to the previous entry based on the values in eptr and sptr. Both are * set to NULL when there is no prev entry. */ void ZzlPrev(const uint8_t* zl, uint8_t** eptr, uint8_t** sptr) { unsigned char *_eptr, *_sptr; serverAssert(*eptr != NULL && *sptr != NULL); _sptr = lpPrev(const_cast(zl), *eptr); if (_sptr != NULL) { _eptr = lpPrev(const_cast(zl), _sptr); DCHECK(_eptr != NULL); } else { /* No previous entry. */ _eptr = NULL; } *eptr = _eptr; *sptr = _sptr; } /* Move to next entry based on the values in eptr and sptr. Both are set to * NULL when there is no next entry. */ void ZzlNext(const uint8_t* zl, uint8_t** eptr, uint8_t** sptr) { unsigned char *_eptr, *_sptr; DCHECK(*eptr != NULL && *sptr != NULL); _eptr = lpNext(const_cast(zl), *sptr); if (_eptr != NULL) { _sptr = lpNext(const_cast(zl), _eptr); DCHECK(_sptr != NULL); } else { /* No next entry. */ _sptr = NULL; } *eptr = _eptr; *sptr = _sptr; } /* Free a lex range structure, must be called only after zslParseLexRange() * populated the structure with success (C_OK returned). */ void ZslFreeLexRange(const zlexrangespec* spec) { if (spec->min != cminstring && spec->min != cmaxstring) sdsfree(spec->min); if (spec->max != cminstring && spec->max != cmaxstring) sdsfree(spec->max); } /* This is just a wrapper to sdscmp() that is able to * handle shared.minstring and shared.maxstring as the equivalent of * -inf and +inf for strings */ int sdscmplex(sds a, sds b) { if (a == b) return 0; if (a == cminstring || b == cmaxstring) return -1; if (a == cmaxstring || b == cminstring) return 1; return sdscmp(a, b); } int zslLexValueGteMin(sds value, const zlexrangespec* spec) { return spec->minex ? (sdscmplex(value, spec->min) > 0) : (sdscmplex(value, spec->min) >= 0); } int zslLexValueLteMax(sds value, const zlexrangespec* spec) { return spec->maxex ? (sdscmplex(value, spec->max) < 0) : (sdscmplex(value, spec->max) <= 0); } int ZzlLexValueGteMin(unsigned char* p, const zlexrangespec* spec) { sds value = LpGetObject(p); int res = zslLexValueGteMin(value, spec); sdsfree(value); return res; } int ZzlLexValueLteMax(unsigned char* p, const zlexrangespec* spec) { sds value = LpGetObject(p); int res = zslLexValueLteMax(value, spec); sdsfree(value); return res; } /* Returns if there is a part of the zset is in range. Should only be used * internally by zzlFirstInRange and zzlLastInRange. */ int zzlIsInRange(unsigned char* zl, const zrangespec* range) { unsigned char* p; double score; /* Test for ranges that will always be empty. */ if (range->min > range->max || (range->min == range->max && (range->minex || range->maxex))) return 0; p = lpSeek(zl, -1); /* Last score. */ if (p == NULL) return 0; /* Empty sorted set */ score = ZzlGetScore(p); if (!ZslValueGteMin(score, range)) return 0; p = lpSeek(zl, 1); /* First score. */ serverAssert(p != NULL); score = ZzlGetScore(p); if (!ZslValueLteMax(score, range)) return 0; return 1; } /* Find pointer to the first element contained in the specified range. * Returns NULL when no element is contained in the range. */ unsigned char* ZzlFirstInRange(unsigned char* zl, const zrangespec* range) { unsigned char *eptr = lpSeek(zl, 0), *sptr; double score; /* If everything is out of range, return early. */ if (!zzlIsInRange(zl, range)) return NULL; while (eptr != NULL) { sptr = lpNext(zl, eptr); serverAssert(sptr != NULL); score = ZzlGetScore(sptr); if (ZslValueGteMin(score, range)) { /* Check if score <= max. */ if (ZslValueLteMax(score, range)) return eptr; return NULL; } /* Move to next element. */ eptr = lpNext(zl, sptr); } return NULL; } /* Find pointer to the last element contained in the specified range. * Returns NULL when no element is contained in the range. */ unsigned char* ZzlLastInRange(unsigned char* zl, const zrangespec* range) { unsigned char *eptr = lpSeek(zl, -2), *sptr; double score; /* If everything is out of range, return early. */ if (!zzlIsInRange(zl, range)) return NULL; while (eptr != NULL) { sptr = lpNext(zl, eptr); serverAssert(sptr != NULL); score = ZzlGetScore(sptr); if (ZslValueLteMax(score, range)) { /* Check if score >= min. */ if (ZslValueGteMin(score, range)) return eptr; return NULL; } /* Move to previous element by moving to the score of previous element. * When this returns NULL, we know there also is no element. */ sptr = lpPrev(zl, eptr); if (sptr != NULL) serverAssert((eptr = lpPrev(zl, sptr)) != NULL); else eptr = NULL; } return NULL; } /* Returns if there is a part of the zset is in range. Should only be used * internally by zzlFirstInRange and zzlLastInRange. */ int ZzlIsInLexRange(unsigned char* zl, const zlexrangespec* range) { unsigned char* p; /* Test for ranges that will always be empty. */ int cmp = sdscmplex(range->min, range->max); if (cmp > 0 || (cmp == 0 && (range->minex || range->maxex))) return 0; p = lpSeek(zl, -2); /* Last element. */ if (p == NULL) return 0; if (!ZzlLexValueGteMin(p, range)) return 0; p = lpSeek(zl, 0); /* First element. */ serverAssert(p != NULL); if (!ZzlLexValueLteMax(p, range)) return 0; return 1; } /* Find pointer to the first element contained in the specified lex range. * Returns NULL when no element is contained in the range. */ unsigned char* ZzlFirstInLexRange(unsigned char* zl, const zlexrangespec* range) { unsigned char *eptr = lpSeek(zl, 0), *sptr; /* If everything is out of range, return early. */ if (!ZzlIsInLexRange(zl, range)) return NULL; while (eptr != NULL) { if (ZzlLexValueGteMin(eptr, range)) { /* Check if score <= max. */ if (ZzlLexValueLteMax(eptr, range)) return eptr; return NULL; } /* Move to next element. */ sptr = lpNext(zl, eptr); /* This element score. Skip it. */ serverAssert(sptr != NULL); eptr = lpNext(zl, sptr); /* Next element. */ } return NULL; } /* Find pointer to the last element contained in the specified lex range. * Returns NULL when no element is contained in the range. */ unsigned char* ZzlLastInLexRange(unsigned char* zl, const zlexrangespec* range) { unsigned char *eptr = lpSeek(zl, -2), *sptr; /* If everything is out of range, return early. */ if (!ZzlIsInLexRange(zl, range)) return NULL; while (eptr != NULL) { if (ZzlLexValueLteMax(eptr, range)) { /* Check if score >= min. */ if (ZzlLexValueGteMin(eptr, range)) return eptr; return NULL; } /* Move to previous element by moving to the score of previous element. * When this returns NULL, we know there also is no element. */ sptr = lpPrev(zl, eptr); if (sptr != NULL) serverAssert((eptr = lpPrev(zl, sptr)) != NULL); else eptr = NULL; } return NULL; } unsigned char* ZzlDeleteRangeByLex(unsigned char* zl, const zlexrangespec* range, unsigned long* deleted) { unsigned char *eptr, *sptr; unsigned long num = 0; if (deleted != NULL) *deleted = 0; eptr = ZzlFirstInLexRange(zl, range); if (eptr == NULL) return zl; /* When the tail of the listpack is deleted, eptr will be NULL. */ while (eptr && (sptr = lpNext(zl, eptr)) != NULL) { if (ZzlLexValueLteMax(eptr, range)) { /* Delete both the element and the score. */ zl = lpDeleteRangeWithEntry(zl, &eptr, 2); num++; } else { /* No longer in range. */ break; } } if (deleted != NULL) *deleted = num; return zl; } unsigned char* ZzlDeleteRangeByScore(unsigned char* zl, const zrangespec* range, unsigned long* deleted) { unsigned char *eptr, *sptr; double score; unsigned long num = 0; if (deleted != NULL) *deleted = 0; eptr = ZzlFirstInRange(zl, range); if (eptr == NULL) return zl; /* When the tail of the listpack is deleted, eptr will be NULL. */ while (eptr && (sptr = lpNext(zl, eptr)) != NULL) { score = ZzlGetScore(sptr); if (ZslValueLteMax(score, range)) { /* Delete both the element and the score. */ zl = lpDeleteRangeWithEntry(zl, &eptr, 2); num++; } else { /* No longer in range. */ break; } } if (deleted != NULL) *deleted = num; return zl; } /* Insert (element,score) pair in listpack. This function assumes the element is * not yet present in the list. */ unsigned char* ZzlInsert(unsigned char* zl, std::string_view ele, double score) { unsigned char *eptr = NULL, *sptr = lpSeek(zl, -1); double s; // Optimization: check first whether the new element should be the last. if (sptr != NULL) { s = ZzlGetScore(sptr); if (s >= score) { // It should not be the last, so fallback to the forward iteration. eptr = lpSeek(zl, 0); } } while (eptr != NULL) { sptr = lpNext(zl, eptr); s = ZzlGetScore(sptr); if (s > score) { /* First element with score larger than score for element to be * inserted. This means we should take its spot in the list to * maintain ordering. */ return ZzlInsertAt(zl, eptr, ele, score); } else if (s == score) { /* Ensure lexicographical ordering for elements. */ if (zzlCompareElements(eptr, (unsigned char*)ele.data(), ele.size()) > 0) { return ZzlInsertAt(zl, eptr, ele, score); } } /* Move to next element. */ eptr = lpNext(zl, sptr); } /* Push on tail of list when it was not yet inserted. */ return ZzlInsertAt(zl, NULL, ele, score); } unsigned char* ZzlFind(unsigned char* lp, std::string_view ele, double* score) { uint8_t *sptr, *eptr = lpFirst(lp); if (eptr == nullptr) return nullptr; eptr = lpFind(lp, eptr, (unsigned char*)ele.data(), ele.size(), 1); if (eptr) { sptr = lpNext(lp, eptr); serverAssert(sptr != NULL); /* Matching element, pull out score. */ if (score != nullptr) *score = ZzlGetScore(sptr); return eptr; } return nullptr; } SortedMap::SortedMap() : score_map(new ScoreMap), score_tree(new ScoreTree(StatelessAllocator::resource())) { } SortedMap::~SortedMap() { delete score_tree; delete score_map; } // Three way comparison of q and key. // Compares scores first and then the keys, unless q.ignore_score is set. // In that case only keys are compared. // In order to support close/open intervals, we introduce a special flag for +inf strings. // So, in case of score equality (or if scores are ignored), q.str_is_infinite means q > key, // and 1 is returned. int SortedMap::ScoreSdsPolicy::KeyCompareTo::operator()(Query q, ScoreSds key) const { sds sdsa = (sds)q.item; if (!q.ignore_score) { double sa = GetObjScore(sdsa); double sb = GetObjScore(key); if (sa < sb) return -1; if (sa > sb) return 1; } // if q.str_is_infinite is set, it means q > key at this point. if (q.str_is_infinite) return 1; return sdscmp(sdsa, (sds)key); } int SortedMap::AddElem(double score, std::string_view ele, int in_flags, int* out_flags, double* newscore) { // does not take ownership over ele. DCHECK(!isnan(score)); ScoreSds obj = nullptr; bool added = false; if (in_flags & ZADD_IN_XX) { obj = score_map->FindObj(ele); if (obj == nullptr) { *out_flags = ZADD_OUT_NOP; return 1; } } else { tie(obj, added) = score_map->AddOrSkip(ele, score); } if (added) { // Adding a new element. DCHECK_EQ(in_flags & ZADD_IN_XX, 0); *out_flags = ZADD_OUT_ADDED; *newscore = score; bool added = score_tree->Insert(obj); DCHECK(added); return 1; } // Updating an existing element. if ((in_flags & ZADD_IN_NX)) { // Updating an existing element. *out_flags = ZADD_OUT_NOP; return 1; } if (in_flags & ZADD_IN_INCR) { score += GetObjScore(obj); if (isnan(score)) { *out_flags = ZADD_OUT_NAN; return 0; } } // Update the score. CHECK(score_tree->Delete(obj)); SetObjScore(obj, score); CHECK(score_tree->Insert(obj)); *out_flags = ZADD_OUT_UPDATED; *newscore = score; return 1; } optional SortedMap::GetScore(std::string_view ele) const { ScoreSds obj = score_map->FindObj(ele); if (obj != nullptr) { return GetObjScore(obj); } return std::nullopt; } bool SortedMap::InsertNew(double score, std::string_view member) { DVLOG(2) << "InsertNew " << score << " " << member; auto [newk, added] = score_map->AddOrSkip(member, score); if (!added) return false; added = score_tree->Insert(newk); CHECK(added); return true; } optional SortedMap::GetRank(std::string_view ele, bool reverse) const { ScoreSds obj = score_map->FindObj(ele); if (obj == nullptr) return std::nullopt; optional rank = score_tree->GetRank(obj, reverse); DCHECK(rank); return *rank; } SortedMap::ScoredArray SortedMap::GetRange(const zrangespec& range, unsigned offset, unsigned limit, bool reverse) const { ScoredArray arr; if (score_tree->Size() <= offset || limit == 0) return arr; char buf[16]; if (reverse) { ScoreSds key = BuildScoredKey(range.max, buf); auto path = score_tree->LEQ(Query{key, false, !range.maxex}); if (path.Empty()) return arr; if (range.maxex && range.max == GetObjScore(path.Terminal())) { ++offset; } DCHECK_LE(GetObjScore(path.Terminal()), range.max); while (offset--) { if (!path.Prev()) return arr; } while (limit--) { ScoreSds ele = path.Terminal(); double score = GetObjScore(ele); if (range.min > score || (range.min == score && range.minex)) break; arr.emplace_back(string{(sds)ele, sdslen((sds)ele)}, score); if (!path.Prev()) break; } } else { ScoreSds key = BuildScoredKey(range.min, buf); auto path = score_tree->GEQ(Query{key, false, range.minex}); if (path.Empty()) return arr; while (offset--) { if (!path.Next()) return arr; } auto path2 = path; size_t num_elems = 0; // Count the number of elements in the range. while (limit--) { ScoreSds ele = path.Terminal(); double score = GetObjScore(ele); if (range.max < score || (range.max == score && range.maxex)) break; ++num_elems; if (!path.Next()) break; } // reserve enough space. arr.resize(num_elems); for (size_t i = 0; i < num_elems; ++i) { ScoreSds ele = path2.Terminal(); arr[i] = {string{(sds)ele, sdslen((sds)ele)}, GetObjScore(ele)}; path2.Next(); } } return arr; } SortedMap::ScoredArray SortedMap::GetLexRange(const zlexrangespec& range, unsigned offset, unsigned limit, bool reverse) const { if (score_tree->Size() <= offset || limit == 0) return {}; detail::BPTreePath path; ScoredArray arr; if (reverse) { if (range.max != cmaxstring) { path = score_tree->LEQ(Query{range.max, true}); if (path.Empty()) return {}; if (range.maxex && sdscmp((sds)path.Terminal(), range.max) == 0) { ++offset; } while (offset--) { if (!path.Prev()) return {}; } } else { path = score_tree->FromRank(score_tree->Size() - offset - 1); } while (limit--) { ScoreSds ele = path.Terminal(); if (range.min != cminstring) { int cmp = sdscmp((sds)ele, range.min); if (cmp < 0 || (cmp == 0 && range.minex)) break; } arr.emplace_back(string{(sds)ele, sdslen((sds)ele)}, GetObjScore(ele)); if (!path.Prev()) break; } } else { if (range.min != cminstring) { path = score_tree->GEQ(Query{range.min, true}); if (path.Empty()) return {}; if (range.minex && sdscmp((sds)path.Terminal(), range.min) == 0) { ++offset; } while (offset--) { if (!path.Next()) return {}; } } else { path = score_tree->FromRank(offset); } while (limit--) { ScoreSds ele = path.Terminal(); if (range.max != cmaxstring) { int cmp = sdscmp((sds)ele, range.max); if (cmp > 0 || (cmp == 0 && range.maxex)) break; } arr.emplace_back(string{(sds)ele, sdslen((sds)ele)}, GetObjScore(ele)); if (!path.Next()) break; } } return arr; } uint8_t* SortedMap::ToListPack() const { uint8_t* lp = lpNew(0); score_tree->Iterate(0, UINT32_MAX, [&](ScoreSds ele) { const std::string_view v{(sds)ele, sdslen((sds)ele)}; lp = ZzlInsertAt(lp, NULL, v, GetObjScore(ele)); return true; }); return lp; } bool SortedMap::Delete(std::string_view ele) const { ScoreSds obj = score_map->FindObj(ele); if (obj == nullptr) return false; CHECK(score_tree->Delete(obj)); CHECK(score_map->Erase(ele)); return true; } size_t SortedMap::MallocSize() const { // TODO: add malloc used to BPTree. return score_map->SetMallocUsed() + score_map->ObjMallocUsed() + score_tree->NodeCount() * 256; } bool SortedMap::Reserve(size_t sz) { score_map->Reserve(sz); return true; } size_t SortedMap::DeleteRangeByRank(unsigned start, unsigned end) { DCHECK_LE(start, end); DCHECK_LT(end, score_tree->Size()); for (uint32_t i = start; i <= end; ++i) { /* Ideally, we would want to advance path to the next item and delete the previous one. * However, we can not do that because the path is invalidated after the * deletion. So we have to recreate the path for each item using the same rank. * Note, it is probably could be improved, but it's much more complicated. */ auto path = score_tree->FromRank(start); sds ele = (sds)path.Terminal(); score_tree->Delete(path); score_map->Erase(ele); } return end - start + 1; } size_t SortedMap::DeleteRangeByScore(const zrangespec& range) { char buf[16] = {0}; size_t deleted = 0; while (!score_tree->Empty()) { ScoreSds min_key = BuildScoredKey(range.min, buf); auto path = score_tree->GEQ(Query{min_key, false, range.minex}); if (path.Empty()) break; ScoreSds item = path.Terminal(); double score = GetObjScore(item); if (range.minex) { DCHECK_GT(score, range.min); } else { DCHECK_GE(score, range.min); } if (score > range.max || (range.maxex && score == range.max)) break; score_tree->Delete(item); ++deleted; score_map->Erase((sds)item); } return deleted; } size_t SortedMap::DeleteRangeByLex(const zlexrangespec& range) { if (score_tree->Size() == 0) return 0; size_t deleted = 0; uint32_t rank = 0; if (range.min != cminstring) { auto path = score_tree->GEQ(Query{range.min, true}); if (path.Empty()) return {}; rank = path.Rank(); if (range.minex && sdscmp((sds)path.Terminal(), range.min) == 0) { ++rank; } } while (rank < score_tree->Size()) { auto path = score_tree->FromRank(rank); ScoreSds item = path.Terminal(); if (range.max != cmaxstring) { int cmp = sdscmp((sds)item, range.max); if (cmp > 0 || (cmp == 0 && range.maxex)) break; } ++deleted; score_tree->Delete(path); score_map->Erase((sds)item); } return deleted; } SortedMap::ScoredArray SortedMap::PopTopScores(unsigned count, bool reverse) { DCHECK_GT(count, 0u); DCHECK_EQ(score_map->UpperBoundSize(), score_tree->Size()); size_t sz = score_map->UpperBoundSize(); ScoredArray res; DCHECK_GT(sz, 0u); // Empty sets are not allowed. if (sz == 0 || count == 0) return res; if (count > sz) count = sz; res.reserve(count); auto cb = [&](ScoreSds obj) { res.emplace_back(string{(sds)obj, sdslen((sds)obj)}, GetObjScore(obj)); // We can not delete from score_tree because we are in the middle of the iteration. CHECK(score_map->Erase((sds)obj)); return true; // continue with the iteration. }; unsigned rank = 0; unsigned step = 0; if (reverse) { score_tree->IterateReverse(0, count - 1, std::move(cb)); rank = score_tree->Size() - 1; step = 1; } else { score_tree->Iterate(0, count - 1, std::move(cb)); } // We already deleted elements from score_map, so what's left is to delete from the tree. if (score_map->Empty()) { // Corner case optimization. score_tree->Clear(); } else { for (unsigned i = 0; i < res.size(); ++i) { auto path = score_tree->FromRank(rank); score_tree->Delete(path); rank -= step; } } return res; } size_t SortedMap::Count(const zrangespec& range) const { DCHECK_LE(range.min, range.max); if (score_tree->Size() == 0) return 0; // build min key. char buf[16]; ScoreSds range_key = BuildScoredKey(range.min, buf); auto path = score_tree->GEQ(Query{range_key, false, range.minex}); if (path.Empty()) return 0; ScoreSds bound = path.Terminal(); if (range.minex) { DCHECK_GT(GetObjScore(bound), range.min); } else { DCHECK_GE(GetObjScore(bound), range.min); } uint32_t min_rank = path.Rank(); // Now build the max key. // If we need to exclude the maximum score, set the key'sstring part to empty string, // otherwise set it to infinity. range_key = BuildScoredKey(range.max, buf); path = score_tree->GEQ(Query{range_key, false, !range.maxex}); if (path.Empty()) { return score_tree->Size() - min_rank; } bound = path.Terminal(); uint32_t max_rank = path.Rank(); if (range.maxex || GetObjScore(bound) > range.max) { if (max_rank <= min_rank) return 0; --max_rank; } // max_rank could be less than min_rank, for example, if the range is [a, a). return max_rank < min_rank ? 0 : max_rank - min_rank + 1; } size_t SortedMap::LexCount(const zlexrangespec& range) const { if (score_tree->Size() == 0) return 0; // Ranges that will always be zero - (+inf, anything) or (anything, -inf) if (range.min == cmaxstring || range.max == cminstring) { return 0; } uint32_t min_rank = 0; detail::BPTreePath path; if (range.min != cminstring) { path = score_tree->GEQ(Query{range.min, true}); if (path.Empty()) return 0; min_rank = path.Rank(); if (range.minex && sdscmp((sds)path.Terminal(), range.min) == 0) { ++min_rank; if (min_rank >= score_tree->Size()) return 0; } } uint32_t max_rank = score_tree->Size() - 1; if (range.max != cmaxstring) { path = score_tree->GEQ(Query{range.max, true}); if (!path.Empty()) { max_rank = path.Rank(); // fix the max rank, if needed. int cmp = sdscmp((sds)path.Terminal(), range.max); DCHECK_GE(cmp, 0); if (cmp > 0 || range.maxex) { if (max_rank <= min_rank) return 0; --max_rank; } } } return max_rank < min_rank ? 0 : max_rank - min_rank + 1; } bool SortedMap::Iterate(unsigned start_rank, unsigned len, bool reverse, std::function cb) const { DCHECK_GT(len, 0u); unsigned end_rank = start_rank + len - 1; bool success; if (reverse) { success = score_tree->IterateReverse( start_rank, end_rank, [&](ScoreSds obj) { return cb((sds)obj, GetObjScore(obj)); }); } else { success = score_tree->Iterate(start_rank, end_rank, [&](ScoreSds obj) { return cb((sds)obj, GetObjScore(obj)); }); } return success; } uint64_t SortedMap::Scan(uint64_t cursor, absl::FunctionRef cb) const { auto scan_cb = [&cb](const void* obj) { sds ele = (sds)obj; cb(string_view{ele, sdslen(ele)}, GetObjScore(obj)); }; return this->score_map->Scan(cursor, std::move(scan_cb)); } // taken from zsetConvert SortedMap* SortedMap::FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp) { uint8_t* zl = (uint8_t*)lp; unsigned char *eptr, *sptr; unsigned char* vstr; unsigned int vlen; long long vlong; void* ptr = res->allocate(sizeof(SortedMap), alignof(SortedMap)); SortedMap* zs = new (ptr) SortedMap; eptr = lpSeek(zl, 0); if (eptr != NULL) { sptr = lpNext(zl, eptr); CHECK(sptr != NULL); } while (eptr != NULL) { double score = ZzlGetScore(sptr); vstr = lpGetValue(eptr, &vlen, &vlong); if (vstr == NULL) { CHECK(zs->InsertNew(score, absl::StrCat(vlong))); } else { CHECK(zs->InsertNew(score, string_view{reinterpret_cast(vstr), vlen})); } ZzlNext(zl, &eptr, &sptr); } return zs; } bool SortedMap::DefragIfNeeded(PageUsage* page_usage) { auto cb = [this](sds old_obj, sds new_obj) { score_tree->ForceUpdate(old_obj, new_obj); }; bool reallocated = false; for (auto it = score_map->begin(); it != score_map->end(); ++it) { reallocated |= it.ReallocIfNeeded(page_usage, cb); } return reallocated; } std::optional SortedMap::GetRankAndScore(std::string_view ele, bool reverse) const { ScoreSds obj = score_map->FindObj(ele); if (obj == nullptr) return std::nullopt; optional rank = score_tree->GetRank(obj, reverse); DCHECK(rank); return SortedMap::RankAndScore{*rank, GetObjScore(obj)}; } } // namespace detail sds cminstring = detail::kMinStrData + 1; sds cmaxstring = detail::kMaxStrData + 1; } // namespace dfly ================================================ FILE: src/core/sorted_map.h ================================================ // Copyright 2023, Roman Gershman. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include #include "core/bptree_set.h" #include "core/score_map.h" extern "C" { /* Struct to hold an inclusive/exclusive range spec by score comparison. */ typedef struct { double min, max; int minex, maxex; /* are min or max exclusive? */ } zrangespec; /* Struct to hold an inclusive/exclusive range spec by lexicographic comparison. */ typedef struct { sds min, max; /* May be set to shared.(minstring|maxstring) */ int minex, maxex; /* are min or max exclusive? */ } zlexrangespec; } // extern "C" /* Input flags. */ #define ZADD_IN_NONE 0 #define ZADD_IN_INCR (1 << 0) /* Increment the score instead of setting it. */ #define ZADD_IN_NX (1 << 1) /* Don't touch elements already existing. */ #define ZADD_IN_XX (1 << 2) /* Only touch elements already existing. */ #define ZADD_IN_GT (1 << 3) /* Only update existing when new scores are higher. */ #define ZADD_IN_LT (1 << 4) /* Only update existing when new scores are lower. */ /* Output flags. */ #define ZADD_OUT_NOP (1 << 0) /* Operation not performed because of conditionals.*/ #define ZADD_OUT_NAN (1 << 1) /* Only touch elements already existing. */ #define ZADD_OUT_ADDED (1 << 2) /* The element was new and was added. */ #define ZADD_OUT_UPDATED (1 << 3) /* The element already existed, score updated. */ namespace dfly { class PageUsage; // Copied from zset.h extern sds cmaxstring; extern sds cminstring; namespace detail { /** * @brief SortedMap is a sorted map implementation based on zset.h. It holds unique strings that * are ordered by score and lexicographically. The score is a double value and has higher priority. * The map is implemented as a skip list and a hash table. For more details see * zset.h and t_zset.c files in Redis. */ class SortedMap { public: using ScoredMember = std::pair; using ScoredArray = std::vector; using ScoreSds = void*; using RankAndScore = std::pair; SortedMap(); ~SortedMap(); SortedMap(const SortedMap&) = delete; SortedMap& operator=(const SortedMap&) = delete; bool Reserve(size_t sz); int AddElem(double score, std::string_view ele, int in_flags, int* out_flags, double* newscore); // Inserts a new element. Returns false if the element already exists. // No score update is performed in this case. bool InsertNew(double score, std::string_view member); bool Delete(std::string_view ele) const; // Upper bound size of the set. // Note: Currently we do not allow member expiry in sorted sets, therefore it's exact // But if we decide to add expire, this method will provide an approximation from above. size_t Size() const { return score_map->UpperBoundSize(); } size_t MallocSize() const; size_t DeleteRangeByRank(unsigned start, unsigned end); size_t DeleteRangeByScore(const zrangespec& range); size_t DeleteRangeByLex(const zlexrangespec& range); ScoredArray PopTopScores(unsigned count, bool reverse); std::optional GetScore(std::string_view ele) const; std::optional GetRank(std::string_view ele, bool reverse) const; std::optional GetRankAndScore(std::string_view ele, bool reverse) const; ScoredArray GetRange(const zrangespec& r, unsigned offs, unsigned len, bool rev) const; ScoredArray GetLexRange(const zlexrangespec& r, unsigned o, unsigned l, bool rev) const; size_t Count(const zrangespec& range) const; size_t LexCount(const zlexrangespec& range) const; // Runs cb for each element in the range [start_rank, start_rank + len). // Stops iteration if cb returns false. Returns false in this case. bool Iterate(unsigned start_rank, unsigned len, bool reverse, std::function cb) const; uint64_t Scan(uint64_t cursor, absl::FunctionRef cb) const; uint8_t* ToListPack() const; static SortedMap* FromListPack(PMR_NS::memory_resource* res, const uint8_t* lp); bool DefragIfNeeded(PageUsage* page_usage); private: struct Query { ScoreSds item; bool ignore_score; bool str_is_infinite; Query(ScoreSds key, bool ign_score = false, int is_inf = 0) : item(key), ignore_score(ign_score), str_is_infinite(is_inf != 0) { } }; struct ScoreSdsPolicy { using KeyT = ScoreSds; struct KeyCompareTo { int operator()(Query q, ScoreSds key) const; }; }; using ScoreTree = BPTree; // hash map from fields to scores. ScoreMap* score_map = nullptr; // sorted tree of (score,field) items. ScoreTree* score_tree = nullptr; }; // Used by CompactObject. unsigned char* ZzlInsert(unsigned char* zl, std::string_view ele, double score); unsigned char* ZzlFind(unsigned char* lp, std::string_view ele, double* score); // Used by SortedMap and ZsetFamily. double ZzlGetScore(const uint8_t* sptr); void ZzlNext(const uint8_t* zl, uint8_t** eptr, uint8_t** sptr); void ZzlPrev(const uint8_t* zl, uint8_t** eptr, uint8_t** sptr); void ZslFreeLexRange(const zlexrangespec* spec); uint8_t* ZzlLastInRange(uint8_t* zl, const zrangespec* range); uint8_t* ZzlFirstInRange(uint8_t* zl, const zrangespec* range); uint8_t* ZzlFirstInLexRange(uint8_t* zl, const zlexrangespec* range); uint8_t* ZzlLastInLexRange(uint8_t* zl, const zlexrangespec* range); int ZzlLexValueGteMin(uint8_t* p, const zlexrangespec* spec); int ZzlLexValueLteMax(uint8_t* p, const zlexrangespec* spec); uint8_t* ZzlDeleteRangeByLex(uint8_t* zl, const zlexrangespec* range, unsigned long* deleted); uint8_t* ZzlDeleteRangeByScore(uint8_t* zl, const zrangespec* range, unsigned long* deleted); inline int ZslValueGteMin(double value, const zrangespec* spec) { return spec->minex ? (value > spec->min) : (value >= spec->min); } inline int ZslValueLteMax(double value, const zrangespec* spec) { return spec->maxex ? (value < spec->max) : (value <= spec->max); } } // namespace detail } // namespace dfly ================================================ FILE: src/core/sorted_map_test.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/sorted_map.h" #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/mi_memory_resource.h" #include "core/page_usage/page_usage_stats.h" extern "C" { #include "redis/zmalloc.h" } using namespace std; using absl::StrCat; using testing::ElementsAre; using testing::Pair; using testing::StrEq; namespace dfly { using detail::SortedMap; class SortedMapTest : public ::testing::Test { protected: static void SetUpTestSuite() { // configure redis lib zmalloc which requires mimalloc heap to work. auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); InitTLStatelessAllocMR(PMR_NS::get_default_resource()); } SortedMap sm_; }; TEST_F(SortedMapTest, Add) { int out_flags; double new_score; int res = sm_.AddElem(1.0, "a", 0, &out_flags, &new_score); EXPECT_EQ(1, res); EXPECT_EQ(ZADD_OUT_ADDED, out_flags); EXPECT_EQ(1, new_score); res = sm_.AddElem(2.0, "a", ZADD_IN_NX, &out_flags, &new_score); EXPECT_EQ(1, res); EXPECT_EQ(ZADD_OUT_NOP, out_flags); res = sm_.AddElem(2.0, "a", ZADD_IN_INCR, &out_flags, &new_score); EXPECT_EQ(1, res); EXPECT_EQ(ZADD_OUT_UPDATED, out_flags); EXPECT_EQ(3, new_score); sds ele = sdsnew("a"); EXPECT_EQ(3, sm_.GetScore(ele)); sdsfree(ele); } TEST_F(SortedMapTest, Scan) { for (unsigned i = 0; i < 972; ++i) { sm_.InsertNew(i, StrCat(i)); } uint64_t cursor = 0; unsigned cnt = 0; do { cursor = sm_.Scan(cursor, [&](string_view str, double score) { ++cnt; }); } while (cursor != 0); EXPECT_EQ(972, cnt); } TEST_F(SortedMapTest, InsertPop) { for (unsigned i = 0; i < 256; ++i) { ASSERT_TRUE(sm_.InsertNew(1000, StrCat("a", i))); } vector vec; bool res = sm_.Iterate(1, 2, false, [&](sds ele, double score) { vec.push_back(ele); return true; }); EXPECT_TRUE(res); EXPECT_THAT(vec, ElementsAre(StrEq("a1"), StrEq("a10"))); sds s = sdsnew("a1"); EXPECT_EQ(1, sm_.GetRank(s, false)); EXPECT_EQ(254, sm_.GetRank(s, true)); sdsfree(s); auto top_scores = sm_.PopTopScores(3, false); EXPECT_THAT(top_scores, ElementsAre(Pair(StrEq("a0"), 1000), Pair(StrEq("a1"), 1000), Pair(StrEq("a10"), 1000))); top_scores = sm_.PopTopScores(3, true); EXPECT_THAT(top_scores, ElementsAre(Pair(StrEq("a99"), 1000), Pair(StrEq("a98"), 1000), Pair(StrEq("a97"), 1000))); } TEST_F(SortedMapTest, LexRanges) { for (unsigned i = 0; i < 100; ++i) { ASSERT_TRUE(sm_.InsertNew(1, StrCat("a", i))); } zlexrangespec range; range.max = sdsnew("a96"); range.min = sdsnew("a93"); range.maxex = 0; range.minex = 0; EXPECT_EQ(4, sm_.LexCount(range)); auto array = sm_.GetLexRange(range, 1, 1000, false); ASSERT_EQ(3, array.size()); EXPECT_THAT(array.front(), Pair("a94", 1)); range.maxex = 1; EXPECT_EQ(3, sm_.LexCount(range)); array = sm_.GetLexRange(range, 1, 1000, true); ASSERT_EQ(2, array.size()); EXPECT_THAT(array.front(), Pair("a94", 1)); range.minex = 1; EXPECT_EQ(2, sm_.LexCount(range)); array = sm_.GetLexRange(range, 1, 1000, false); ASSERT_EQ(1, array.size()); EXPECT_THAT(array.front(), Pair("a95", 1)); sdsfree(range.min); range.min = range.max; EXPECT_EQ(0, sm_.LexCount(range)); range.minex = 0; EXPECT_EQ(0, sm_.LexCount(range)); sdsfree(range.max); range.maxex = 0; range.min = cminstring; range.max = sdsnew("a"); EXPECT_EQ(0, sm_.LexCount(range)); sdsfree(range.max); range.max = sdsnew("a0"); EXPECT_EQ(1, sm_.LexCount(range)); range.maxex = 1; EXPECT_EQ(0, sm_.LexCount(range)); sdsfree(range.max); } TEST_F(SortedMapTest, ScoreRanges) { for (unsigned i = 0; i < 10; ++i) { ASSERT_TRUE(sm_.InsertNew(1, StrCat("a", i))); } for (unsigned i = 0; i < 10; ++i) { ASSERT_TRUE(sm_.InsertNew(2, StrCat("b", i))); } zrangespec range; range.max = 5; range.min = 1; range.maxex = 0; range.minex = 0; EXPECT_EQ(20, sm_.Count(range)); detail::SortedMap::ScoredArray array = sm_.GetRange(range, 0, 1000, false); ASSERT_EQ(20, array.size()); EXPECT_THAT(array.front(), Pair("a0", 1)); EXPECT_THAT(array.back(), Pair("b9", 2)); range.minex = 1; // exclude all the "1" scores. EXPECT_EQ(10, sm_.Count(range)); array = sm_.GetRange(range, 2, 1, false); ASSERT_EQ(1, array.size()); EXPECT_THAT(array.front(), Pair("b2", 2)); range.max = 1; range.minex = 0; range.min = -HUGE_VAL; EXPECT_EQ(10, sm_.Count(range)); array = sm_.GetRange(range, 2, 2, true); ASSERT_EQ(2, array.size()); EXPECT_THAT(array.back(), Pair("a6", 1)); range.maxex = 1; EXPECT_EQ(0, sm_.Count(range)); array = sm_.GetRange(range, 0, 2, true); ASSERT_EQ(0, array.size()); range.min = 3; array = sm_.GetRange(range, 0, 2, true); ASSERT_EQ(0, array.size()); } TEST_F(SortedMapTest, DeleteRange) { for (unsigned i = 0; i <= 100; ++i) { ASSERT_TRUE(sm_.InsertNew(i * 2, StrCat("a", i))); } zrangespec range; range.min = range.max = 200; range.minex = range.maxex = 1; EXPECT_EQ(0, sm_.DeleteRangeByScore(range)); range.min = 199; EXPECT_EQ(0, sm_.DeleteRangeByScore(range)); range.minex = 0; EXPECT_EQ(0, sm_.DeleteRangeByScore(range)); range.max = 199; range.min = 198; EXPECT_EQ(1, sm_.DeleteRangeByScore(range)); range.max = 197; range.min = 193; EXPECT_EQ(2, sm_.DeleteRangeByScore(range)); EXPECT_EQ(2, sm_.DeleteRangeByRank(0, 1)); zlexrangespec lex_range; lex_range.min = sdsnew("b"); lex_range.max = sdsnew("c"); EXPECT_EQ(0, sm_.DeleteRangeByLex(lex_range)); sdsfree(lex_range.min); sdsfree(lex_range.max); lex_range.min = cminstring; lex_range.max = cmaxstring; EXPECT_EQ(96, sm_.DeleteRangeByLex(lex_range)); } TEST_F(SortedMapTest, RangeBug) { constexpr size_t kArrLen = 80; for (unsigned i = 0; i < kArrLen; i++) { ASSERT_TRUE(sm_.InsertNew(i, StrCat("score", i))); } for (unsigned i = 0; i < kArrLen; i++) { zrangespec range; range.max = HUGE_VAL; range.min = i; range.minex = 0; range.maxex = 0; auto arr = sm_.GetRange(range, 0, 5, false); ASSERT_GT(arr.size(), 0) << i; } } uint64_t total_wasted_memory = 0; TEST_F(SortedMapTest, ReallocIfNeeded) { auto build_str = [](size_t i) { return to_string(i) + string(131, 'a'); }; auto count_waste = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { size_t used = block_size * area->used; total_wasted_memory += area->committed - used; return true; }; for (size_t i = 0; i < 10'000; i++) { int out_flags; double new_val; auto str = build_str(i); sm_.AddElem(i, str, 0, &out_flags, &new_val); } for (size_t i = 0; i < 10'000; i++) { if (i % 10 == 0) continue; auto str = build_str(i); sds ele = sdsnew(str.c_str()); sm_.Delete(ele); sdsfree(ele); } mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_before = total_wasted_memory; PageUsage page_usage{CollectPageStats::NO, 9}; ASSERT_TRUE(sm_.DefragIfNeeded(&page_usage)); total_wasted_memory = 0; mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_after = total_wasted_memory; // Check we waste significanlty less now EXPECT_GT(wasted_before, wasted_after * 2); ASSERT_EQ(sm_.Size(), 1000); auto cb = [i = 0, build_str](sds ele, double score) mutable -> bool { EXPECT_EQ(std::string_view(ele), build_str(i * 10)); EXPECT_EQ((size_t)score, i * 10); ++i; return true; }; sm_.Iterate(0, 10000, false, cb); } } // namespace dfly ================================================ FILE: src/core/sse_port.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #if defined(__aarch64__) #define SSE2NEON_SUPPRESS_WARNINGS #include "base/sse2neon.h" #elif defined(__riscv) || defined(__riscv__) #include "base/sse2rvv.h" #elif defined(__s390x__) #include #else #include #include #endif namespace dfly { #ifndef __s390x__ inline __m128i mm_loadu_si128(const __m128i* ptr) { #if defined(__aarch64__) __m128i res; memcpy(&res, ptr, sizeof(res)); return res; // return vreinterpretq_m128i_s32(vld1q_s32((const int32_t *) p)); #else return _mm_loadu_si128(ptr); #endif } #endif } // namespace dfly ================================================ FILE: src/core/string_map.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/string_map.h" #include "base/endian.h" #include "base/logging.h" #include "core/compact_object.h" #include "core/page_usage/page_usage_stats.h" #include "core/sds_utils.h" extern "C" { #include "redis/zmalloc.h" } using namespace std; namespace dfly { namespace { constexpr uint64_t kValTtlBit = 1ULL << 63; constexpr uint64_t kValMask = ~kValTtlBit; // Returns key, tagged value pair pair CreateEntry(string_view field, string_view value, uint32_t time_now, uint32_t ttl_sec) { // 8 additional bytes for a pointer to value. sds newkey; size_t meta_offset = field.size() + 1; sds sdsval = sdsnewlen(value.data(), value.size()); uint64_t sdsval_tag = uint64_t(sdsval); if (ttl_sec == UINT32_MAX) { // The layout is: // key, '\0', 8-byte pointer to value newkey = AllocSdsWithSpace(field.size(), 8); } else { // The layout is: // key, '\0', 8-byte pointer to value, 4-byte absolute time. // the value pointer it tagged. newkey = AllocSdsWithSpace(field.size(), 8 + 4); uint32_t at = time_now + ttl_sec; absl::little_endian::Store32(newkey + meta_offset + 8, at); // skip the value pointer. sdsval_tag |= kValTtlBit; } if (!field.empty()) { memcpy(newkey, field.data(), field.size()); } absl::little_endian::Store64(newkey + meta_offset, sdsval_tag); return {newkey, sdsval_tag}; } bool HasTtl(sds entry) { const uint64_t tag = absl::little_endian::Load64(entry + sdslen(entry) + 1); return (tag & kValTtlBit) != 0; } } // namespace StringMap::~StringMap() { Clear(); } bool StringMap::AddOrUpdate(std::string_view field, std::string_view value, uint32_t ttl_sec, bool keepttl) { sds prev = AddOrExchange(field, value, ttl_sec, keepttl); if (prev) { ObjDelete(prev, false); return false; } return true; } sds StringMap::AddOrExchange(std::string_view field, std::string_view value, uint32_t ttl_sec, bool keepttl) { const uint32_t computed_ttl = ComputeTtl(field, ttl_sec, keepttl); auto [newkey, sdsval_tag] = CreateEntry(field, value, time_now(), computed_ttl); auto prev_entry = static_cast(AddOrReplaceObj(newkey, sdsval_tag & kValTtlBit)); return prev_entry; } uint32_t StringMap::ComputeTtl(string_view field, uint32_t ttl_sec, bool keepttl) const { if (!keepttl) return ttl_sec; auto* prev = static_cast(FindInternal(&field, Hash(&field, 1), 1)); if (!prev) return ttl_sec; if (!HasTtl(prev)) return ttl_sec; return ObjExpireTime(prev) - time_now(); } bool StringMap::AddOrSkip(std::string_view field, std::string_view value, uint32_t ttl_sec) { uint64_t hashcode = Hash(&field, 1); void* obj = FindInternal(&field, hashcode, 1); // 1 - string_view if (obj) return false; auto [newkey, sdsval_tag] = CreateEntry(field, value, time_now(), ttl_sec); AddUnique(newkey, sdsval_tag & kValTtlBit, hashcode); return true; } bool StringMap::Erase(string_view key) { return EraseInternal(&key, 1); } StringMap::SdsEntry StringMap::Extract(string_view key) { return SdsEntry(static_cast(DetachInternal(const_cast(&key), 1)), DeleteEntry); } void StringMap::DeleteEntry(sds entry) { sds value = GetValue(entry); sdsfree(value); sdsfree(entry); } bool StringMap::Contains(string_view field) const { // 1 - means it's string_view. See ObjEqual for details. uint64_t hashcode = Hash(&field, 1); return FindInternal(&field, hashcode, 1) != nullptr; } optional> StringMap::RandomPair() { // Iteration may remove elements, and so we need to loop if we happen to reach the end while (true) { auto it = begin(); // It may be that begin() will invalidate all elements, getting us to an Empty() state if (Empty()) { break; } it += rand() % UpperBoundSize(); if (it != end()) { return std::make_pair(it->first, it->second); } } return nullopt; } void StringMap::RandomPairsUnique(unsigned int count, std::vector& keys, std::vector& vals, bool with_value) { unsigned int total_size = SizeSlow(); unsigned int index = 0; if (count > total_size) count = total_size; auto itr = begin(); uint32_t picked = 0, remaining = count; while (picked < count && itr != end()) { double random_double = ((double)rand()) / RAND_MAX; double threshold = ((double)remaining) / (total_size - index); if (random_double <= threshold) { keys.push_back(itr->first); if (with_value) { vals.push_back(itr->second); } remaining--; picked++; } ++itr; index++; } DCHECK(keys.size() == count); if (with_value) DCHECK(vals.size() == count); } void StringMap::RandomPairs(unsigned int count, std::vector& keys, std::vector& vals, bool with_value) { using RandomPick = std::pair; std::vector picks; unsigned int total_size = SizeSlow(); for (unsigned int i = 0; i < count; ++i) { RandomPick pick{rand() % total_size, i}; picks.push_back(pick); } std::sort(picks.begin(), picks.end(), [](auto& x, auto& y) { return x.first < y.first; }); unsigned int index = picks[0].first, pick_index = 0; auto itr = begin(); for (unsigned int i = 0; i < index; ++i) ++itr; keys.resize(count); if (with_value) vals.resize(count); while (itr != end() && pick_index < count) { auto [key, val] = *itr; while (pick_index < count && index == picks[pick_index].first) { int store_order = picks[pick_index].second; keys[store_order] = key; if (with_value) vals[store_order] = val; ++pick_index; } ++index; ++itr; } } sds StringMap::GetValue(sds key) { char* valptr = key + sdslen(key) + 1; const uint64_t val = absl::little_endian::Load64(valptr); return (sds)(kValMask & val); } pair StringMap::ReallocIfNeeded(void* obj, PageUsage* page_usage) { sds key = (sds)obj; size_t key_len = sdslen(key); auto* value_ptr = key + key_len + 1; uint64_t value_tag = absl::little_endian::Load64(value_ptr); sds value = (sds)(uint64_t(value_tag) & kValMask); bool realloced_value = false; // If the allocated value is underutilized, re-allocate it and update the pointer inside the key if (page_usage->IsPageForObjectUnderUtilized(value)) { size_t value_len = sdslen(value); sds new_value = sdsnewlen(value, value_len); memcpy(new_value, value, value_len); uint64_t new_value_tag = (uint64_t(new_value) & kValMask) | (value_tag & ~kValMask); absl::little_endian::Store64(value_ptr, new_value_tag); sdsfree(value); realloced_value = true; } if (!page_usage->IsPageForObjectUnderUtilized(key)) return {key, realloced_value}; size_t space_size = 8 /* value ptr */ + ((value_tag & kValTtlBit) ? 4 : 0) /* optional expiry */; sds new_key = AllocSdsWithSpace(key_len, space_size); memcpy(new_key, key, key_len + 1 /* \0 */ + space_size); sdsfree(key); return {new_key, true}; } uint64_t StringMap::Hash(const void* obj, uint32_t cookie) const { DCHECK_LT(cookie, 2u); if (cookie == 0) { sds s = (sds)obj; return CompactObj::HashCode(string_view{s, sdslen(s)}); } const string_view* sv = (const string_view*)obj; return CompactObj::HashCode(*sv); } bool StringMap::ObjEqual(const void* left, const void* right, uint32_t right_cookie) const { DCHECK_LT(right_cookie, 2u); sds s1 = (sds)left; if (right_cookie == 0) { sds s2 = (sds)right; if (sdslen(s1) != sdslen(s2)) { return false; } return sdslen(s1) == 0 || memcmp(s1, s2, sdslen(s1)) == 0; } const string_view* right_sv = (const string_view*)right; string_view left_sv{s1, sdslen(s1)}; return left_sv == (*right_sv); } size_t StringMap::ObjectAllocSize(const void* obj) const { sds s1 = (sds)obj; size_t res = zmalloc_usable_size(sdsAllocPtr(s1)); sds val = GetValue(s1); res += zmalloc_usable_size(sdsAllocPtr(val)); return res; } uint32_t StringMap::ObjExpireTime(const void* obj) const { sds str = (sds)obj; const char* valptr = str + sdslen(str) + 1; uint64_t val = absl::little_endian::Load64(valptr); DCHECK(val & kValTtlBit); if (val & kValTtlBit) { return absl::little_endian::Load32(valptr + 8); } // Should not reach. return UINT32_MAX; } void StringMap::ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) { return SdsUpdateExpireTime(obj, time_now() + ttl_sec, 8); } void StringMap::ObjDelete(void* obj, bool has_ttl) const { sds s1 = (sds)obj; sds value = GetValue(s1); sdsfree(value); sdsfree(s1); } void* StringMap::ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const { uint32_t ttl_sec = add_ttl ? 0 : (has_ttl ? ObjExpireTime(obj) : UINT32_MAX); sds str = (sds)obj; auto pair = detail::SdsPair(str, GetValue(str)); // Use explicit string_view constructor with length to preserve null characters string_view key_sv(pair->first, sdslen(pair->first)); string_view value_sv(pair->second, sdslen(pair->second)); auto [newkey, sdsval_tag] = CreateEntry(key_sv, value_sv, time_now(), ttl_sec); return (void*)newkey; } detail::SdsPair StringMap::iterator::BreakToPair(void* obj) { sds f = (sds)obj; return detail::SdsPair(f, GetValue(f)); } bool StringMap::iterator::ReallocIfNeeded(PageUsage* page_usage) { auto* ptr = curr_entry_; if (ptr->IsLink()) { ptr = ptr->AsLink(); } DCHECK(!ptr->IsEmpty()); DCHECK(ptr->IsObject()); auto* obj = ptr->GetObject(); auto [new_obj, realloced] = static_cast(owner_)->ReallocIfNeeded(obj, page_usage); ptr->SetObject(new_obj); return realloced; } } // namespace dfly ================================================ FILE: src/core/string_map.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "core/dense_set.h" extern "C" { #include "redis/sds.h" } namespace dfly { class PageUsage; namespace detail { class SdsPair { public: SdsPair(sds k, sds v) : first(k), second(v) { } SdsPair* operator->() { return this; } const SdsPair* operator->() const { return this; } operator std::pair() const { return {{first, sdslen(first)}, {second, sdslen(second)}}; } const sds first; const sds second; }; }; // namespace detail class StringMap : public DenseSet { public: explicit StringMap(void* unused = nullptr) { } ~StringMap(); class iterator : private DenseSet::IteratorBase { static detail::SdsPair BreakToPair(void* obj); public: iterator() : IteratorBase() { } explicit iterator(const IteratorBase& o) : IteratorBase(o) { } iterator(DenseSet* owner) : IteratorBase(owner, false) { } detail::SdsPair operator->() const { void* ptr = curr_entry_->GetObject(); return BreakToPair(ptr); } detail::SdsPair operator*() const { void* ptr = curr_entry_->GetObject(); return BreakToPair(ptr); } // Try reducing memory fragmentation of the value by re-allocating. Returns true if // re-allocation happened. bool ReallocIfNeeded(PageUsage* page_usage); iterator& operator++() { Advance(); return *this; } // Advances at most `n` steps, but stops at end. iterator& operator+=(unsigned int n) { for (unsigned int i = 0; i < n; ++i) { if (curr_entry_ == nullptr) { break; } Advance(); } return *this; } bool operator==(const iterator& b) const { if (owner_ == nullptr && b.owner_ == nullptr) { // to allow comparison with end() return true; } return owner_ == b.owner_ && curr_entry_ == b.curr_entry_; } bool operator!=(const iterator& b) const { return !(*this == b); } using IteratorBase::ExpiryTime; using IteratorBase::HasExpiry; using IteratorBase::SetExpiryTime; }; // Adds a new field or updates its value. Returns true if added, false if updated. bool AddOrUpdate(std::string_view field, std::string_view value, uint32_t ttl_sec = UINT32_MAX, bool keepttl = false); // Like AddOrUpdate but on update returns the previous sds entry // instead of deleting it. Caller must free the returned entry via DeleteEntry(). // Returns nullptr if a new field was added. sds AddOrExchange(std::string_view field, std::string_view value, uint32_t ttl_sec = UINT32_MAX, bool keepttl = false); // Returns true if field was added // false, if already exists. In that case no update is done. bool AddOrSkip(std::string_view field, std::string_view value, uint32_t ttl_sec = UINT32_MAX); bool Erase(std::string_view s1); using SdsEntry = std::unique_ptr; // Removes and returns the sds entry for the given key without freeing it. // Returns nullptr if the key was not found. SdsEntry Extract(std::string_view s1); // Frees a StringMap sds entry (key + embedded value). static void DeleteEntry(sds entry); bool Contains(std::string_view s1) const; /// @brief Returns value of the key or an empty iterator if key not found. /// @param key /// @return sds iterator Find(std::string_view member) { return iterator{FindIt(&member, 1)}; } iterator begin() { return iterator{this}; } iterator end() { return iterator{}; } // Returns a random key value pair. // Returns key only if value is a nullptr. std::optional> RandomPair(); // Randomly selects count of key value pairs. The selections are unique. // if count is larger than the total number of key value pairs, returns // every pair. // Executes at O(n) (i.e. slow for large sets). void RandomPairsUnique(unsigned int count, std::vector& keys, std::vector& vals, bool with_value); // Randomly selects count of key value pairs. The select key value pairs // are allowed to have duplications. // Executes at O(n) (i.e. slow for large sets). void RandomPairs(unsigned int count, std::vector& keys, std::vector& vals, bool with_value); static sds GetValue(sds key); private: // If keepttl is specified, performs a lookup for given field and computes ttl by comparing // existing expiry against time_now(). If keepttl is false, or field is not found, or it expires, // or the field has no ttl, returns ttl_sec. set_time() must have been called before computing // ttl. uint32_t ComputeTtl(std::string_view field, uint32_t ttl_sec, bool keepttl) const; // Reallocate key and/or value if their pages are underutilized. // Returns new pointer (stays same if key utilization is enough) and if reallocation happened. std::pair ReallocIfNeeded(void* obj, PageUsage* page_usage); uint64_t Hash(const void* obj, uint32_t cookie) const final; bool ObjEqual(const void* left, const void* right, uint32_t right_cookie) const final; size_t ObjectAllocSize(const void* obj) const final; uint32_t ObjExpireTime(const void* obj) const final; void ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) override; void ObjDelete(void* obj, bool has_ttl) const override; void* ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const final; }; } // namespace dfly ================================================ FILE: src/core/string_map_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/string_map.h" #include #include #include #include #include #include #include #include #include #include #include #include #include "base/logging.h" #include "core/compact_object.h" #include "core/detail/stateless_allocator.h" #include "core/page_usage/page_usage_stats.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { using namespace std; class StringMapTest : public ::testing::Test { protected: static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); InitTLStatelessAllocMR(PMR_NS::get_default_resource()); } static void TearDownTestSuite() { mi_heap_collect(mi_heap_get_backing(), true); auto cb_visit = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { LOG(ERROR) << "Unfreed allocations: block_size " << block_size << ", allocated: " << area->used * block_size; return true; }; mi_heap_visit_blocks(mi_heap_get_backing(), false /* do not visit all blocks*/, cb_visit, nullptr); } StringMapTest() : mi_alloc_(mi_heap_get_backing()) { } void SetUp() override { sm_.reset(new StringMap(&mi_alloc_)); } void TearDown() override { sm_.reset(); EXPECT_EQ(zmalloc_used_memory_tl, 0); } MiMemoryResource mi_alloc_; std::unique_ptr sm_; }; TEST_F(StringMapTest, Basic) { EXPECT_TRUE(sm_->AddOrUpdate("foo", "bar")); EXPECT_TRUE(sm_->Contains("foo")); auto it = sm_->Find("foo"); EXPECT_STREQ("bar", it->second); it = sm_->begin(); EXPECT_STREQ("foo", it->first); EXPECT_STREQ("bar", it->second); ++it; EXPECT_TRUE(it == sm_->end()); for (const auto& k_v : *sm_) { EXPECT_STREQ("foo", k_v.first); EXPECT_STREQ("bar", k_v.second); } size_t sz = sm_->ObjMallocUsed(); EXPECT_FALSE(sm_->AddOrUpdate("foo", "baraaaaaaaaaaaa2")); EXPECT_GT(sm_->ObjMallocUsed(), sz); it = sm_->begin(); EXPECT_STREQ("baraaaaaaaaaaaa2", it->second); EXPECT_FALSE(sm_->AddOrSkip("foo", "bar2")); EXPECT_STREQ("baraaaaaaaaaaaa2", it->second); } TEST_F(StringMapTest, EmptyFind) { sm_->Find("bar"); } TEST_F(StringMapTest, Ttl) { EXPECT_TRUE(sm_->AddOrUpdate("bla", "val1", 1)); EXPECT_FALSE(sm_->AddOrUpdate("bla", "val2", 1)); sm_->set_time(1); EXPECT_TRUE(sm_->AddOrUpdate("bla", "val2", 1)); EXPECT_EQ(1u, sm_->UpperBoundSize()); EXPECT_FALSE(sm_->AddOrSkip("bla", "val3", 2)); // set ttl to 2, meaning that the key will expire at time 3. EXPECT_TRUE(sm_->AddOrSkip("bla2", "val3", 2)); EXPECT_TRUE(sm_->Contains("bla2")); sm_->set_time(3); auto it = sm_->begin(); EXPECT_TRUE(it == sm_->end()); } TEST_F(StringMapTest, IterateExpired) { EXPECT_TRUE(sm_->AddOrUpdate("k1", "v1", 1)); EXPECT_TRUE(sm_->AddOrUpdate("k2", "v2", 1)); sm_->set_time(1); auto it = sm_->begin(); it += 1; EXPECT_EQ(it, sm_->end()); } TEST_F(StringMapTest, SetFieldExpireHasExpiry) { EXPECT_TRUE(sm_->AddOrUpdate("k1", "v1", 5)); auto k = sm_->Find("k1"); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 5); k.SetExpiryTime(1); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 1); } TEST_F(StringMapTest, SetFieldExpireNoHasExpiry) { EXPECT_TRUE(sm_->AddOrUpdate("k1", "v1")); auto k = sm_->Find("k1"); EXPECT_FALSE(k.HasExpiry()); k.SetExpiryTime(1); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 1); } TEST_F(StringMapTest, Bug3973) { for (unsigned i = 0; i < 8; i++) { EXPECT_TRUE(sm_->AddOrUpdate(to_string(i), "val")); } for (unsigned i = 0; i < 8; i++) { auto k = sm_->Find(to_string(i)); ASSERT_FALSE(k.HasExpiry()); k.SetExpiryTime(1); EXPECT_EQ(k.ExpiryTime(), 1); } for (unsigned i = 100; i < 1000; i++) { EXPECT_TRUE(sm_->AddOrUpdate(to_string(i), "val")); } // make sure the first 8 keys have expiry set for (unsigned i = 0; i < 8; i++) { auto k = sm_->Find(to_string(i)); ASSERT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 1); } } TEST_F(StringMapTest, Bug3984) { for (unsigned i = 0; i < 6; i++) { EXPECT_TRUE(sm_->AddOrUpdate(to_string(i), "val")); } for (unsigned i = 0; i < 6; i++) { auto k = sm_->Find(to_string(i)); ASSERT_FALSE(k.HasExpiry()); k.SetExpiryTime(1); EXPECT_EQ(k.ExpiryTime(), 1); } for (unsigned i = 0; i < 6; i++) { EXPECT_FALSE(sm_->AddOrUpdate(to_string(i), "val")); } } unsigned total_wasted_memory = 0; TEST_F(StringMapTest, ReallocIfNeeded) { auto build_str = [](size_t i) { return to_string(i) + string(131, 'a'); }; auto count_waste = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { size_t used = block_size * area->used; total_wasted_memory += area->committed - used; return true; }; for (size_t i = 0; i < 10'000; i++) sm_->AddOrUpdate(build_str(i), build_str(i + 1), i * 10 + 1); for (size_t i = 0; i < 10'000; i++) { if (i % 10 == 0) continue; sm_->Erase(build_str(i)); } mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_before = total_wasted_memory; size_t underutilized = 0; PageUsage page_usage{CollectPageStats::NO, 0.9}; for (auto it = sm_->begin(); it != sm_->end(); ++it) { underutilized += page_usage.IsPageForObjectUnderUtilized(it->first); it.ReallocIfNeeded(&page_usage); } // Check there are underutilized pages CHECK_GT(underutilized, 0u); total_wasted_memory = 0; mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_after = total_wasted_memory; // Check we waste significanlty less now EXPECT_GT(wasted_before, wasted_after * 2); EXPECT_EQ(sm_->UpperBoundSize(), 1000); for (size_t i = 0; i < 1000; i++) EXPECT_EQ(sm_->Find(build_str(i * 10))->second, build_str(i * 10 + 1)); } TEST_F(StringMapTest, ExpiryChangesSize) { sm_->AddOrUpdate("field", "value"); const size_t old_size = sm_->ObjMallocUsed(); auto it = sm_->Find("field"); it.SetExpiryTime(1); const size_t new_size = sm_->ObjMallocUsed(); EXPECT_LT(old_size, new_size); sm_->AddOrUpdate("field", "value", 1); EXPECT_EQ(new_size, sm_->ObjMallocUsed()); } TEST_F(StringMapTest, ExpiryWithMaxAndKeepTTL) { sm_->AddOrUpdate("field", "value", 100); auto k = sm_->Find("field"); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 100); // ttl is copied from prev. if max value is supplied sm_->AddOrUpdate("field", "value", UINT32_MAX, true); k = sm_->Find("field"); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 100); // max ttl value results in no expiry without keepttl sm_->AddOrUpdate("field", "value", UINT32_MAX); EXPECT_FALSE(sm_->Find("field").HasExpiry()); // No prev. expiry, supplied ttl_sec value is used sm_->AddOrUpdate("field", "value", 10, true); k = sm_->Find("field"); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 10); // object removed while adding due to expiry sm_->set_time(11); sm_->AddOrUpdate("field", "value", UINT32_MAX, true); k = sm_->Find("field"); EXPECT_FALSE(k.HasExpiry()); } TEST_F(StringMapTest, ExtractExisting) { sm_->AddOrUpdate("f1", "v1"); sm_->AddOrUpdate("f2", "v2"); EXPECT_EQ(sm_->UpperBoundSize(), 2u); auto entry = sm_->Extract("f1"); ASSERT_TRUE(entry); // Verify the extracted entry has the correct value sds val = StringMap::GetValue(entry.get()); EXPECT_EQ(string_view(val, sdslen(val)), "v1"); // Verify it was removed from the map EXPECT_EQ(sm_->UpperBoundSize(), 1u); EXPECT_FALSE(sm_->Contains("f1")); EXPECT_TRUE(sm_->Contains("f2")); } TEST_F(StringMapTest, ExtractNonExisting) { sm_->AddOrUpdate("f1", "v1"); auto entry = sm_->Extract("no_such_key"); EXPECT_FALSE(entry); EXPECT_EQ(sm_->UpperBoundSize(), 1u); } TEST_F(StringMapTest, AddOrExchangeNew) { // Adding a new field returns nullptr (no previous entry) sds prev = sm_->AddOrExchange("f1", "v1"); EXPECT_EQ(prev, nullptr); EXPECT_TRUE(sm_->Contains("f1")); EXPECT_STREQ(sm_->Find("f1")->second, "v1"); } TEST_F(StringMapTest, AddOrExchangeReplace) { sm_->AddOrUpdate("f1", "old_value"); EXPECT_EQ(sm_->UpperBoundSize(), 1u); sds prev = sm_->AddOrExchange("f1", "new_value"); ASSERT_NE(prev, nullptr); // Verify the extracted entry has the old value sds val = StringMap::GetValue(prev); EXPECT_EQ(string_view(val, sdslen(val)), "old_value"); // Verify map now has the new value EXPECT_STREQ(sm_->Find("f1")->second, "new_value"); EXPECT_EQ(sm_->UpperBoundSize(), 1u); StringMap::DeleteEntry(prev); } TEST_F(StringMapTest, AddOrExchangeWithTtl) { sm_->AddOrUpdate("f1", "v1", 100); sds prev = sm_->AddOrExchange("f1", "v2", 200); ASSERT_NE(prev, nullptr); sds val = StringMap::GetValue(prev); EXPECT_EQ(string_view(val, sdslen(val)), "v1"); // Make sure new entry has correct value and ttl auto it = sm_->Find("f1"); EXPECT_STREQ(it->second, "v2"); EXPECT_TRUE(it.HasExpiry()); EXPECT_EQ(it.ExpiryTime(), 200u); StringMap::DeleteEntry(prev); } TEST_F(StringMapTest, ExtractMultiple) { for (unsigned i = 0; i < 20; i++) { sm_->AddOrUpdate(to_string(i), "val" + to_string(i)); } EXPECT_EQ(sm_->UpperBoundSize(), 20u); // Extract every other entry vector extracted; for (unsigned i = 0; i < 20; i += 2) { auto entry = sm_->Extract(to_string(i)); ASSERT_TRUE(entry); extracted.push_back(std::move(entry)); } EXPECT_EQ(sm_->UpperBoundSize(), 10u); // Verify remaining entries for (unsigned i = 1; i < 20; i += 2) { EXPECT_TRUE(sm_->Contains(to_string(i))); } } } // namespace dfly ================================================ FILE: src/core/string_set.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/string_set.h" #include "absl/flags/flag.h" #include "core/compact_object.h" #include "core/page_usage/page_usage_stats.h" #include "core/sds_utils.h" extern "C" { #include "redis/sds.h" #include "redis/zmalloc.h" } #include "base/logging.h" using namespace std; namespace dfly { namespace { inline bool MayHaveTtl(sds s) { char* alloc_ptr = (char*)sdsAllocPtr(s); return sdslen(s) + 1 + 4 <= zmalloc_usable_size(alloc_ptr); } sds AllocImmutableWithTtl(uint32_t len, uint32_t at) { sds res = AllocSdsWithSpace(len, sizeof(at)); absl::little_endian::Store32(res + len + 1, at); // Save TTL return res; } } // namespace StringSet::~StringSet() { Clear(); } bool StringSet::Add(string_view src, uint32_t ttl_sec) { uint64_t hash = Hash(&src, 1); void* prev = FindInternal(&src, hash, 1); if (prev != nullptr) { return false; } sds newsds = MakeSetSds(src, ttl_sec); bool has_ttl = ttl_sec != UINT32_MAX; AddUnique(newsds, has_ttl, hash); return true; } unsigned StringSet::AddMany(absl::Span span, uint32_t ttl_sec, bool keepttl) { std::string_view views[kMaxBatchLen]; unsigned res = 0; if (BucketCount() < span.size()) { Reserve(span.size()); } while (span.size() >= kMaxBatchLen) { for (size_t i = 0; i < kMaxBatchLen; i++) views[i] = span[i]; span.remove_prefix(kMaxBatchLen); res += AddBatch(absl::MakeSpan(views), ttl_sec, keepttl); } if (span.size()) { for (size_t i = 0; i < span.size(); i++) views[i] = span[i]; res += AddBatch(absl::MakeSpan(views, span.size()), ttl_sec, keepttl); } return res; } unsigned StringSet::AddBatch(absl::Span span, uint32_t ttl_sec, bool keepttl) { uint64_t hash[kMaxBatchLen]; bool has_ttl = ttl_sec != UINT32_MAX; unsigned count = span.size(); unsigned res = 0; DCHECK_LE(count, kMaxBatchLen); for (size_t i = 0; i < count; i++) { hash[i] = CompactObj::HashCode(span[i]); Prefetch(hash[i]); } for (unsigned i = 0; i < count; ++i) { void* prev = FindInternal(&span[i], hash[i], 1); if (prev == nullptr) { ++res; sds field = MakeSetSds(span[i], ttl_sec); AddUnique(field, has_ttl, hash[i]); } else if (has_ttl && !keepttl) { ObjUpdateExpireTime(prev, ttl_sec); } } return res; } StringSet::iterator StringSet::GetRandomMember() { return iterator{DenseSet::GetRandomIterator()}; } std::optional StringSet::Pop() { sds str = (sds)PopInternal(); if (str == nullptr) { return std::nullopt; } std::string ret{str, sdslen(str)}; sdsfree(str); return ret; } uint32_t StringSet::Scan(uint32_t cursor, const std::function& func) const { return DenseSet::Scan(cursor, [func](const void* ptr) { func((sds)ptr); }); } uint64_t StringSet::Hash(const void* ptr, uint32_t cookie) const { DCHECK_LT(cookie, 2u); if (cookie == 0) { sds s = (sds)ptr; return CompactObj::HashCode(string_view{s, sdslen(s)}); } const string_view* sv = (const string_view*)ptr; return CompactObj::HashCode(*sv); } bool StringSet::ObjEqual(const void* left, const void* right, uint32_t right_cookie) const { DCHECK_LT(right_cookie, 2u); sds s1 = (sds)left; if (right_cookie == 0) { sds s2 = (sds)right; if (sdslen(s1) != sdslen(s2)) { return false; } return sdslen(s1) == 0 || memcmp(s1, s2, sdslen(s1)) == 0; } const string_view* right_sv = (const string_view*)right; string_view left_sv{s1, sdslen(s1)}; return left_sv == (*right_sv); } size_t StringSet::ObjectAllocSize(const void* s1) const { return zmalloc_usable_size(sdsAllocPtr((sds)s1)); } uint32_t StringSet::ObjExpireTime(const void* str) const { sds s = (sds)str; DCHECK(MayHaveTtl(s)); char* ttlptr = s + sdslen(s) + 1; return absl::little_endian::Load32(ttlptr); } void StringSet::ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) { return SdsUpdateExpireTime(obj, time_now() + ttl_sec, 0); } void StringSet::ObjDelete(void* obj, bool has_ttl) const { sdsfree((sds)obj); } void* StringSet::ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const { sds src = (sds)obj; string_view sv{src, sdslen(src)}; uint32_t ttl_sec = add_ttl ? 0 : (has_ttl ? ObjExpireTime(obj) : UINT32_MAX); return (void*)MakeSetSds(sv, ttl_sec); } sds StringSet::MakeSetSds(string_view src, uint32_t ttl_sec) const { if (ttl_sec != UINT32_MAX) { uint32_t at = time_now() + ttl_sec; sds newsds = AllocImmutableWithTtl(src.size(), at); if (!src.empty()) memcpy(newsds, src.data(), src.size()); return newsds; } return sdsnewlen(src.data(), src.size()); } // Does not release obj. Callers must deallocate with sdsfree explicitly pair StringSet::DuplicateEntryIfFragmented(void* obj, PageUsage* page_usage) { sds key = (sds)obj; if (!page_usage->IsPageForObjectUnderUtilized(key)) return {key, false}; size_t key_len = sdslen(key); bool has_ttl = MayHaveTtl(key); if (has_ttl) { sds res = AllocSdsWithSpace(key_len, sizeof(uint32_t)); std::memcpy(res, key, key_len + sizeof(uint32_t)); return {res, true}; } return {sdsnewlen(key, key_len), true}; } bool StringSet::iterator::ReallocIfNeeded(PageUsage* page_usage) { auto* ptr = curr_entry_; if (ptr->IsLink()) { ptr = ptr->AsLink(); } DCHECK(!ptr->IsEmpty()); DCHECK(ptr->IsObject()); auto* obj = ptr->GetObject(); auto [new_obj, realloced] = static_cast(owner_)->DuplicateEntryIfFragmented(obj, page_usage); if (realloced) { ptr->SetObject(new_obj); sdsfree((sds)obj); } return realloced; } } // namespace dfly ================================================ FILE: src/core/string_set.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include "core/dense_set.h" extern "C" { #include "redis/sds.h" } namespace dfly { class PageUsage; class StringSet : public DenseSet { public: StringSet() = default; ~StringSet(); // Returns true if elem was added. bool Add(std::string_view s1, uint32_t ttl_sec = UINT32_MAX); unsigned AddMany(absl::Span span, uint32_t ttl_sec, bool keepttl); bool Erase(std::string_view str) { return EraseInternal(&str, 1); } bool Contains(std::string_view s1) const { return FindInternal(&s1, Hash(&s1, 1), 1) != nullptr; } class iterator : private IteratorBase { public: using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = sds; using pointer = sds*; using reference = sds&; explicit iterator(const IteratorBase& o) : IteratorBase(o) { } iterator() : IteratorBase() { } iterator(DenseSet* set) : IteratorBase(set, false) { } iterator& operator++() { Advance(); return *this; } bool operator==(const iterator& b) const { if (owner_ == nullptr && b.owner_ == nullptr) { // to allow comparison with end() return true; } return owner_ == b.owner_ && curr_entry_ == b.curr_entry_; } bool operator!=(const iterator& b) const { return !(*this == b); } value_type operator*() { return (value_type)curr_entry_->GetObject(); } value_type operator->() { return (value_type)curr_entry_->GetObject(); } using IteratorBase::ExpiryTime; using IteratorBase::HasExpiry; using IteratorBase::SetExpiryTime; // Try reducing memory fragmentation of the value by re-allocating. Returns true if // re-allocation happened. bool ReallocIfNeeded(PageUsage* page_usage); }; iterator begin() { return iterator{this}; } iterator end() { return iterator{}; } // See DenseSet::GetRandomIterator iterator GetRandomMember(); std::optional Pop(); uint32_t Scan(uint32_t, const std::function&) const; iterator Find(std::string_view member) { return iterator{FindIt(&member, 1)}; } protected: uint64_t Hash(const void* ptr, uint32_t cookie) const override; unsigned AddBatch(absl::Span span, uint32_t ttl_sec, bool keepttl); bool ObjEqual(const void* left, const void* right, uint32_t right_cookie) const override; size_t ObjectAllocSize(const void* s1) const override; uint32_t ObjExpireTime(const void* obj) const override; void ObjUpdateExpireTime(const void* obj, uint32_t ttl_sec) override; void ObjDelete(void* obj, bool has_ttl) const override; void* ObjectClone(const void* obj, bool has_ttl, bool add_ttl) const override; sds MakeSetSds(std::string_view src, uint32_t ttl_sec) const; private: std::pair DuplicateEntryIfFragmented(void* obj, PageUsage* page_usage); }; } // end namespace dfly ================================================ FILE: src/core/string_set_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/string_set.h" #include #include #include #include #include #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "core/compact_object.h" #include "core/page_usage/page_usage_stats.h" #include "redis/sds.h" extern "C" { #include "redis/zmalloc.h" } namespace dfly { using namespace std; using absl::StrCat; class DenseSetAllocator : public PMR_NS::memory_resource { public: bool all_freed() const { return alloced_ == 0; } void* do_allocate(size_t bytes, size_t alignment) override { alloced_ += bytes; void* p = PMR_NS::new_delete_resource()->allocate(bytes, alignment); return p; } void do_deallocate(void* p, size_t bytes, size_t alignment) override { alloced_ -= bytes; return PMR_NS::new_delete_resource()->deallocate(p, bytes, alignment); } bool do_is_equal(const PMR_NS::memory_resource& other) const noexcept override { return PMR_NS::new_delete_resource()->is_equal(other); } private: size_t alloced_ = 0; }; class StringSetTest : public ::testing::Test { protected: static void SetUpTestSuite() { auto* tlh = mi_heap_get_backing(); init_zmalloc_threadlocal(tlh); InitTLStatelessAllocMR(PMR_NS::get_default_resource()); } static void TearDownTestSuite() { } void SetUp() override { ss_ = new StringSet; generator_.seed(0); } void TearDown() override { delete ss_; // ensure there are no memory leaks after every test EXPECT_TRUE(alloc_.all_freed()); EXPECT_EQ(zmalloc_used_memory_tl, 0); } StringSet* ss_; DenseSetAllocator alloc_; mt19937 generator_; }; TEST_F(StringSetTest, Basic) { EXPECT_TRUE(ss_->Add("foo"sv)); EXPECT_TRUE(ss_->Add("bar"sv)); EXPECT_FALSE(ss_->Add("foo"sv)); EXPECT_FALSE(ss_->Add("bar"sv)); EXPECT_TRUE(ss_->Contains("foo"sv)); EXPECT_TRUE(ss_->Contains("bar"sv)); EXPECT_EQ(2, ss_->UpperBoundSize()); } TEST_F(StringSetTest, StandardAddErase) { EXPECT_TRUE(ss_->Add("@@@@@@@@@@@@@@@@")); EXPECT_TRUE(ss_->Add("A@@@@@@@@@@@@@@@")); EXPECT_TRUE(ss_->Add("AA@@@@@@@@@@@@@@")); EXPECT_TRUE(ss_->Add("AAA@@@@@@@@@@@@@")); EXPECT_TRUE(ss_->Add("AAAAAAAAA@@@@@@@")); EXPECT_TRUE(ss_->Add("AAAAAAAAAA@@@@@@")); EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAA@")); EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAAA")); EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAAD")); EXPECT_TRUE(ss_->Add("BBBBBAAAAAAAAAAA")); EXPECT_TRUE(ss_->Add("BBBBBBBBAAAAAAAA")); EXPECT_TRUE(ss_->Add("CCCCCBBBBBBBBBBB")); // Remove link in the middle of chain EXPECT_TRUE(ss_->Erase("BBBBBBBBAAAAAAAA")); // Remove start of a chain EXPECT_TRUE(ss_->Erase("CCCCCBBBBBBBBBBB")); // Remove end of link EXPECT_TRUE(ss_->Erase("AAA@@@@@@@@@@@@@")); // Remove only item in chain EXPECT_TRUE(ss_->Erase("AA@@@@@@@@@@@@@@")); EXPECT_TRUE(ss_->Erase("AAAAAAAAA@@@@@@@")); EXPECT_TRUE(ss_->Erase("AAAAAAAAAA@@@@@@")); EXPECT_TRUE(ss_->Erase("AAAAAAAAAAAAAAA@")); } TEST_F(StringSetTest, DisplacedBug) { string_view vals[] = {"imY", "OVl", "NhH", "BCe", "YDL", "lpb", "nhF", "xod", "zYR", "PSa", "hce", "cTR"}; ss_->AddMany(absl::MakeSpan(vals), UINT32_MAX, false); ss_->Add("fIc"); ss_->Erase("YDL"); ss_->Add("fYs"); ss_->Erase("hce"); ss_->Erase("nhF"); ss_->Add("dye"); ss_->Add("xZT"); ss_->Add("LVK"); ss_->Erase("zYR"); ss_->Erase("fYs"); ss_->Add("ueB"); ss_->Erase("PSa"); ss_->Erase("OVl"); ss_->Add("cga"); ss_->Add("too"); ss_->Erase("ueB"); ss_->Add("HZe"); ss_->Add("oQn"); ss_->Erase("too"); ss_->Erase("HZe"); ss_->Erase("xZT"); ss_->Erase("cga"); ss_->Erase("cTR"); ss_->Erase("BCe"); ss_->Add("eua"); ss_->Erase("lpb"); ss_->Add("OXK"); ss_->Add("QmO"); ss_->Add("SzV"); ss_->Erase("QmO"); ss_->Add("jbe"); ss_->Add("BPN"); ss_->Add("OfH"); ss_->Add("Muf"); ss_->Add("CwP"); ss_->Erase("Muf"); ss_->Erase("xod"); ss_->Add("Cis"); ss_->Add("Xvd"); ss_->Erase("SzV"); ss_->Erase("eua"); ss_->Add("DGb"); ss_->Add("leD"); ss_->Add("MVX"); ss_->Add("HPq"); } static string random_string(mt19937& rand, unsigned len) { const string_view alpanum = "1234567890abcdefghijklmnopqrstuvwxyz"; string ret; ret.reserve(len); for (size_t i = 0; i < len; ++i) { ret += alpanum[rand() % alpanum.size()]; } return ret; } TEST_F(StringSetTest, Resizing) { constexpr size_t num_strs = 4096; unordered_set strs; while (strs.size() != num_strs) { auto str = random_string(generator_, 10); strs.insert(str); } unsigned size = 0; for (auto it = strs.begin(); it != strs.end(); ++it) { const auto& str = *it; EXPECT_TRUE(ss_->Add(str, 1)); EXPECT_EQ(ss_->UpperBoundSize(), size + 1); // make sure we haven't lost any items after a grow // which happens every power of 2 if ((size & (size - 1)) == 0) { for (auto j = strs.begin(); j != it; ++j) { const auto& str = *j; auto it = ss_->Find(str); ASSERT_TRUE(it != ss_->end()); EXPECT_TRUE(it.HasExpiry()); EXPECT_EQ(it.ExpiryTime(), ss_->time_now() + 1); } } ++size; } } TEST_F(StringSetTest, SimpleScan) { unordered_set info = {"foo", "bar"}; unordered_set seen; for (auto str : info) { EXPECT_TRUE(ss_->Add(str)); } uint32_t cursor = 0; do { cursor = ss_->Scan(cursor, [&](const sds ptr) { sds s = (sds)ptr; string_view str{s, sdslen(s)}; EXPECT_TRUE(info.count(str)); seen.insert(str); }); } while (cursor != 0); EXPECT_TRUE(seen.size() == info.size() && equal(seen.begin(), seen.end(), info.begin())); } // Ensure REDIS scan guarantees are met TEST_F(StringSetTest, ScanGuarantees) { unordered_set to_be_seen = {"foo", "bar"}; unordered_set not_be_seen = {"AAA", "BBB"}; unordered_set maybe_seen = {"AA@@@@@@@@@@@@@@", "AAA@@@@@@@@@@@@@", "AAAAAAAAA@@@@@@@", "AAAAAAAAAA@@@@@@"}; unordered_set seen; auto scan_callback = [&](const sds ptr) { sds s = (sds)ptr; string_view str{s, sdslen(s)}; EXPECT_TRUE(to_be_seen.count(str) || maybe_seen.count(str)); EXPECT_FALSE(not_be_seen.count(str)); if (to_be_seen.count(str)) { seen.insert(str); } }; EXPECT_EQ(ss_->Scan(0, scan_callback), 0); for (auto str : not_be_seen) { EXPECT_TRUE(ss_->Add(str)); } for (auto str : not_be_seen) { EXPECT_TRUE(ss_->Erase(str)); } for (auto str : to_be_seen) { EXPECT_TRUE(ss_->Add(str)); } // should reach at least the first item in the set uint32_t cursor = ss_->Scan(0, scan_callback); for (auto str : maybe_seen) { EXPECT_TRUE(ss_->Add(str)); } while (cursor != 0) { cursor = ss_->Scan(cursor, scan_callback); } EXPECT_TRUE(seen.size() == to_be_seen.size()); } TEST_F(StringSetTest, IntOnly) { constexpr size_t num_ints = 8192; unordered_set numbers; for (size_t i = 0; i < num_ints; ++i) { numbers.insert(i); EXPECT_TRUE(ss_->Add(to_string(i))); } for (size_t i = 0; i < num_ints; ++i) { ASSERT_FALSE(ss_->Add(to_string(i))); } size_t num_remove = generator_() % 4096; unordered_set removed; for (size_t i = 0; i < num_remove; ++i) { auto remove_int = generator_() % num_ints; auto remove = to_string(remove_int); if (numbers.count(remove_int)) { ASSERT_TRUE(ss_->Contains(remove)) << remove_int; EXPECT_TRUE(ss_->Erase(remove)); numbers.erase(remove_int); } else { EXPECT_FALSE(ss_->Erase(remove)); } EXPECT_FALSE(ss_->Contains(remove)); removed.insert(remove); } size_t expected_seen = 0; auto scan_callback = [&](const sds ptr) { string str{ptr, sdslen(ptr)}; EXPECT_FALSE(removed.count(str)); if (numbers.count(atoi(str.data()))) { ++expected_seen; } }; uint32_t cursor = 0; do { cursor = ss_->Scan(cursor, scan_callback); // randomly throw in some new numbers uint32_t val = generator_(); VLOG(1) << "Val " << val; ss_->Add(to_string(val)); } while (cursor != 0); EXPECT_GE(expected_seen + removed.size(), num_ints); } TEST_F(StringSetTest, XtremeScanGrow) { unordered_set to_see, force_grow, seen; while (to_see.size() != 8) { to_see.insert(random_string(generator_, 10)); } while (force_grow.size() != 8192) { string str = random_string(generator_, 10); if (to_see.count(str)) { continue; } force_grow.insert(random_string(generator_, 10)); } for (auto& str : to_see) { EXPECT_TRUE(ss_->Add(str)); } auto scan_callback = [&](const sds ptr) { sds s = (sds)ptr; string_view str{s, sdslen(s)}; if (to_see.count(string(str))) { seen.insert(string(str)); } }; uint32_t cursor = ss_->Scan(0, scan_callback); // force approx 10 grows for (auto& s : force_grow) { EXPECT_TRUE(ss_->Add(s)); } while (cursor != 0) { cursor = ss_->Scan(cursor, scan_callback); } EXPECT_EQ(seen.size(), to_see.size()); } TEST_F(StringSetTest, Pop) { constexpr size_t num_items = 8; unordered_set to_insert; while (to_insert.size() != num_items) { auto str = random_string(generator_, 10); if (to_insert.count(str)) { continue; } to_insert.insert(str); EXPECT_TRUE(ss_->Add(str)); } while (!ss_->Empty()) { size_t size = ss_->UpperBoundSize(); auto str = ss_->Pop(); DCHECK(ss_->UpperBoundSize() == to_insert.size() - 1); DCHECK(str.has_value()); DCHECK(to_insert.count(str.value())); DCHECK_EQ(ss_->UpperBoundSize(), size - 1); to_insert.erase(str.value()); } DCHECK(ss_->Empty()); DCHECK(to_insert.empty()); } TEST_F(StringSetTest, Iteration) { ss_->Add("foo"); for (const sds ptr : *ss_) { LOG(INFO) << ptr; } ss_->Clear(); constexpr size_t num_items = 8192; unordered_set to_insert; while (to_insert.size() != num_items) { auto str = random_string(generator_, 10); if (to_insert.count(str)) { continue; } to_insert.insert(str); EXPECT_TRUE(ss_->Add(str)); } for (const sds ptr : *ss_) { string str{ptr, sdslen(ptr)}; EXPECT_TRUE(to_insert.count(str)); to_insert.erase(str); } EXPECT_EQ(to_insert.size(), 0); } TEST_F(StringSetTest, SetFieldExpireHasExpiry) { EXPECT_TRUE(ss_->Add("k1", 100)); auto k = ss_->Find("k1"); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 100); k.SetExpiryTime(1); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 1); } TEST_F(StringSetTest, SetFieldExpireNoHasExpiry) { EXPECT_TRUE(ss_->Add("k1")); auto k = ss_->Find("k1"); EXPECT_FALSE(k.HasExpiry()); k.SetExpiryTime(10); EXPECT_TRUE(k.HasExpiry()); EXPECT_EQ(k.ExpiryTime(), 10); } TEST_F(StringSetTest, Ttl) { EXPECT_TRUE(ss_->Add("bla"sv, 1)); EXPECT_FALSE(ss_->Add("bla"sv, 1)); auto it = ss_->Find("bla"sv); EXPECT_EQ(1u, it.ExpiryTime()); ss_->set_time(1); EXPECT_TRUE(ss_->Add("bla"sv, 1)); EXPECT_EQ(1u, ss_->UpperBoundSize()); for (unsigned i = 0; i < 100; ++i) { EXPECT_TRUE(ss_->Add(StrCat("foo", i), 1)); } EXPECT_EQ(101u, ss_->UpperBoundSize()); it = ss_->Find("foo50"); EXPECT_STREQ("foo50", *it); EXPECT_EQ(2u, it.ExpiryTime()); ss_->set_time(2); for (unsigned i = 0; i < 100; ++i) { EXPECT_TRUE(ss_->Add(StrCat("bar", i))); } it = ss_->Find("bar50"); EXPECT_FALSE(it.HasExpiry()); for (auto it = ss_->begin(); it != ss_->end(); ++it) { ASSERT_TRUE(absl::StartsWith(*it, "bar")) << *it; string str = *it; VLOG(1) << *it; } } TEST_F(StringSetTest, Grow) { for (size_t j = 0; j < 10; ++j) { for (size_t i = 0; i < 4098; ++i) { ss_->Reserve(generator_() % 256); auto str = random_string(generator_, 3); ss_->Add(str); } ss_->Clear(); } } TEST_F(StringSetTest, Reserve) { vector strs; for (size_t i = 0; i < 10; ++i) { strs.push_back(random_string(generator_, 10)); ss_->Add(strs.back()); } for (size_t j = 2; j < 20; j += 3) { ss_->Reserve(j * 20); for (size_t i = 0; i < 10; ++i) { ASSERT_TRUE(ss_->Contains(strs[i])); } } } TEST_F(StringSetTest, Fill) { for (size_t i = 0; i < 100; ++i) { ss_->Add(random_string(generator_, 10)); } StringSet s2; ss_->Fill(&s2); EXPECT_EQ(s2.UpperBoundSize(), ss_->UpperBoundSize()); for (sds str : *ss_) { EXPECT_TRUE(s2.Contains(str)); } } TEST_F(StringSetTest, ClearResetsObjMallocUsed) { // Add some items for (size_t i = 0; i < 100; ++i) { ss_->Add(random_string(generator_, 10)); } // Verify ObjMallocUsed() > 0 after adding items EXPECT_GT(ss_->ObjMallocUsed(), 0u); EXPECT_GT(ss_->UpperBoundSize(), 0u); // Clear the set ss_->Clear(); // Verify ObjMallocUsed() is reset to 0 after Clear EXPECT_EQ(ss_->ObjMallocUsed(), 0u); EXPECT_EQ(ss_->UpperBoundSize(), 0u); } TEST_F(StringSetTest, IterateEmpty) { for (const auto& s : *ss_) { // We're iterating to make sure there is no crash. However, if we got here, it's a bug CHECK(false) << "Found entry " << s << " in empty set"; } } static size_t MemUsed(StringSet& obj) { return obj.ObjMallocUsed() + obj.SetMallocUsed(); } void BM_Clone(benchmark::State& state) { vector strs; mt19937 generator(0); StringSet ss1, ss2; unsigned elems = state.range(0); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, 10); ss1.Add(str); } ss2.Reserve(ss1.UpperBoundSize()); while (state.KeepRunning()) { for (auto src : ss1) { ss2.Add(src); } state.PauseTiming(); ss2.Clear(); ss2.Reserve(ss1.UpperBoundSize()); state.ResumeTiming(); } } BENCHMARK(BM_Clone)->ArgName("elements")->Arg(32000); void BM_Fill(benchmark::State& state) { unsigned elems = state.range(0); vector strs; mt19937 generator(0); StringSet ss1, ss2; for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, 10); ss1.Add(str); } while (state.KeepRunning()) { ss1.Fill(&ss2); state.PauseTiming(); ss2.Clear(); state.ResumeTiming(); } } BENCHMARK(BM_Fill)->ArgName("elements")->Arg(32000); void BM_Clear(benchmark::State& state) { unsigned elems = state.range(0); mt19937 generator(0); StringSet ss; while (state.KeepRunning()) { state.PauseTiming(); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, 16); ss.Add(str); } state.ResumeTiming(); ss.Clear(); } } BENCHMARK(BM_Clear)->ArgName("elements")->Arg(32000); void BM_Add(benchmark::State& state) { vector strs; mt19937 generator(0); StringSet ss; unsigned elems = state.range(0); unsigned keySize = state.range(1); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, keySize); strs.push_back(str); } ss.Reserve(elems); size_t mem_used = 0; while (state.KeepRunning()) { for (auto& str : strs) ss.Add(str); state.PauseTiming(); mem_used += MemUsed(ss); ss.Clear(); ss.Reserve(elems); state.ResumeTiming(); } state.counters["Memory_Used"] = mem_used / state.iterations(); } BENCHMARK(BM_Add) ->ArgNames({"elements", "Key Size"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_AddMany(benchmark::State& state) { vector strs; mt19937 generator(0); StringSet ss; unsigned elems = state.range(0); unsigned keySize = state.range(1); for (size_t i = 0; i < elems; ++i) { string str = random_string(generator, keySize); strs.push_back(str); } ss.Reserve(elems); vector svs; for (const auto& str : strs) { svs.push_back(str); } size_t mem_used = 0; while (state.KeepRunning()) { ss.AddMany(absl::MakeSpan(svs), UINT32_MAX, false); state.PauseTiming(); CHECK_EQ(ss.UpperBoundSize(), elems); mem_used += MemUsed(ss); ss.Clear(); ss.Reserve(elems); state.ResumeTiming(); } state.counters["Memory_Used"] = mem_used / state.iterations(); } BENCHMARK(BM_AddMany) ->ArgNames({"elements", "Key Size"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_Erase(benchmark::State& state) { std::vector strs; mt19937 generator(0); StringSet ss; auto elems = state.range(0); auto keySize = state.range(1); for (long int i = 0; i < elems; ++i) { std::string str = random_string(generator, keySize); strs.push_back(str); ss.Add(str); } state.counters["Memory_Before_Erase"] = MemUsed(ss); size_t mem_used = 0; while (state.KeepRunning()) { for (auto& str : strs) { ss.Erase(str); } state.PauseTiming(); mem_used += MemUsed(ss); for (auto& str : strs) { ss.Add(str); } state.ResumeTiming(); } state.counters["Memory_After_Erase"] = mem_used / state.iterations(); } BENCHMARK(BM_Erase) ->ArgNames({"elements", "Key Size"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_Get(benchmark::State& state) { std::vector strs; mt19937 generator(0); StringSet ss; auto elems = state.range(0); auto keySize = state.range(1); for (long int i = 0; i < elems; ++i) { std::string str = random_string(generator, keySize); strs.push_back(str); ss.Add(str); } while (state.KeepRunning()) { for (auto& str : strs) { ss.Find(str); } } } BENCHMARK(BM_Get) ->ArgNames({"elements", "Key Size"}) ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); void BM_Grow(benchmark::State& state) { vector strs; mt19937 generator(0); StringSet src; unsigned elems = 1 << 18; for (size_t i = 0; i < elems; ++i) { src.Add(random_string(generator, 16), UINT32_MAX); strs.push_back(random_string(generator, 16)); } while (state.KeepRunning()) { state.PauseTiming(); StringSet tmp; src.Fill(&tmp); CHECK_EQ(tmp.BucketCount(), elems); state.ResumeTiming(); for (const auto& str : strs) { tmp.Add(str); if (tmp.BucketCount() > elems) { break; // we grew } } CHECK_GT(tmp.BucketCount(), elems); } } BENCHMARK(BM_Grow); void BM_Spop1000(benchmark::State& state) { mt19937 generator(0); StringSet src; unsigned elems = 1 << 14; for (size_t i = 0; i < elems; ++i) { src.Add(random_string(generator, 16), UINT32_MAX); } auto sparseness = state.range(0); while (state.KeepRunning()) { state.PauseTiming(); StringSet tmp; src.Fill(&tmp); tmp.Reserve(elems * sparseness); state.ResumeTiming(); for (int i = 0; i < 1000; ++i) { tmp.Pop(); } } } BENCHMARK(BM_Spop1000)->ArgName("sparseness")->ArgsProduct({{1, 4, 10}}); unsigned total_wasted_memory = 0; TEST_F(StringSetTest, ReallocIfNeeded) { auto build_str = [](size_t i) { return to_string(i) + string(131, 'a'); }; auto count_waste = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { size_t used = block_size * area->used; total_wasted_memory += area->committed - used; return true; }; for (size_t i = 0; i < 10'000; i++) ss_->Add(build_str(i)); for (size_t i = 0; i < 10'000; i++) { if (i % 10 == 0) continue; ss_->Erase(build_str(i)); } mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_before = total_wasted_memory; size_t underutilized = 0; PageUsage page_usage{CollectPageStats::NO, 0.9}; for (auto it = ss_->begin(); it != ss_->end(); ++it) { underutilized += page_usage.IsPageForObjectUnderUtilized(*it); it.ReallocIfNeeded(&page_usage); } // Check there are underutilized pages CHECK_GT(underutilized, 0u); total_wasted_memory = 0; mi_heap_collect(mi_heap_get_backing(), true); mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); size_t wasted_after = total_wasted_memory; // Check we waste significanlty less now EXPECT_GT(wasted_before, wasted_after * 2); EXPECT_EQ(ss_->UpperBoundSize(), 1000); for (size_t i = 0; i < 1000; i++) EXPECT_EQ(*ss_->Find(build_str(i * 10)), build_str(i * 10)); } TEST_F(StringSetTest, TransferTTLFlagLinkToObjectOnDelete) { for (size_t i = 0; i < 10; i++) { EXPECT_TRUE(ss_->Add(absl::StrCat(i), 1)); } for (size_t i = 0; i < 9; i++) { EXPECT_TRUE(ss_->Erase(absl::StrCat(i))); } auto it = ss_->Find("9"sv); EXPECT_TRUE(it.HasExpiry()); EXPECT_EQ(1u, it.ExpiryTime()); } class ShrinkTest : public StringSetTest, public ::testing::WithParamInterface {}; TEST_P(ShrinkTest, BasicShrink) { constexpr size_t num_strs = 1000000; size_t shrink_to = GetParam(); vector strs; for (size_t i = 0; i < num_strs; ++i) { strs.push_back(random_string(generator_, 10)); EXPECT_TRUE(ss_->Add(strs.back())); } // Grow to a larger size ss_->Reserve(1 << 22); size_t original_bucket_count = ss_->BucketCount(); EXPECT_EQ(original_bucket_count, 1u << 22); // Shrink to the parameterized size ss_->Shrink(shrink_to); EXPECT_EQ(ss_->BucketCount(), shrink_to); EXPECT_EQ(ss_->UpperBoundSize(), num_strs); // Verify all elements are still accessible for (const auto& str : strs) { EXPECT_TRUE(ss_->Contains(str)) << "Missing: " << str; } } INSTANTIATE_TEST_SUITE_P(ShrinkSizes, ShrinkTest, ::testing::Values(1u << 21, // 2M buckets (sparse) 1u << 20, // 1M buckets (~1 per bucket) 1u << 19), // 512K buckets (~2 per bucket) [](const auto& info) { return absl::StrCat("buckets_", info.param); }); TEST_F(StringSetTest, ShrinkWithTTL) { constexpr size_t num_strs = 1000000; // Track elements by their TTL category vector expired_strs; // TTL 1-50, will expire vector surviving_strs; // TTL 51-100, will survive vector no_ttl_strs; // No TTL, will survive for (size_t i = 0; i < num_strs; ++i) { string str = random_string(generator_, 10); if (i % 3 == 0) { // No TTL EXPECT_TRUE(ss_->Add(str)); no_ttl_strs.push_back(str); } else if (i % 3 == 1) { // TTL 1-50 (will expire when time=50) uint32_t ttl = (i % 50) + 1; EXPECT_TRUE(ss_->Add(str, ttl)); expired_strs.push_back(str); } else { // TTL 51-100 (will survive when time=50) uint32_t ttl = (i % 50) + 51; EXPECT_TRUE(ss_->Add(str, ttl)); surviving_strs.push_back(str); } } // Grow to larger size ss_->Reserve(1 << 22); // Set time to 50 - this will expire elements with TTL <= 50 ss_->set_time(50); // Shrink ss_->Shrink(1 << 21); EXPECT_EQ(ss_->BucketCount(), 1u << 21); // Verify expired elements are gone for (const auto& str : expired_strs) { EXPECT_EQ(ss_->Find(str), ss_->end()) << "Should be expired: " << str; } // Verify surviving TTL elements are still accessible with correct TTL for (const auto& str : surviving_strs) { auto it = ss_->Find(str); ASSERT_NE(it, ss_->end()) << "Missing surviving TTL element: " << str; EXPECT_TRUE(it.HasExpiry()); EXPECT_GT(it.ExpiryTime(), 50u); } // Verify no-TTL elements are still accessible for (const auto& str : no_ttl_strs) { auto it = ss_->Find(str); ASSERT_NE(it, ss_->end()) << "Missing no-TTL element: " << str; EXPECT_FALSE(it.HasExpiry()); } } TEST_F(StringSetTest, ScanWithShrinkBetweenCalls) { // Test that cursor-based scanning works correctly when Grow and Shrink happen between Scan calls // This verifies SCAN guarantees: elements present at start and end of scan must be seen constexpr size_t num_strs = 1000000; vector strs; unordered_set must_see; // Add elements and track them for (size_t i = 0; i < num_strs; ++i) { strs.push_back(random_string(generator_, 10)); EXPECT_TRUE(ss_->Add(strs.back())); must_see.insert(strs.back()); } // Note initial bucket count (will be ~1M after adding 1M elements) size_t initial_bucket_count = ss_->BucketCount(); unordered_set seen; auto scan_callback = [&](const sds ptr) { string str{ptr, sdslen(ptr)}; seen.insert(str); }; // Start scanning BEFORE Grow uint32_t cursor = ss_->Scan(0, scan_callback); EXPECT_NE(cursor, 0u) << "Should not finish in one iteration"; // Grow to large size in the middle of scanning ss_->Reserve(1 << 22); EXPECT_EQ(ss_->BucketCount(), 1u << 22); EXPECT_GT(ss_->BucketCount(), initial_bucket_count); // Continue scanning a bit after Grow cursor = ss_->Scan(cursor, scan_callback); // Now Shrink in the middle of scanning - this is the key test // Elements that existed at scan start must still be visible ss_->Shrink(1 << 21); EXPECT_EQ(ss_->BucketCount(), 1u << 21); // Continue scanning with the same cursor constexpr int max_iterations = 1 << 22; int iterations = 0; while (cursor != 0 && iterations < max_iterations) { cursor = ss_->Scan(cursor, scan_callback); iterations++; } EXPECT_LT(iterations, max_iterations) << "Hit iteration limit"; EXPECT_EQ(cursor, 0u) << "Scan should complete"; // Verify all original elements were seen for (const auto& str : must_see) { EXPECT_TRUE(seen.count(str)) << "Missing element after shrink: " << str; } EXPECT_EQ(seen.size(), must_see.size()) << "Should see exactly all original elements"; } } // namespace dfly ================================================ FILE: src/core/task_queue.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/task_queue.h" #include #include "base/logging.h" using namespace std; using namespace util::fb2; namespace dfly { __thread unsigned TaskQueue::blocked_submitters_ = 0; TaskQueue::TaskQueue(unsigned queue_size, unsigned start_size, unsigned pool_max_size) : queue_(queue_size), consumer_fibers_(start_size) { CHECK_GT(start_size, 0u); CHECK_LE(start_size, pool_max_size); } void TaskQueue::Start(std::string_view base_name) { for (size_t i = 0; i < consumer_fibers_.size(); ++i) { auto& fb = consumer_fibers_[i]; CHECK(!fb.IsJoinable()); string name = absl::StrCat(base_name, "/", i); fb = Fiber(Fiber::Opts{.priority = FiberPriority::HIGH, .name = name}, [this] { queue_.Run(); }); } } void TaskQueue::Shutdown() { queue_.Shutdown(); for (auto& fb : consumer_fibers_) fb.JoinIfNeeded(); } } // namespace dfly ================================================ FILE: src/core/task_queue.h ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include "util/fibers/fiberqueue_threadpool.h" #include "util/fibers/fibers.h" namespace dfly { /** * MPSC task-queue that is handled by a single consumer thread. * The queue is just a wrapper around FiberQueue that manages its fiber itself. */ class TaskQueue { public: // TODO: to add a mechanism to moderate pool size. Currently it's static with pool_start_size. TaskQueue(unsigned queue_size, unsigned pool_start_size, unsigned pool_max_size); template bool TryAdd(F&& f) { return queue_.TryAdd(std::forward(f)); } // Returns true if task queue was blocked when adding the task. template bool Add(F&& f) { if (queue_.TryAdd(std::forward(f))) return false; ++blocked_submitters_; auto res = queue_.Add(std::forward(f)); --blocked_submitters_; return res; } template auto Await(F&& f) -> decltype(f()) { util::fb2::Done done; using ResultType = decltype(f()); util::detail::ResultMover mover; ++blocked_submitters_; Add([&mover, f = std::forward(f), done]() mutable { mover.Apply(f); done.Notify(); }); --blocked_submitters_; done.Wait(); return std::move(mover).get(); } /** * @brief Start running consumer loop in the caller thread by spawning fibers. * Returns immediately. */ void Start(std::string_view base_name); /** * @brief Notifies Run() function to empty the queue and to exit and waits for the consumer * fiber to finish. */ void Shutdown(); static unsigned blocked_submitters() { return blocked_submitters_; } private: util::fb2::FiberQueue queue_; std::vector consumer_fibers_; static __thread unsigned blocked_submitters_; }; } // namespace dfly ================================================ FILE: src/core/tiering_types.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/tiering_types.h" #include "redis/redis_aux.h" namespace dfly::tiering { auto FragmentRef::GetDescr(const CompactValue* pv) -> SerializationDescr { switch (pv->ObjType()) { case OBJ_STRING: { if (!pv->HasAllocated()) return {}; auto strs = pv->GetRawString(); return {strs, CompactObj::ExternalRep::STRING}; } case OBJ_HASH: { if (pv->Encoding() == kEncodingListPack) { return {static_cast(pv->RObjPtr()), CompactObj::ExternalRep::SERIALIZED_MAP}; } return {}; } default: return {}; }; } TieredCoolRecord* FragmentRef::GetCoolRecord() const { return std::visit( [](auto* pv) -> TieredCoolRecord* { return pv->IsExternal() && pv->IsCool() ? pv->GetCool().record : nullptr; }, val_); } } // namespace dfly::tiering ================================================ FILE: src/core/tiering_types.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "core/compact_object.h" namespace dfly::tiering { // TieredCoolRecord is part of the cooling cache. It allows offloading values to disk // while still keeping some of them in-memory to avoid disk reads in case they are requested again // soon after offloading. When a value is moved to the cold storage, TieredCoolRecord and only // the external reference is kept. When the value is warmed up, the record is removed from the cool // storage and the value is read back to memory. struct TieredCoolRecord : public ::boost::intrusive::list_base_hook< boost::intrusive::link_mode> { uint64_t key_hash; // Allows searching the entry in the dbslice. CompactValue value; uint16_t db_index; uint32_t page_index; }; static_assert(sizeof(TieredCoolRecord) == 48); class FragmentRef { public: // Describes how this fragment should be serialized for offloading. // Used by stashing flow. struct SerializationDescr { std::variant, uint8_t*> blob; CompactObj::ExternalRep rep = CompactObj::ExternalRep::STRING; }; FragmentRef(CompactValue& pv) : val_(&pv) { // NOLINT } FragmentRef(CompactValue* pv) : val_(pv) { // NOLINT } bool IsOffloaded() const { return std::visit([](auto* pv) { return pv->IsExternal(); }, val_); } // Resets offloaded state for this fragment. void ClearOffloaded() { std::visit([](auto* pv) { pv->RemoveExternal(); }, val_); } bool HasStashPending() const { return std::visit([](auto* pv) { return pv->HasStashPending(); }, val_); } void ClearStashPending() { std::visit([](auto* pv) { pv->SetStashPending(false); }, val_); } CompactObjType ObjType() const { return std::visit([](auto* pv) { return pv->ObjType(); }, val_); } // Determine required byte size and encoding type based on value. SerializationDescr GetSerializationDescr() const { return std::visit([](auto* pv) { return GetDescr(pv); }, val_); } // Returns a pointer to TieredCoolRecord if this fragment is cool, and null otherwise. TieredCoolRecord* GetCoolRecord() const; // Returns the external slice of the offloaded value. Only valid if IsOffloaded() is true. std::pair GetExternalSlice() const { return std::visit([](auto* pv) { return pv->GetExternalSlice(); }, val_); } private: static SerializationDescr GetDescr(const CompactValue* pv); // TODO: to support more types, for example Node* from qlist.h. std::variant val_; }; } // namespace dfly::tiering ================================================ FILE: src/core/top_keys.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/top_keys.h" #include #include "absl/numeric/bits.h" #include "absl/random/distributions.h" #include "base/logging.h" namespace dfly { using namespace std; TopKeys::TopKeys(Options options) : options_(options), fingerprints_(options_.buckets * options_.depth) { if (options_.min_key_count_to_record < 2) { options_.min_key_count_to_record = 2; } } void TopKeys::Touch(std::string_view key) { auto ResetCell = [&](Cell& cell, uint64_t fingerprint) { cell.fingerprint = fingerprint; cell.count = 1; cell.key.clear(); }; uint64_t fingerprint = XXH3_64bits(key.data(), key.size()); constexpr uint64_t kPrime = 0xff51afd7ed558ccd; for (uint64_t id = 0; id < options_.depth; ++id) { const unsigned bucket = fingerprint % options_.buckets; fingerprint *= kPrime; Cell& cell = GetCell(id, bucket); if (cell.count == 0) { // No fingerprint in cell. ResetCell(cell, fingerprint); } else if (cell.fingerprint == fingerprint) { // Same fingerprint, simply increment count. // We could make sure that, if !cell.key.empty(), then key == cell.key.empty() here. However, // what do we do in case they are different? ++cell.count; if (cell.count >= options_.min_key_count_to_record && cell.key.empty()) { cell.key = key; } } else { // Different fingerprint, apply exponential decay. const double rand = absl::Uniform(bitgen_, 0, 1.0); if (rand < std::pow(options_.decay_base, -static_cast(cell.count))) { --cell.count; if (cell.count == 0) { ResetCell(cell, fingerprint); } } } } } absl::flat_hash_map TopKeys::GetTopKeys() const { absl::flat_hash_map results; for (unsigned array = 0; array < options_.depth; ++array) { for (unsigned bucket = 0; bucket < options_.buckets; ++bucket) { const Cell& cell = GetCell(array, bucket); if (!cell.key.empty()) { auto [it, added] = results.emplace(cell.key, cell.count); if (!added && it->second < cell.count) { it->second = cell.count; } } } } return results; } TopKeys::Cell& TopKeys::GetCell(uint32_t d, uint32_t bucket) { DCHECK(d < options_.depth); DCHECK(bucket < options_.buckets); return fingerprints_[d * options_.buckets + bucket]; } const TopKeys::Cell& TopKeys::GetCell(uint32_t d, uint32_t bucket) const { DCHECK(d < options_.depth); DCHECK(bucket < options_.buckets); return fingerprints_[d * options_.buckets + bucket]; } } // end of namespace dfly ================================================ FILE: src/core/top_keys.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "base/random.h" namespace dfly { // INTERNAL USE ONLY: This class is an optimized, O(1) probabilistic hot-key tracker designed // specifically to run on the database's hot path (e.g., tracking hot keys using DEBUG TOPK). // It cannot and should not be used for user-facing Redis TOPK commands. It intentionally // omits a Min-Heap (preventing instant eviction reporting), does not support arbitrary // increments, and does not use PMR allocators (which are required for strict memory // tracking and RDB serialization of user data). // // For the public Redis TOPK module API, use the `TOPK` class defined in `core/topk.h`. // // TopKeys is a utility class that helps determine the most frequently used keys. // Based on: HeavyKeeper paper, https://www.usenix.org/conference/atc18/presentation/gong // // Usage: // - Instantiate this class with proper options (see below) // - For every used key k, call Touch(k) // - At some point(s) in time, call GetTopKeys() to get an estimated list of top keys along with // their approximate count (i.e. how many times Touch() was invoked for them). // // Notes: // - This class implements a slightly modified version of HeavyKeeper, a data structure designed // for a similar problem domain. The modification made is to store the keys directly within the // tables, when they meet a certain threshold, instead of using a min-heap. // - This class is statistical in nature. Do *not* expect accurate counts. // - When misconfigured, real top keys may be missing from GetTopKeys(). This can occur when there // are too few buckets, or when min_key_count_to_record is too high, depending on actual usage. class TopKeys { TopKeys(const TopKeys&) = delete; TopKeys& operator=(const TopKeys&) = delete; public: struct Options { // HeavyKeeper options uint32_t buckets = 1 << 16; uint32_t depth = 4; // What is the minimum times Touch() has to be called for a given key in order for the key to be // saved. Use lower values when load is low, or higher values when load is high. The cost of a // low value for high load is frequent string copying and memory allocation. // Min value: 2 uint32_t min_key_count_to_record = 50; double decay_base = 1.08; }; explicit TopKeys(Options options); void Touch(std::string_view key); absl::flat_hash_map GetTopKeys() const; private: // Each cell consists of a key-fingerprint, a count, and potentially the key itself, when it's // above options_.min_key_count_to_record. struct Cell { uint64_t fingerprint = 0; uint64_t count = 0; std::string key; }; Cell& GetCell(uint32_t d, uint32_t bucket); const Cell& GetCell(uint32_t d, uint32_t bucket) const; Options options_; base::Xoroshiro128p bitgen_; // fingerprints_'s size is options_.buckets * options_.arrays. Always access fields via GetCell(). std::vector fingerprints_; }; } // end of namespace dfly ================================================ FILE: src/core/top_keys_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/top_keys.h" #include #include #include "base/gtest.h" #include "base/logging.h" using ::testing::Pair; using ::testing::UnorderedElementsAre; namespace dfly { TEST(TopKeysTest, Basic) { TopKeys top_keys({.min_key_count_to_record = 2}); top_keys.Touch("key1"); top_keys.Touch("key1"); top_keys.Touch("key2"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 2))); } TEST(TopKeysTest, MultiTouch) { TopKeys top_keys({.min_key_count_to_record = 2}); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre()); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 2))); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 3))); } TEST(TopKeysTest, MinKeyCountToRecord) { TopKeys top_keys({.min_key_count_to_record = 3}); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre()); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre()); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 3))); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 4))); top_keys.Touch("key1"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 5))); } TEST(TopKeysTest, MultiKeys) { TopKeys top_keys({.min_key_count_to_record = 2}); for (int i = 0; i < 2; ++i) { top_keys.Touch("key1"); top_keys.Touch("key2"); } top_keys.Touch("key3"); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 2), Pair("key2", 2))); } TEST(TopKeysTest, BucketCollision) { TopKeys top_keys({.buckets = 1, .min_key_count_to_record = 1}); for (int i = 0; i < 5; ++i) { top_keys.Touch("key1"); } EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 5))); for (int i = 0; i < 100; ++i) { top_keys.Touch("key2"); } auto top_keys_table = top_keys.GetTopKeys(); EXPECT_EQ(top_keys_table.size(), 1); EXPECT_LE(top_keys_table["key2"], 100); EXPECT_GE(top_keys_table["key2"], 50); // Touching "key1" should *not* replace "key2". top_keys.Touch("key1"); EXPECT_FALSE(top_keys.GetTopKeys().contains("key1")); } TEST(TopKeysTest, BucketCollisionAggressiveDecay) { TopKeys top_keys({.buckets = 1, .min_key_count_to_record = 2, .decay_base = 1.0}); for (int i = 0; i < 5; ++i) { top_keys.Touch("key1"); } EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 5))); for (int i = 0; i < 100; ++i) { top_keys.Touch("key2"); } EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key2", 96))); } TEST(TopKeysTest, BucketCollisionHesitantDecay) { TopKeys top_keys({.buckets = 1, .min_key_count_to_record = 2, .decay_base = 1000.0}); for (int i = 0; i < 5; ++i) { top_keys.Touch("key1"); } EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 5))); for (int i = 0; i < 100; ++i) { top_keys.Touch("key2"); } // "key2" will never replace "key1", as the decay practically never happens (1000^-5) EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key1", 5))); } TEST(TopKeysTest, SavedByMultipleArrays) { // This test is not trivial. It tests that having multiple arrays inside TopKeys saves keys in // case of collision. The way it does it is by inserting an arbitrary key (= "key"), and then (at // runtime) finding another key which *does* collide with that key. // // Once we've found such a key, we create another TopKeys instance, but this time with 10 arrays // which should mean that for some hash value, the keys won't be present in the same bucket. std::string collision_key; TopKeys::Options options( {.buckets = 2, .depth = 1, .min_key_count_to_record = 2, .decay_base = 1}); { TopKeys top_keys(options); // Insert some key top_keys.Touch("key"); top_keys.Touch("key"); // Find a key with a collision int i = 0; while (true) { collision_key = absl::StrCat("key", i); top_keys.Touch(collision_key); if (!top_keys.GetTopKeys().contains(collision_key)) { break; } ++i; } } options.depth = 10; { TopKeys top_keys(options); // Insert some key top_keys.Touch("key"); top_keys.Touch("key"); // Insert collision key, expect result to be present top_keys.Touch(collision_key); top_keys.Touch(collision_key); EXPECT_THAT(top_keys.GetTopKeys(), UnorderedElementsAre(Pair("key", 2), Pair(collision_key, 2))); } } } // end of namespace dfly ================================================ FILE: src/core/topk.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/topk.h" #include #include #include #include #include #include "absl/random/distributions.h" #include "base/logging.h" #include "base/random.h" namespace dfly { namespace { const std::array& GetDefaultDecayTable() { static const auto table = [] { std::array t{}; for (size_t i = 0; i < TOPK::kDecayLookupSize; ++i) { t[i] = std::pow(TOPK::kDefaultDecay, static_cast(i)); } return t; }(); return table; } } // namespace TOPK::TOPK(PMR_NS::memory_resource* mr, uint32_t k, uint32_t width, uint32_t depth, double decay) : k_(k), width_(width), depth_(depth), decay_(decay), counters_(static_cast(width) * depth, 0, PMR_NS::polymorphic_allocator(mr)), min_heap_(PMR_NS::polymorphic_allocator(mr)) { DCHECK(mr != nullptr); DCHECK_GT(k_, 0u); DCHECK_GT(width_, 0u); DCHECK_GT(depth_, 0u); DCHECK_GE(decay_, 0.0); DCHECK_LE(decay_, 1.0); min_heap_.reserve(k_); if (std::abs(decay_ - TOPK::kDefaultDecay) < TOPK::kDecayEpsilon) { // default decay value: use shared static table to save memory and initialization time decay_lookup_ = &GetDefaultDecayTable(); } else { // custom decay value: build a dedicated table for this instance custom_decay_table_ = std::make_unique>(); for (size_t i = 0; i < TOPK::kDecayLookupSize; ++i) { (*custom_decay_table_)[i] = std::pow(decay_, static_cast(i)); } decay_lookup_ = custom_decay_table_.get(); } } TOPK::TOPK(TOPK&& other) noexcept : k_(std::exchange(other.k_, 0)), width_(std::exchange(other.width_, 0)), depth_(std::exchange(other.depth_, 0)), decay_(std::exchange(other.decay_, 0.0)), decay_lookup_(std::exchange(other.decay_lookup_, nullptr)), custom_decay_table_(std::move(other.custom_decay_table_)), counters_(std::move(other.counters_)), min_heap_(std::move(other.min_heap_)) { } TOPK& TOPK::operator=(TOPK&& other) noexcept { if (this != &other) { k_ = std::exchange(other.k_, 0); width_ = std::exchange(other.width_, 0); depth_ = std::exchange(other.depth_, 0); decay_ = std::exchange(other.decay_, 0.0); decay_lookup_ = std::exchange(other.decay_lookup_, nullptr); custom_decay_table_ = std::move(other.custom_decay_table_); counters_ = std::move(other.counters_); min_heap_ = std::move(other.min_heap_); } return *this; } uint64_t TOPK::Hash(std::string_view item, uint32_t row) const { auto full_hash = XXH3_64bits_withSeed(item.data(), item.size(), row); // Lemire's Fast Range Reduction avoids the expensive CPU integer division penalty of the modulo // (%) operator. The main principle: multiplication is much faster than division, so we multiply // a 32-bit slice of the hash by the width, and then shift right by 32 bits to get the bucket // index. See: https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ uint32_t hash32 = static_cast(full_hash); uint64_t bucket = (static_cast(hash32) * width_) >> 32; DCHECK_LT(bucket, width_); return bucket; } double TOPK::ComputeDecayProbability(uint32_t count) const { DCHECK(decay_lookup_); DCHECK_GT(count, 0u); const auto& table = *decay_lookup_; if (count < kDecayLookupSize) { return table[count]; } // If the probability is already less than kDecayEpsilon, the chance of decay is // statistically zero (see ShouldDecay). Skip the expensive std::pow extrapolation entirely. if (table[TOPK::kDecayLookupSize - 1] < TOPK::kDecayEpsilon) { return 0.0; } // Extrapolate probabilities for counts that exceed our lookup table's max index. // Let M = the maximum table index (kDecayLookupSize - 1) // Let Q = the quotient (count / M) // Let R = the remainder (count % M) // // Using the Laws of Exponents, we break down decay^count: // decay^count = decay^((Q * M) + R) = (decay^M)^Q * decay^R // // This translates directly to reusing our cached table: // std::pow(table[M], Q) * table[R] uint32_t quotient = count / (TOPK::kDecayLookupSize - 1); uint32_t remainder = count % (TOPK::kDecayLookupSize - 1); double base = table[TOPK::kDecayLookupSize - 1]; return std::pow(base, static_cast(quotient)) * table[remainder]; } bool TOPK::ShouldDecay(uint32_t current_count) const { if (current_count == 0) return false; // Exponential decay probability: decay^count thread_local base::Xoroshiro128p bitgen; double prob = ComputeDecayProbability(current_count); return absl::Uniform(bitgen, 0.0, 1.0) < prob; } void TOPK::HeapifyUp(size_t index) { DCHECK_LT(index, min_heap_.size()); // Restores the min-heap property by shifting the element at 'index' upward. // Triggered in two cases: // 1. Initial insertion: A new item is appended to the array and needs to bubble up. // 2. Count decrease: An existing item's count drops (becomes smaller), floating higher. while (index > 0) { size_t parent = (index - 1) / 2; if (min_heap_[parent].count <= min_heap_[index].count) { break; // Heap property satisfied } // Swap with parent std::swap(min_heap_[parent], min_heap_[index]); index = parent; } } void TOPK::HeapifyDown(size_t index) { DCHECK_LT(index, min_heap_.size()); // Restores the min-heap property by shifting the element at 'index' downward. // Triggered in two cases: // 1. Root replacement/removal: The minimum item is evicted/replaced and the new root must sink. // 2. Count increase: An existing item's count grows (becomes heavier), sinking lower. size_t size = min_heap_.size(); while (true) { size_t left = (2 * index) + 1; size_t right = (2 * index) + 2; size_t smallest = index; if ((left < size) && (min_heap_[left].count) < (min_heap_[smallest].count)) { smallest = left; } if ((right < size) && (min_heap_[right].count) < (min_heap_[smallest].count)) { smallest = right; } if (smallest == index) { break; // Heap property satisfied } // Swap with smallest child std::swap(min_heap_[smallest], min_heap_[index]); index = smallest; } } size_t TOPK::GetCounterIndex(std::string_view item, uint32_t row) const { DCHECK_LT(row, depth_); // Note: // - bucket is mathematically guaranteed to be in the range [0, width_ - 1] // - The max possible idx is depth * width - 1, which is within the bounds of our counters_ // vector uint64_t bucket = Hash(item, row); size_t idx = static_cast(row) * width_ + bucket; DCHECK_LT(idx, counters_.size()); return idx; } uint32_t TOPK::Count(std::string_view item) const { uint32_t min_count = std::numeric_limits::max(); for (uint32_t row = 0; row < depth_; ++row) { size_t idx = GetCounterIndex(item, row); min_count = std::min(min_count, counters_[idx]); } return min_count; } std::optional TOPK::IncrementInternal(std::string_view item, uint32_t increment) { uint32_t min_count = std::numeric_limits::max(); // Update counters using HeavyKeeper logic for (uint32_t row = 0; row < depth_; ++row) { size_t idx = GetCounterIndex(item, row); // HeavyKeeper: decay and increment are mutually exclusive. // - With probability decay^count, the counter is decremented (colliding items suppress each // other). // - Otherwise, the counter is incremented for the item being added. if ((counters_[idx] > 0) && ShouldDecay(counters_[idx])) { --counters_[idx]; } else { counters_[idx] = static_cast( std::min(static_cast(counters_[idx]) + increment, static_cast(std::numeric_limits::max()))); } // Count-Min Sketch property: The minimum counter across all rows is the // most accurate, as it has suffered the fewest hash collisions. min_count = std::min(min_count, counters_[idx]); } return UpdateHeap(item, min_count); } std::optional TOPK::Add(std::string_view item) { return IncrementInternal(item, 1); } std::optional TOPK::IncrBy(std::string_view item, uint32_t increment) { if (increment < 1) { return std::nullopt; } return IncrementInternal(item, increment); } std::vector TOPK::List() const { std::vector result; result.reserve(min_heap_.size()); for (const auto& heap_item : min_heap_) { result.push_back({heap_item.key, heap_item.count}); } // Sort by count (descending) for output std::sort(result.begin(), result.end(), [](const TopKItem& a, const TopKItem& b) { return a.count > b.count; }); return result; } std::optional TOPK::UpdateHeap(std::string_view item, uint32_t new_count) { // Fast path: O(K) linear scan. // For small K, this avoids hash map overhead. Short keys benefit from SSO // (Small String Optimization), keeping memory contiguous and cache-friendly. // TODO: Benchmark to find the crossover point where larger K OR long strings (SSO not applicable) // justify re-introducing a hash map. for (size_t i = 0; i < min_heap_.size(); ++i) { if (min_heap_[i].key == item) { uint32_t old_count = min_heap_[i].count; min_heap_[i].count = new_count; if (new_count > old_count) { HeapifyDown(i); } else if (new_count < old_count) { HeapifyUp(i); } return std::nullopt; } } // Fast reject: item doesn't qualify for the heap. Just exit without any memory allocations or // modifications. if ((min_heap_.size() >= k_) && (new_count <= min_heap_.front().count)) { return std::nullopt; } DCHECK_LE(min_heap_.size(), k_); // Slow path: item will enter the heap. Now allocate. std::string item_str(item); if (min_heap_.size() < k_) { // Heap not full, add the item, no eviction needed size_t new_idx = min_heap_.size(); min_heap_.push_back({std::move(item_str), new_count}); HeapifyUp(new_idx); return std::nullopt; } // Heap is full, evict minimum and add new item DCHECK_EQ(min_heap_.size(), k_); std::string old_key = std::move(min_heap_[0].key); min_heap_[0] = {std::move(item_str), new_count}; HeapifyDown(0); return old_key; } size_t TOPK::MallocUsed() const { size_t size = 0; // Custom decay table (only for non-default decay values) if (custom_decay_table_) { size += sizeof(std::array); } // Counter array size += counters_.capacity() * sizeof(uint32_t); // Heap items - calculate actual string sizes size += min_heap_.capacity() * sizeof(HeapItem); for (const auto& item : min_heap_) { size += item.key.capacity(); } return size; } TOPK::SerializedData TOPK::Serialize() const { SerializedData data; data.k = k_; data.width = width_; data.depth = depth_; data.decay = decay_; // Serialize heap items data.heap_items.reserve(min_heap_.size()); for (const auto& heap_item : min_heap_) { data.heap_items.push_back({heap_item.key, heap_item.count}); } // Serialize counter array data.counters.assign(counters_.begin(), counters_.end()); return data; } void TOPK::Deserialize(const SerializedData& data) { DCHECK_EQ(data.counters.size(), static_cast(width_) * depth_); DCHECK_LE(data.heap_items.size(), k_); DCHECK_EQ(data.k, k_); DCHECK_EQ(data.width, width_); DCHECK_EQ(data.depth, depth_); DCHECK_EQ(data.decay, decay_); // Clear existing data min_heap_.clear(); // Restore counters counters_.assign(data.counters.begin(), data.counters.end()); // Restore heap min_heap_.reserve(data.heap_items.size()); for (const auto& item : data.heap_items) { min_heap_.push_back({item.item, item.count}); } // Rebuild heap property std::make_heap(min_heap_.begin(), min_heap_.end(), std::greater()); } } // namespace dfly ================================================ FILE: src/core/topk.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include #include #include "base/pmr/memory_resource.h" namespace dfly { class TOPKTest; // // TOPK: User-Facing API Data Structure // // This class implements the data structure required to support the public Redis // TOPK module API (e.g., TOPK.RESERVE, TOPK.ADD, TOPK.INCRBY). // // WHY WE HAVE TWO TOP-K IMPLEMENTATIONS: // Dragonfly maintains two separate Top-K tracking structures to protect the // performance of the database's hot path: // 1. `TopKeys` (src/core/top_keys.h): An internal-only, hyper-optimized O(1) // tracker that runs on every single database command to detect hot keys. // It intentionally lacks a min-heap and uses standard memory allocation to // maximize raw speed and minimize instruction cache pollution. // 2. `TOPK` (this file): The user-facing implementation. To comply with the Redis // API contract, this class MUST support instant eviction reporting (requiring an // O(log K) Min-Heap), arbitrary increments, and PMR allocators for strict // memory limit tracking and RDB snapshot serialization. // // Forcing the internal tracker to support Min-Heaps and PMR would severely // degrade overall database throughput, hence the strict separation of concerns. // // Algorithm Deviation Note: // While heavily inspired by the HeavyKeeper algorithm, this is NOT a strict // implementation. The original HeavyKeeper paper requires storing a // (fingerprint, count) pair in each cell so that decay only penalizes a specific // item. This implementation uses a bare `uint32_t` counter grid, making it closer // to a Count-Min Sketch coupled with a Min-Heap and a decay heuristic. This // design safely overestimates counts (which is acceptable for Top-K bounds) // while simplifying PMR memory layout and RDB serialization. // // TODO: Full PMR Integration for String Ownership // Currently, min_heap_ and counters_ use the provided memory_resource, ensuring the // dominant allocations are tracked. However, the std::string keys inside HeapItem // use the default heap. // Future optimization: Upgrade HeapItem to use PMR_NS::string with proper // uses_allocator construction. class TOPK { friend class TOPKTest; public: // Initializes a Top-K tracking sketch with the specified dimensions. // // mr: Pointer to the memory resource used for allocations (MUST NOT be null). // k: Maximum number of most frequent items to maintain in the min-heap. // width: Number of counter buckets per row in the hash grid (default: 8). // depth: Number of independent hash functions (rows) used (default: 7). // decay: Probability multiplier for exponential decay (must be 0.0 to 1.0, default: 0.9). TOPK(PMR_NS::memory_resource* mr, uint32_t k, uint32_t width = kDefaultWidth, uint32_t depth = kDefaultDepth, double decay = kDefaultDecay); TOPK(const TOPK&) = delete; TOPK& operator=(const TOPK&) = delete; TOPK(TOPK&& other) noexcept; TOPK& operator=(TOPK&& other) noexcept; ~TOPK() = default; static constexpr double kDefaultDecay = 0.9; static constexpr uint32_t kDefaultWidth = 8; static constexpr uint32_t kDefaultDepth = 7; static constexpr double kDecayEpsilon = 1e-9; // Size is 4097 so that (kDecayLookupSize - 1) equals exactly 4096 (2^12). // This allows the C++ compiler to optimize the division and modulo operations // in the extrapolation hot-path into very-fast bitwise shifts & ANDs. static constexpr size_t kDecayLookupSize = 4097; // Represents an item in the Top-K list with its estimated count struct TopKItem { std::string item; uint32_t count; }; // Inserts a single item into the Top-K sketch, incrementing its estimated frequency by 1. // // Returns: The string of the evicted item if this insertion caused a resident // item to be displaced from the Top-K min-heap, or std::nullopt // if no eviction occurred. std::optional Add(std::string_view item); // Increments an item's estimated frequency by a specific amount. // // If 'increment' is 0, this operation is a safe no-op and returns std::nullopt. // Otherwise, returns the string of the evicted item if this operation caused // a resident item to be displaced from the Top-K min-heap, or std::nullopt. std::optional IncrBy(std::string_view item, uint32_t increment); // Queries whether an item currently resides in the Top-K min-heap. [[nodiscard]] bool Query(std::string_view item) const { return IsInHeap(item); } // Estimates the frequency count for an item using the underlying sketch. // Returns the minimum counter value across all hash rows (Count-Min Sketch estimate). [[nodiscard]] uint32_t Count(std::string_view item) const; // Retrieves the complete list of current Top-K high-frequency items. // // Returns: A vector of TopKItem structures (containing the key and its count), // sorted in descending order by estimated frequency (highest first). [[nodiscard]] std::vector List() const; // -------------------------------------------------------------------------- // Accessors for Top-K Configuration Parameters // -------------------------------------------------------------------------- // Returns the maximum capacity (K) of the Top-K min-heap. [[nodiscard]] uint32_t K() const { return k_; } // Returns the number of items currently tracked in the Top-K heap. [[nodiscard]] size_t Size() const { return min_heap_.size(); } // Returns the width (number of columns/buckets) of the Count-Min Sketch array. [[nodiscard]] uint32_t Width() const { return width_; } // Returns the depth (number of rows/hash functions) of the Count-Min Sketch array. [[nodiscard]] uint32_t Depth() const { return depth_; } // Returns the exponential decay probability base used by the HeavyKeeper algorithm. [[nodiscard]] double Decay() const { return decay_; } // Calculates the total heap memory dynamically allocated by this Top-K instance, // including sketch counters, min-heap allocations, and hash map overhead. // // Returns: Total memory usage in bytes. [[nodiscard]] size_t MallocUsed() const; // -------------------------------------------------------------------------- // Serialization and Persistence // -------------------------------------------------------------------------- // Pod-like structure to hold the exact internal state of the Top-K instance. struct SerializedData { uint32_t k; uint32_t width; uint32_t depth; double decay; std::vector heap_items; std::vector counters; }; // Extracts the current structural state of the sketch for RDB persistence. [[nodiscard]] SerializedData Serialize() const; // Reconstructs the internal state of the sketch from a previously serialized dataset. void Deserialize(const SerializedData& data); private: struct HeapItem { std::string key; uint32_t count; // Min heap comparator bool operator>(const HeapItem& other) const { return count > other.count; } }; // Hash function for bucket selection in row [[nodiscard]] uint64_t Hash(std::string_view item, uint32_t row) const; // Exponential decay logic [[nodiscard]] bool ShouldDecay(uint32_t current_count) const; // Updates the min-heap with the new count for the given item. // Returns the evicted item's key if the heap is at capacity and a new item displaces an existing // one. Otherwise, returns std::nullopt. std::optional UpdateHeap(std::string_view item, uint32_t new_count); // Check if an item is in the Top-K heap [[nodiscard]] bool IsInHeap(std::string_view item) const { for (const auto& heap_item : min_heap_) { if (heap_item.key == item) return true; } return false; } // Hashes the item for a specific row and calculates its flattened 1D index // within the counters_ array. Maps the 2D Count-Min Sketch grid (depth x width) // into a single contiguous block of memory for better CPU cache locality. size_t GetCounterIndex(std::string_view item, uint32_t row) const; // Shared increment logic std::optional IncrementInternal(std::string_view item, uint32_t increment); // Compute decay probability using lookup table or extrapolation double ComputeDecayProbability(uint32_t count) const; // Heap maintenance functions // O(log k) ops void HeapifyUp(size_t index); void HeapifyDown(size_t index); uint32_t k_; // Number of top items to track uint32_t width_; // Hash table width (buckets per row) uint32_t depth_; // Hash table depth (number of rows) double decay_; // Decay constant (0.0-1.0, typically 0.9) // Pointer to the active decay lookup table. For the default decay (0.9), this points to // a process-wide shared static table (32KB, allocated once). For custom (non-default) decay // values, it points to custom_decay_table_ below. This pattern can help to avoid embedding a 32KB // array in every TOPK object. // Assumption: >99% of TOPK instances will use the default decay, so // this optimization can significantly reduce memory usage and improve startup performance by // avoiding the need to build a custom table for each instance. const std::array* decay_lookup_ = nullptr; // Heap-allocated table for non-default decay values. Null for the common case (decay=0.9). std::unique_ptr> custom_decay_table_; // HeavyKeeper data structures // Hash table: width × depth matrix of counters std::vector> counters_; // Min heap: vector of top-K items maintained as a min heap std::vector> min_heap_; }; } // namespace dfly ================================================ FILE: src/core/topk_test.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/topk.h" #include #include #include #include #include #include #include "base/gtest.h" namespace dfly { using namespace std; class TOPKTest : public ::testing::Test { protected: // Use decay=0 to disable probabilistic decay, making tests deterministic. // With decay=0, ShouldDecay always returns false (0^count = 0 for count>0), // so counters only grow and are never decremented by colliding items. // Having a decay != 0 will cause probabilistic flakiness in tests, as items may be randomly // evicted due to decay rather than true count comparisons. TOPKTest() : topk_(PMR_NS::get_default_resource(), 5, 100, 5, 0.0) { } double ComputeDecayProbability(TOPK* topk, uint32_t count) const { return topk->ComputeDecayProbability(count); } TOPK topk_; }; // --------------------------------------------------------------------------- // Construction & Configuration // --------------------------------------------------------------------------- // Verify K(), Width(), Depth(), Decay() return the exact values passed to the constructor. TEST(TOPKBasic, ConstructorStoresParameters) { TOPK topk(PMR_NS::get_default_resource(), 10, 200, 7, 0.85); EXPECT_EQ(topk.K(), 10u); EXPECT_EQ(topk.Width(), 200u); EXPECT_EQ(topk.Depth(), 7u); EXPECT_DOUBLE_EQ(topk.Decay(), 0.85); } // Verify that default decay reuses the static process-wide table (saving memory), // while a custom decay value allocates its own ~32KB lookup table. TEST(TOPKBasic, DecayTableMemoryAllocation) { TOPK default_topk(PMR_NS::get_default_resource(), 5, 100, 5, TOPK::kDefaultDecay); TOPK custom_topk(PMR_NS::get_default_resource(), 5, 100, 5, 0.75); size_t default_mem = default_topk.MallocUsed(); size_t custom_mem = custom_topk.MallocUsed(); // Test that the custom one uses strictly more memory EXPECT_LT(default_mem, custom_mem); // Test that the difference in memory is exactly the size of the custom decay array size_t expected_table_size = TOPK::kDecayLookupSize * sizeof(double); EXPECT_GE(custom_mem - default_mem, expected_table_size); } // Move-construct a populated TOPK; source should be emptied and destination should hold the items. TEST_F(TOPKTest, MoveConstructorTransfersOwnership) { topk_.Add("alpha"); topk_.Add("beta"); TOPK moved(std::move(topk_)); EXPECT_EQ(moved.K(), 5u); auto list = moved.List(); EXPECT_FALSE(list.empty()); // Source is zeroed out. EXPECT_EQ(topk_.K(), 0u); } // Move-assign a populated TOPK into another; verify same post-conditions as move constructor. TEST(TOPKBasic, MoveAssignmentTransfersOwnership) { TOPK src(PMR_NS::get_default_resource(), 3, 50, 3, 0.0); src.Add("x"); src.Add("y"); TOPK dst(PMR_NS::get_default_resource(), 1, 10, 1, 0.0); dst = std::move(src); EXPECT_EQ(dst.K(), 3u); EXPECT_EQ(dst.Width(), 50u); auto list = dst.List(); EXPECT_EQ(list.size(), 2u); EXPECT_EQ(src.K(), 0u); } // --------------------------------------------------------------------------- // Add // --------------------------------------------------------------------------- // Add exactly K distinct items; List() should return exactly K items with no evictions. TEST_F(TOPKTest, AddFillsHeapUpToK) { for (uint32_t i{}; i < topk_.K(); ++i) { auto evicted = topk_.Add(absl::StrCat("item", i)); EXPECT_FALSE(evicted.has_value()) << "Unexpected eviction at i=" << i; } EXPECT_EQ(topk_.List().size(), topk_.K()); } // Each Add() while the heap has room returns std::nullopt. // Note: adding a K+1th item with the same count as the minimum also returns nullopt, // because the fast-reject path correctly requires new_count > min to trigger an eviction. TEST_F(TOPKTest, AddReturnsNulloptWhileHeapNotFull) { for (uint32_t i{}; i < topk_.K(); ++i) { EXPECT_EQ(topk_.Add(absl::StrCat("item", i)), nullopt); } } // After filling the heap, IncrBy a new item with a large count to force an eviction. TEST_F(TOPKTest, AddEvictsMinimumWhenHeapFull) { // Fill the heap with K items, each added once (count=1). for (uint32_t i{}; i < topk_.K(); ++i) { topk_.Add(absl::StrCat("filler", i)); } // Force a new item in with a large count; it must evict the minimum. auto evicted = topk_.IncrBy("heavy_hitter", 1000); EXPECT_TRUE(evicted.has_value()); } // After filling the heap, adding an item whose count can't exceed the minimum shouldn't evict. TEST_F(TOPKTest, AddDoesNotEvictWhenNewItemScoreTooLow) { // Fill the heap with items pumped to high counts. for (uint32_t i{}; i < topk_.K(); ++i) { topk_.IncrBy(absl::StrCat("big", i), 1000); } // Single add of a brand-new item (count=1) won't beat any existing item. auto evicted = topk_.Add("tiny_newcomer"); EXPECT_FALSE(evicted.has_value()); } // Adding the same item repeatedly increases its count in the heap. // Because decay=0.0 and there are no collisions, the count must be exactly 100. TEST_F(TOPKTest, AddSameItemRepeatedlyIncreasesCount) { for (int i{}; i < 100; ++i) { topk_.Add("repeat"); } auto list = topk_.List(); bool found = false; for (const auto& item : list) { if (item.item == "repeat") { EXPECT_EQ(item.count, 100u); found = true; } } EXPECT_TRUE(found); } // --------------------------------------------------------------------------- // IncrBy // --------------------------------------------------------------------------- // IncrBy with increment=0 must return nullopt and not modify state. TEST_F(TOPKTest, IncrByZeroReturnsNullopt) { topk_.Add("existing"); auto before = topk_.Count("existing"); auto result = topk_.IncrBy("existing", 0); EXPECT_EQ(result, nullopt); auto after = topk_.Count("existing"); EXPECT_EQ(before, after); } // IncrBy(item, 1) should behave the same as Add(item) — both increment by 1. TEST(TOPKBasic, IncrByOneBehavesLikeAdd) { TOPK a(PMR_NS::get_default_resource(), 3, 100, 5, 0.0); TOPK b(PMR_NS::get_default_resource(), 3, 100, 5, 0.0); a.Add("x"); b.IncrBy("x", 1); EXPECT_EQ(a.Count("x"), b.Count("x")); } // A single IncrBy with a large increment should immediately promote the item into the heap, // evicting the current minimum. TEST_F(TOPKTest, IncrByLargeValueCausesImmediateEviction) { for (uint32_t i{}; i < topk_.K(); ++i) { topk_.Add(absl::StrCat("base", i)); } auto evicted = topk_.IncrBy("newcomer", 10000); EXPECT_TRUE(evicted.has_value()); EXPECT_TRUE(topk_.Query("newcomer")); } // IncrBy on an item already in the heap should increase its count without eviction. TEST_F(TOPKTest, IncrByExistingHeapItemUpdatesCount) { topk_.IncrBy("item_a", 50); auto count_before = topk_.Count("item_a"); auto evicted = topk_.IncrBy("item_a", 100); EXPECT_EQ(evicted, nullopt); auto count_after = topk_.Count("item_a"); EXPECT_GT(count_after, count_before); } // --------------------------------------------------------------------------- // Query // --------------------------------------------------------------------------- // All K items currently in the heap should return true from Query. TEST_F(TOPKTest, QueryReturnsTrueForHeapItems) { for (uint32_t i{}; i < topk_.K(); ++i) { string key = absl::StrCat("key", i); topk_.Add(key); EXPECT_TRUE(topk_.Query(key)) << key << " should be in heap"; } } // Items that were never inserted should return false from Query. TEST_F(TOPKTest, QueryReturnsFalseForNonHeapItems) { EXPECT_FALSE(topk_.Query("never_seen")); EXPECT_FALSE(topk_.Query("also_absent")); EXPECT_FALSE(topk_.Query("nope")); } // An item that was once in the heap but got evicted should return false from Query. TEST_F(TOPKTest, QueryReturnsFalseForEvictedItems) { // Add our target victim. Count = 1. string victim = "low0"; topk_.Add(victim); // Fill the rest of the heap (K=5) with items that are heavier. for (uint32_t i{1}; i < topk_.K(); ++i) { topk_.IncrBy(absl::StrCat("heavier", i), 50); } // Verify the victim is currently in the heap. EXPECT_TRUE(topk_.Query(victim)); // Evict by adding a massive item. topk_.IncrBy("massive", 10000); // Strictly assert that the victim is gone. EXPECT_FALSE(topk_.Query(victim)); } // Mixed: item in heap vs item not in heap. TEST_F(TOPKTest, QueryMixedBatch) { topk_.IncrBy("inheap", 100); EXPECT_TRUE(topk_.Query("inheap")); EXPECT_FALSE(topk_.Query("notheap")); } // --------------------------------------------------------------------------- // Count // --------------------------------------------------------------------------- // Items never inserted should return count 0. TEST_F(TOPKTest, CountReturnsZeroForUnseen) { EXPECT_EQ(topk_.Count("never_added"), 0u); EXPECT_EQ(topk_.Count("also_missing"), 0u); } // Items that have been added should return a count >= 1. TEST_F(TOPKTest, CountReturnsNonZeroForSeenItems) { topk_.Add("seen"); EXPECT_GE(topk_.Count("seen"), 1u); } // The count from Count() for a heap item should match the count reported in List(). TEST_F(TOPKTest, CountForHeapItemMatchesListCount) { topk_.IncrBy("match_me", 50); auto count_val = topk_.Count("match_me"); auto list = topk_.List(); bool found = false; for (const auto& item : list) { if (item.item == "match_me") { EXPECT_EQ(item.count, count_val); found = true; } } EXPECT_TRUE(found); } // --------------------------------------------------------------------------- // List // --------------------------------------------------------------------------- // List() returns an empty vector on a freshly constructed TOPK. TEST(TOPKBasic, ListEmptyOnConstruction) { TOPK fresh(PMR_NS::get_default_resource(), 5, 100, 5, 0.0); EXPECT_TRUE(fresh.List().empty()); } // List() output is sorted in descending order by count. TEST_F(TOPKTest, ListReturnsSortedByCountDescending) { topk_.IncrBy("low", 10); topk_.IncrBy("mid", 50); topk_.IncrBy("high", 100); auto list = topk_.List(); // 1. Guarantee the items actually returned ASSERT_EQ(list.size(), 3u); // 2. Exact match the deterministic order EXPECT_EQ(list[0].item, "high"); EXPECT_EQ(list[0].count, 100u); EXPECT_EQ(list[1].item, "mid"); EXPECT_EQ(list[1].count, 50u); EXPECT_EQ(list[2].item, "low"); EXPECT_EQ(list[2].count, 10u); } // After inserting more than K distinct items, List().size() == K. TEST_F(TOPKTest, ListNeverExceedsKItems) { for (int i{}; i < 100; ++i) { topk_.IncrBy(absl::StrCat("x", i), (i + 1) * 10); } // We inserted 100 items. The heap MUST be exactly full. EXPECT_EQ(topk_.List().size(), topk_.K()); } // --------------------------------------------------------------------------- // Decay & ComputeDecayProbability // --------------------------------------------------------------------------- // For count < kDecayLookupSize, ComputeDecayProbability equals std::pow(decay, count). TEST_F(TOPKTest, ProbabilityBelowTableSize) { double decay_val = 0.85; TOPK topk(PMR_NS::get_default_resource(), 5, 100, 5, decay_val); // ComputeDecayProbability enforces DCHECK_GT(count, 0u), so we start at 1. for (uint32_t count = 1; count < TOPK::kDecayLookupSize; ++count) { double expected = std::pow(decay_val, static_cast(count)); // EXPECT_DOUBLE_EQ allows up to 4 ULPs of rounding difference. EXPECT_DOUBLE_EQ(ComputeDecayProbability(&topk, count), expected); } } // For count >= kDecayLookupSize, the extrapolation path should not crash or produce NaN. TEST(TOPKBasic, ProbabilityAboveTableSizeNoCrash) { TOPK topk(PMR_NS::get_default_resource(), 3, 10, 3, 0.999); // Push counter safely above kDecayLookupSize (4097) topk.IncrBy("big", 5000); // 2. NOW call Add. This forces ShouldDecay(5000) to execute! // It shouldn't crash, segfault, or produce NaN. for (int i = 0; i < 10; ++i) { topk.Add("big"); } // Just verify the state isn't corrupted (count is still around 5000) EXPECT_GT(topk.Count("big"), 4000u); } // For an extremely large count with a small decay, probability drops to effectively zero. // This means ShouldDecay always returns false for very high counts, so counters aren't decremented. TEST(TOPKBasic, VeryHighCountApproachesZero) { // decay=0.5: 0.5^4096 is astronomically small (< kDecayEpsilon). The extrapolation // path should return 0.0, meaning no decay fires for counts above the table range. TOPK topk(PMR_NS::get_default_resource(), 3, 10, 3, 0.5); topk.IncrBy("stable", 10000); auto count_before = topk.Count("stable"); // Adding more items should not decay "stable"'s counter because the decay // probability for such high counts is effectively zero. for (int i{}; i < 100; ++i) { topk.Add(absl::StrCat("other", i)); } auto count_after = topk.Count("stable"); // Count may increase from hash collisions but should never decrease. EXPECT_GE(count_after, count_before); } // With decay=0.0, the decay probability is always 0 (0^n = 0 for n>0), // so counters should grow monotonically. TEST(TOPKBasic, ZeroDecayNeverDecays) { TOPK topk(PMR_NS::get_default_resource(), 3, 50, 3, 0.0); topk.IncrBy("mono", 100); auto count1 = topk.Count("mono"); topk.IncrBy("mono", 50); auto count2 = topk.Count("mono"); EXPECT_GE(count2, count1); EXPECT_EQ(count2, 150u); } // With decay=1.0, every non-zero counter has ShouldDecay probability exactly 1.0 (1^n = 1). // Because this implementation uses no fingerprints (unlike the original HeavyKeeper paper), // decay fires even when re-adding the same item to its own non-zero counter. // The counter therefore oscillates: 0 → 1 (add to zero-counter) → 0 (decay fires) → repeat. // It is mathematically impossible for the counter to exceed 1. TEST(TOPKBasic, DecayOneAlwaysDecays) { TOPK topk(PMR_NS::get_default_resource(), 3, 10, 3, 1.0); for (int i{}; i < 1000; ++i) { topk.Add("suppressed"); } // Because decay is 100%, the counter just oscillates between 0 and 1. // It is mathematically impossible for it to exceed 1. EXPECT_LE(topk.Count("suppressed"), 1u); } // --------------------------------------------------------------------------- // MallocUsed // --------------------------------------------------------------------------- // MallocUsed() after filling the heap should be larger than right after construction. TEST(TOPKBasic, MallocUsedIncreaseWithHeapGrowth) { TOPK topk(PMR_NS::get_default_resource(), 5, 100, 5, 0.0); size_t before = topk.MallocUsed(); for (int i{}; i < 5; ++i) { topk.IncrBy(absl::StrCat("item_with_a_long_name_", i), 100); } size_t after = topk.MallocUsed(); EXPECT_GT(after, before); } // --------------------------------------------------------------------------- // Serialize / Deserialize // --------------------------------------------------------------------------- // After Serialize() + Deserialize(), K(), Width(), Depth(), Decay() are unchanged. TEST_F(TOPKTest, SerializeRoundTripPreservesConfiguration) { topk_.IncrBy("a", 10); auto data = topk_.Serialize(); TOPK restored(PMR_NS::get_default_resource(), data.k, data.width, data.depth, data.decay); restored.Deserialize(data); EXPECT_EQ(restored.K(), topk_.K()); EXPECT_EQ(restored.Width(), topk_.Width()); EXPECT_EQ(restored.Depth(), topk_.Depth()); EXPECT_DOUBLE_EQ(restored.Decay(), topk_.Decay()); } // After round-trip, List() returns the same items with the same counts. TEST_F(TOPKTest, SerializeRoundTripPreservesHeapItems) { topk_.IncrBy("alpha", 100); topk_.IncrBy("beta", 50); topk_.IncrBy("gamma", 25); auto data = topk_.Serialize(); TOPK restored(PMR_NS::get_default_resource(), data.k, data.width, data.depth, data.decay); restored.Deserialize(data); auto orig_list = topk_.List(); auto rest_list = restored.List(); ASSERT_EQ(orig_list.size(), rest_list.size()); for (size_t i{}; i < orig_list.size(); ++i) { EXPECT_EQ(orig_list[i].item, rest_list[i].item); EXPECT_EQ(orig_list[i].count, rest_list[i].count); } } // After round-trip, Count() returns the same estimated frequencies. TEST_F(TOPKTest, SerializeRoundTripPreservesCounters) { topk_.IncrBy("foo", 42); topk_.IncrBy("bar", 77); auto data = topk_.Serialize(); TOPK restored(PMR_NS::get_default_resource(), data.k, data.width, data.depth, data.decay); restored.Deserialize(data); EXPECT_EQ(topk_.Count("foo"), restored.Count("foo")); EXPECT_EQ(topk_.Count("bar"), restored.Count("bar")); } // After Deserialize(), subsequent Add() calls work correctly and evictions are reported. TEST_F(TOPKTest, DeserializeRebuildsValidHeapProperty) { for (uint32_t i{}; i < topk_.K(); ++i) { topk_.IncrBy(absl::StrCat("pre", i), 10); } auto data = topk_.Serialize(); TOPK restored(PMR_NS::get_default_resource(), data.k, data.width, data.depth, data.decay); restored.Deserialize(data); // The restored heap is full (K items). A heavy new item should evict the minimum. auto evicted = restored.IncrBy("post_restore_big", 10000); EXPECT_TRUE(evicted.has_value()); EXPECT_TRUE(restored.Query("post_restore_big")); } // Serializing a fresh TOPK produces empty heap_items and a zero-filled counters vector. TEST(TOPKBasic, SerializeEmptyTOPK) { TOPK topk(PMR_NS::get_default_resource(), 5, 100, 5, 0.0); auto data = topk.Serialize(); EXPECT_TRUE(data.heap_items.empty()); EXPECT_EQ(data.counters.size(), 100u * 5); for (auto c : data.counters) { EXPECT_EQ(c, 0u); } } // --------------------------------------------------------------------------- // PMR Allocator // --------------------------------------------------------------------------- // Explicitly passing get_default_resource() works correctly without crashing. TEST(TOPKBasic, PMRExplicitDefaultResourceWorks) { TOPK topk(PMR_NS::get_default_resource(), 5, 100, 5, 0.9); topk.Add("works"); EXPECT_EQ(topk.List().size(), 1u); } // --------------------------------------------------------------------------- // Statistical / Accuracy // --------------------------------------------------------------------------- // Verify that the Top-K correctly identifies "Hot" items even when // the sketch is flooded with "Cold" noise (many items seen only once). // // SETUP: // 1. We disable Decay (decay=0.0) to make the test 100% predictable (no RNG). // 2. We use IncrBy to give 5 "Hot" items a guaranteed high score of 1000. // 3. We use Add to insert 200 "Cold" items once each (score of 1). // // WHY INCRBY? // In a real-world scenario with decay, an item's count eventually hits a // "ceiling" where decay and growth balance out. By using IncrBy and decay=0, // we bypass that math to ensure our "Hot" items are strictly, // deterministically larger than the noise. TEST(TOPKBasic, TopKItemsIdentifiedUnderHeavyLoad) { TOPK topk(PMR_NS::get_default_resource(), 5, 500, 5, 0.0); // Hot items get a large, deterministic count via IncrBy. for (int h{}; h < 5; ++h) { topk.IncrBy(absl::StrCat("hot", h), 1000); } // Cold items are each seen only once. for (int c{}; c < 200; ++c) { topk.Add(absl::StrCat("cold", c)); } auto list = topk.List(); ASSERT_EQ(list.size(), 5u); // All 5 hot items should be present in the top-K list. for (int h{}; h < 5; ++h) { string hot_key = absl::StrCat("hot", h); bool found{}; for (const auto& item : list) { if (item.item == hot_key) { found = true; break; } } EXPECT_TRUE(found) << hot_key << " should be in the top-K list"; } } // With k=1, only the single most-frequent item survives in the heap. // Uses decay=0.0 and IncrBy so "dominant" has a deterministically high count // that minor items (each added once, count=1) can never exceed. TEST(TOPKBasic, KEqualsOneTracksOnlyTopItem) { TOPK topk(PMR_NS::get_default_resource(), 1, 500, 5, 0.0); // "dominant" gets a large, fixed count. topk.IncrBy("dominant", 1000); // Minor items are each seen only once; count=1 < 1000, so none can displace dominant. for (int i{}; i < 50; ++i) { topk.Add(absl::StrCat("minor", i)); } auto list = topk.List(); ASSERT_EQ(list.size(), 1u); EXPECT_EQ(list[0].item, "dominant"); } // --------------------------------------------------------------------------- // Deserialization Heap Repair // --------------------------------------------------------------------------- // Deserialize() must call std::make_heap to restore the min-heap invariant even when // heap_items are stored out-of-order in the RDB snapshot (e.g. saved in List() order). TEST(TOPKBasic, DeserializeRestoresHeapProperty) { TOPK::SerializedData data; data.k = 5; data.width = 100; data.depth = 5; data.decay = 0.0; data.counters.resize(500, 0); // Items deliberately out of min-heap order: smallest must end up at the root. data.heap_items.push_back({"heavy", 1000}); data.heap_items.push_back({"medium", 500}); data.heap_items.push_back({"light", 10}); TOPK restored(PMR_NS::get_default_resource(), 5, 100, 5, 0.0); restored.Deserialize(data); // List() sorts descending — correct only if make_heap built a valid heap. auto list = restored.List(); ASSERT_EQ(list.size(), 3u); EXPECT_EQ(list[0].item, "heavy"); EXPECT_EQ(list[1].item, "medium"); EXPECT_EQ(list[2].item, "light"); // Heap is not yet full (3 of 5 slots used), so fill it to capacity. restored.IncrBy("filler1", 20); restored.IncrBy("filler2", 30); // Now heap is full (5 items: light=10, filler1=20, filler2=30, medium=500, heavy=1000). // A new item with count > 10 must evict "light" — the min-heap root. auto evicted = restored.IncrBy("newcomer", 50); ASSERT_TRUE(evicted.has_value()); EXPECT_EQ(evicted.value(), "light"); } // --------------------------------------------------------------------------- // Counter Saturation (Overflow Prevention) // --------------------------------------------------------------------------- // IncrBy must saturate at UINT32_MAX rather than wrapping around to 0. // A wrap-around would trick the heap into evicting a top item — a correctness // and security issue (malicious TOPK.INCRBY with a huge increment). TEST_F(TOPKTest, CounterSaturationPreventsOverflow) { const uint32_t max_val = numeric_limits::max(); topk_.IncrBy("max_item", max_val); EXPECT_EQ(topk_.Count("max_item"), max_val); // Adding more must not wrap the counter back to a small number. topk_.IncrBy("max_item", 100); EXPECT_EQ(topk_.Count("max_item"), max_val); } // --------------------------------------------------------------------------- // Death Tests (DCHECKs active in debug builds only) // --------------------------------------------------------------------------- #ifndef NDEBUG // k=0 violates DCHECK_GT(k_, 0u) in the constructor. TEST(TOPKDeathTest, ZeroKCrashes) { EXPECT_DEBUG_DEATH(TOPK(PMR_NS::get_default_resource(), 0, 100, 5, 0.9), "k_ > 0"); } // width=0 violates DCHECK_GT(width_, 0u) in the constructor. TEST(TOPKDeathTest, ZeroWidthCrashes) { EXPECT_DEBUG_DEATH(TOPK(PMR_NS::get_default_resource(), 5, 0, 5, 0.9), "width_ > 0"); } // decay=1.5 violates DCHECK_LE(decay_, 1.0) in the constructor. TEST(TOPKDeathTest, DecayAboveOneCrashes) { EXPECT_DEBUG_DEATH(TOPK(PMR_NS::get_default_resource(), 5, 100, 5, 1.5), "decay_ <= 1.0"); } // Deserializing data with a mismatched k violates DCHECK_EQ(data.k, k_). TEST(TOPKDeathTest, DeserializeDimensionMismatchCrashes) { TOPK topk(PMR_NS::get_default_resource(), 5, 100, 5, 0.9); TOPK::SerializedData bad; bad.k = 10; // Mismatch: object was constructed with k=5. bad.width = 100; bad.depth = 5; bad.decay = 0.9; bad.counters.resize(500, 0); EXPECT_DEBUG_DEATH(topk.Deserialize(bad), "data.k == k_"); } // Deserializing data with a mismatched decay violates DCHECK_EQ(data.decay, decay_). TEST(TOPKDeathTest, DeserializeDecayMismatchCrashes) { TOPK topk(PMR_NS::get_default_resource(), 5, 100, 5, 0.9); TOPK::SerializedData bad; bad.k = 5; bad.width = 100; bad.depth = 5; bad.decay = 0.5; // Mismatch: object was constructed with decay=0.9. bad.counters.resize(500, 0); EXPECT_DEBUG_DEATH(topk.Deserialize(bad), "data.decay == decay_"); } #endif } // namespace dfly ================================================ FILE: src/core/tx_queue.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "core/tx_queue.h" #include "base/logging.h" namespace dfly { TxQueue::TxQueue(std::function sf) : score_fun_(sf), vec_(32) { for (size_t i = 0; i < vec_.size(); ++i) { vec_[i].next = i + 1; } } auto TxQueue::Insert(Transaction* t) -> Iterator { if (next_free_ >= vec_.size()) { Grow(); } DCHECK_LT(next_free_, vec_.size()); DCHECK_EQ(FREE_TAG, vec_[next_free_].tag); Iterator res = next_free_; vec_[next_free_].u.trans = t; vec_[next_free_].tag = TRANS_TAG; DVLOG(1) << "Insert " << next_free_ << " " << t; LinkFree(score_fun_(t)); return res; } auto TxQueue::Insert(uint64_t val) -> Iterator { if (next_free_ >= vec_.size()) { Grow(); } DCHECK_LT(next_free_, vec_.size()); Iterator res = next_free_; vec_[next_free_].u.uval = val; vec_[next_free_].tag = UINT_TAG; LinkFree(val); return res; } void TxQueue::LinkFree(uint64_t weight) { uint32_t taken = next_free_; next_free_ = vec_[taken].next; if (size_ == 0) { head_ = taken; vec_[head_].next = vec_[head_].prev = head_; } else { uint32_t cur = vec_[head_].prev; while (true) { if (Rank(vec_[cur]) < weight) { Link(cur, taken); break; } if (cur == head_) { Link(vec_[head_].prev, taken); head_ = taken; break; } cur = vec_[cur].prev; } } ++size_; } void TxQueue::Grow() { size_t start = vec_.size(); DVLOG(1) << "Grow from " << start << " to " << start * 2; vec_.resize(start * 2); for (size_t i = start; i < vec_.size(); ++i) { vec_[i].next = i + 1; } } void TxQueue::Remove(Iterator it) { DCHECK_GT(size_, 0u); DCHECK_LT(it, vec_.size()); DCHECK_NE(FREE_TAG, vec_[it].tag); DVLOG(1) << "Remove " << it << " " << vec_[it].u.trans; Iterator next = kEnd; if (size_ > 1) { Iterator prev = vec_[it].prev; next = vec_[it].next; vec_[prev].next = next; vec_[next].prev = prev; } --size_; vec_[it].next = next_free_; vec_[it].tag = FREE_TAG; next_free_ = it; if (head_ == it) { head_ = next; } } uint64_t TxQueue::Rank(const QRecord& r) const { switch (r.tag) { case UINT_TAG: return r.u.uval; case TRANS_TAG: return score_fun_(r.u.trans); } return 0; } } // namespace dfly ================================================ FILE: src/core/tx_queue.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include namespace dfly { class Transaction; // TxQueue implemmented as a circular doubly-linked list. class TxQueue { void Link(uint32_t p, uint32_t n) { uint32_t next = vec_[p].next; vec_[n].next = next; vec_[n].prev = p; vec_[p].next = n; vec_[next].prev = n; } public: // uint64_t is used for unit-tests. using ValueType = std::variant; using Iterator = uint32_t; enum { kEnd = Iterator(-1) }; TxQueue(std::function score_fun = nullptr); // returns iterator to that item the list Iterator Insert(Transaction* t); Iterator Insert(uint64_t val); void Remove(Iterator); ValueType At(Iterator it) const { switch (vec_[it].tag) { case TRANS_TAG: return vec_[it].u.trans; case UINT_TAG: return vec_[it].u.uval; } return 0u; } ValueType Front() const { return At(head_); } void PopFront() { Remove(head_); } size_t size() const { return size_; } bool Empty() const { return size_ == 0; } //! returns the score of the tail record. Can be called only if !Empty(). uint64_t TailScore() const { return Rank(vec_[vec_[head_].prev]); } //! returns the score of the head record. Can be called only if !Empty(). uint64_t HeadScore() const { return Rank(vec_[head_]); } //! Can be called only if !Empty(). Iterator Head() const { return head_; } // Returns the next iterator, it's circular so it always returns a valid // iterator. Can be called only if !Empty(). Iterator Next(Iterator it) const { return vec_[it].next; } private: enum { TRANS_TAG = 0, UINT_TAG = 11, FREE_TAG = 12 }; void Grow(); void LinkFree(uint64_t rank); struct QRecord { union { Transaction* trans; uint64_t uval; } u; uint32_t tag : 8; uint32_t next : 24; uint32_t prev; QRecord() : tag(FREE_TAG), prev(kEnd) { } }; static_assert(sizeof(QRecord) == 16, ""); uint64_t Rank(const QRecord& r) const; std::function score_fun_; std::vector vec_; uint32_t next_free_ = 0, head_ = kEnd; size_t size_ = 0; TxQueue(const TxQueue&) = delete; }; } // namespace dfly ================================================ FILE: src/core/zstd_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #include #include #include "base/logging.h" namespace dfly { using namespace std; constexpr unsigned kLevel = 1; class ZStdTest : public ::testing::Test { protected: string Compress(const string& src, const ZSTD_CDict* cdict) { ZSTD_CCtx* cctx = ZSTD_createCCtx(); size_t c_buffer_size = ZSTD_compressBound(src.size()); string res(c_buffer_size, '\0'); size_t compressed_size = ZSTD_compress_usingCDict(cctx, res.data(), c_buffer_size, src.c_str(), src.size(), cdict); ZSTD_freeCCtx(cctx); res.resize(compressed_size); return res; } string Decompress(const string& src, const ZSTD_DDict* ddict, size_t decompressed_size) { string res(decompressed_size, '\0'); ZSTD_DCtx* dctx = ZSTD_createDCtx(); size_t decompressed_size_actual = ZSTD_decompress_usingDDict( dctx, res.data(), decompressed_size, src.c_str(), src.size(), ddict); CHECK_EQ(decompressed_size, decompressed_size_actual); ZSTD_freeDCtx(dctx); return res; } string CompressNoDict(const string& src) { ZSTD_CCtx* cctx = ZSTD_createCCtx(); size_t c_buffer_size = ZSTD_compressBound(src.size()); string res(c_buffer_size, '\0'); size_t compressed_size = ZSTD_compressCCtx(cctx, res.data(), c_buffer_size, src.c_str(), src.size(), kLevel); ZSTD_freeCCtx(cctx); res.resize(compressed_size); return res; } }; // Dictionary works well for small messages where we do not have enough data to reference // previous stream to have significant savings. // For large messages, it may not be less beneficial. TEST_F(ZStdTest, Dict) { const char* kRandomPieces[] = {"ABCD", "EFGH", "IJKL", "MNOP", "QRST", "UVWX", "YZAB", "CDEF"}; string dict_source; random_device rd; for (unsigned i = 0; i < 1000; ++i) { dict_source += kRandomPieces[rd() % ABSL_ARRAYSIZE(kRandomPieces)]; } LOG(INFO) << "Creating CDICT from " << dict_source.size() << " bytes of random data"; ZSTD_CDict* cdict = ZSTD_createCDict(dict_source.data(), dict_source.size(), 7); ASSERT_TRUE(cdict); size_t actual_dict_size = ZSTD_sizeof_CDict(cdict); LOG(INFO) << "ZSTD_CDict created, size: " << actual_dict_size << " bytes"; ZSTD_DDict* ddict = ZSTD_createDDict(dict_source.data(), dict_source.size()); ASSERT_TRUE(ddict); size_t actual_ddict_size = ZSTD_sizeof_DDict(ddict); LOG(INFO) << "ZSTD_DDict created, size: " << actual_ddict_size << " bytes"; // 3. Data to compress std::string data_to_compress; for (unsigned j = 0; j < 30; ++j) { data_to_compress += kRandomPieces[rd() % ABSL_ARRAYSIZE(kRandomPieces)]; } size_t data_to_compress_size = data_to_compress.size(); // 4. Compress data string compressed = Compress(data_to_compress, cdict); LOG(INFO) << "Compressed data size: " << compressed.size() << " bytes vs " << data_to_compress_size << " bytes of original data"; string compress_no_dict = CompressNoDict(data_to_compress); LOG(INFO) << "Compressed data size without dict: " << compress_no_dict.size() << " bytes"; // 5. Decompress data string decompressed = Decompress(compressed, ddict, data_to_compress_size); ASSERT_EQ(data_to_compress, decompressed); // 7. Free memory ZSTD_freeCDict(cdict); ZSTD_freeDDict(ddict); } } // namespace dfly ================================================ FILE: src/external_libs.cmake ================================================ add_third_party( lua GIT_REPOSITORY https://github.com/dragonflydb/lua GIT_TAG Dragonfly-5.4.6a CONFIGURE_COMMAND echo BUILD_IN_SOURCE 1 BUILD_COMMAND ${DFLY_TOOLS_MAKE} all INSTALL_COMMAND cp /liblua.a ${THIRD_PARTY_LIB_DIR}/lua/lib/ COMMAND cp /lualib.h /lua.h /lauxlib.h /luaconf.h ${THIRD_PARTY_LIB_DIR}/lua/include ) if (APPLE OR ${CMAKE_SYSTEM_NAME} MATCHES "FreeBSD") set(SED_REPL sed "-i" '') else() set(SED_REPL sed "-i") endif() add_third_party( dconv GIT_REPOSITORY https://github.com/google/double-conversion # URL https://github.com/google/double-conversion/archive/refs/tags/v3.3.1.tar.gz GIT_TAG 0604b4c PATCH_COMMAND ${SED_REPL} "/static const std::ctype/d" /double-conversion/string-to-double.cc COMMAND ${SED_REPL} "/std::use_facet/double-conversion/string-to-double.cc COMMAND ${SED_REPL} "s/cType.tolower/std::tolower/g" /double-conversion/string-to-double.cc LIB libdouble-conversion.a ) add_third_party( reflex URL https://github.com/Genivia/RE-flex/archive/refs/tags/v5.2.2.tar.gz PATCH_COMMAND autoreconf -fi CONFIGURE_COMMAND /configure --disable-avx2 --prefix=${THIRD_PARTY_LIB_DIR}/reflex CXX=${CMAKE_CXX_COMPILER} CC=${CMAKE_C_COMPILER} ) set(REFLEX "${THIRD_PARTY_LIB_DIR}/reflex/bin/reflex") add_third_party( jsoncons GIT_REPOSITORY https://github.com/dragonflydb/jsoncons GIT_TAG Dragonfly1.5.0 GIT_SHALLOW 1 CMAKE_PASS_FLAGS "-DJSONCONS_BUILD_TESTS=OFF -DJSONCONS_HAS_POLYMORPHIC_ALLOCATOR=ON" LIB "none" ) add_third_party( lz4 URL https://github.com/lz4/lz4/archive/refs/tags/v1.10.0.tar.gz BUILD_IN_SOURCE 1 CONFIGURE_COMMAND echo skip BUILD_COMMAND ${DFLY_TOOLS_MAKE} lib-release INSTALL_COMMAND ${DFLY_TOOLS_MAKE} install BUILD_SHARED=no PREFIX=${THIRD_PARTY_LIB_DIR}/lz4 ) set(MIMALLOC_ROOT_DIR ${THIRD_PARTY_LIB_DIR}/mimalloc2) set(MIMALLOC_INCLUDE_DIR ${MIMALLOC_ROOT_DIR}/include) set(MIMALLOC_PATCH_DIR ${CMAKE_CURRENT_LIST_DIR}/../patches/mimalloc-v2.2.4) set(MIMALLOC_C_FLAGS "-O3 -g -DMI_STAT=1 -DNDEBUG") file(MAKE_DIRECTORY ${MIMALLOC_INCLUDE_DIR}) ExternalProject_Add(mimalloc2_project URL https://github.com/microsoft/mimalloc/archive/refs/tags/v2.2.4.tar.gz DOWNLOAD_DIR ${THIRD_PARTY_DIR}/mimalloc2 SOURCE_DIR ${THIRD_PARTY_DIR}/mimalloc2 # INSTALL_DIR ${MIMALLOC_ROOT_DIR} UPDATE_COMMAND "" PATCH_COMMAND patch -p1 -d ${THIRD_PARTY_DIR}/mimalloc2/ -i ${MIMALLOC_PATCH_DIR}/0_base.patch COMMAND patch -p1 -d ${THIRD_PARTY_DIR}/mimalloc2/ -i ${MIMALLOC_PATCH_DIR}/1_add_stat_type.patch COMMAND patch -p1 -d ${THIRD_PARTY_DIR}/mimalloc2/ -i ${MIMALLOC_PATCH_DIR}/2_return_stat.patch COMMAND patch -p1 -d ${THIRD_PARTY_DIR}/mimalloc2/ -i ${MIMALLOC_PATCH_DIR}/3_track_full_size.patch COMMAND patch -p1 -d ${THIRD_PARTY_DIR}/mimalloc2/ -i ${MIMALLOC_PATCH_DIR}/4_fix_heap_collect.patch BUILD_COMMAND make mimalloc-static INSTALL_COMMAND make install # Copy internal types like mi_page_usage_stats_s and mi_heap_s COMMAND cp -r /include/mimalloc ${MIMALLOC_INCLUDE_DIR}/ LOG_INSTALL ON LOG_DOWNLOAD ON LOG_CONFIGURE ON LOG_BUILD ON LOG_PATCH ON LOG_UPDATE ON DOWNLOAD_EXTRACT_TIMESTAMP YES CMAKE_GENERATOR "Unix Makefiles" # Add -DCMAKE_BUILD_TYPE=Debug -DCMAKE_C_FLAGS=-O0 to debug, and set BUILD_BYPRODUCTS to # libmimalloc-debug.a BUILD_BYPRODUCTS ${MIMALLOC_ROOT_DIR}/lib/libmimalloc.a CMAKE_ARGS -DCMAKE_ARCHIVE_OUTPUT_DIRECTORY:PATH=${MIMALLOC_ROOT_DIR}/lib -DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=${MIMALLOC_ROOT_DIR}/lib -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_CXX_COMPILER:STRING=${CMAKE_CXX_COMPILER} -DMI_INSTALL_TOPLEVEL=ON -DMI_OVERRIDE=OFF -DMI_NO_PADDING=ON -DMI_BUILD_TESTS=OFF -DMI_BUILD_SHARED=OFF -DMI_BUILD_OBJECT=OFF -DCMAKE_C_FLAGS=${MIMALLOC_C_FLAGS} -DCMAKE_INSTALL_PREFIX:PATH=${MIMALLOC_ROOT_DIR} ) add_library(TRDP::mimalloc2 STATIC IMPORTED) add_dependencies(TRDP::mimalloc2 mimalloc2_project) set_target_properties(TRDP::mimalloc2 PROPERTIES IMPORTED_LOCATION ${MIMALLOC_ROOT_DIR}/lib/libmimalloc.a INTERFACE_INCLUDE_DIRECTORIES ${MIMALLOC_ROOT_DIR}/include) add_third_party( croncpp URL https://github.com/mariusbancila/croncpp/archive/refs/tags/v2023.03.30.tar.gz LIB "none" ) if (WITH_SEARCH) add_third_party( uni-algo URL https://github.com/uni-algo/uni-algo/archive/refs/tags/v1.0.0.tar.gz CMAKE_PASS_FLAGS "-DCMAKE_CXX_STANDARD:STRING=20" ) add_third_party( hnswlib GIT_REPOSITORY https://github.com/dragonflydb/hnswlib.git # HEAD of dragonfly branch GIT_TAG d07dd1da2bf48b85d2f03b8396193ad7120f75c2 BUILD_COMMAND echo SKIP INSTALL_COMMAND cp -R /hnswlib ${THIRD_PARTY_LIB_DIR}/hnswlib/include/ LIB "none" ) endif() add_third_party( fast_float URL https://github.com/fastfloat/fast_float/archive/refs/tags/v5.2.0.tar.gz LIB "none" ) add_third_party( flatbuffers URL https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.tar.gz CMAKE_PASS_FLAGS "-DFLATBUFFERS_BUILD_TESTS=OFF -DFLATBUFFERS_LIBCXX_WITH_CLANG=OFF -DFLATBUFFERS_BUILD_FLATC=OFF" ) add_third_party( hdr_histogram GIT_REPOSITORY https://github.com/HdrHistogram/HdrHistogram_c/ GIT_TAG 652d51bcc36744fd1a6debfeb1a8a5f58b14022c CMAKE_PASS_FLAGS "-DHDR_LOG_REQUIRED=OFF -DHDR_HISTOGRAM_BUILD_PROGRAMS=OFF -DHDR_HISTOGRAM_INSTALL_SHARED=OFF" LIB libhdr_histogram_static.a ) if(WITH_SIMSIMD) # Compute integer macros for native half-precision support. set(SIMSIMD_NATIVE_F16_VAL 0) set(SIMSIMD_NATIVE_BF16_VAL 0) if(SIMSIMD_NATIVE_F16) set(SIMSIMD_NATIVE_F16_VAL 1) set(SIMSIMD_NATIVE_BF16_VAL 1) endif() # Build statically via add_third_party using the C shim with dynamic dispatch. add_third_party( simsimd URL https://github.com/ashvardanian/SimSIMD/archive/refs/tags/v6.5.3.tar.gz BUILD_IN_SOURCE 1 CONFIGURE_COMMAND echo skip BUILD_COMMAND bash -c "\ mkdir -p ${THIRD_PARTY_LIB_DIR}/simsimd/lib && \ ${CMAKE_C_COMPILER} -O3 -fPIC -DNDEBUG \ -DSIMSIMD_DYNAMIC_DISPATCH=1 \ -DSIMSIMD_NATIVE_F16=${SIMSIMD_NATIVE_F16_VAL} \ -DSIMSIMD_NATIVE_BF16=${SIMSIMD_NATIVE_BF16_VAL} \ -I/include -c /c/lib.c -o /lib.o && \ ar rcs /libsimsimd.a /lib.o" INSTALL_COMMAND bash -c "\ mkdir -p ${THIRD_PARTY_LIB_DIR}/simsimd/include ${THIRD_PARTY_LIB_DIR}/simsimd/lib && \ cp -R /include/* ${THIRD_PARTY_LIB_DIR}/simsimd/include/ && \ cp /libsimsimd.a ${THIRD_PARTY_LIB_DIR}/simsimd/lib/" LIB libsimsimd.a ) endif() add_library(TRDP::jsoncons INTERFACE IMPORTED) add_dependencies(TRDP::jsoncons jsoncons_project) set_target_properties(TRDP::jsoncons PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${JSONCONS_INCLUDE_DIR}") add_library(TRDP::croncpp INTERFACE IMPORTED) add_dependencies(TRDP::croncpp croncpp_project) set_target_properties(TRDP::croncpp PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${CRONCPP_INCLUDE_DIR}") if (WITH_SEARCH) add_library(TRDP::hnswlib INTERFACE IMPORTED) add_dependencies(TRDP::hnswlib hnswlib_project) set_target_properties(TRDP::hnswlib PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${HNSWLIB_INCLUDE_DIR}") endif() add_library(TRDP::fast_float INTERFACE IMPORTED) add_dependencies(TRDP::fast_float fast_float_project) set_target_properties(TRDP::fast_float PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${FAST_FLOAT_INCLUDE_DIR}") ================================================ FILE: src/facade/CMakeLists.txt ================================================ add_library(dfly_parser_lib redis_parser.cc resp_expr.cc resp_parser.cc resp_srv_parser.cc) cxx_link(dfly_parser_lib base strings_lib redis_lib) add_library(dfly_facade dragonfly_listener.cc dragonfly_connection.cc facade.cc memcache_parser.cc reply_builder.cc op_status.cc parsed_command.cc service_interface.cc reply_capture.cc cmd_arg_parser.cc tls_helpers.cc socket_utils.cc disk_backed_queue.cc) if (DF_USE_SSL) set(TLS_LIB tls_lib) target_compile_definitions(dfly_facade PRIVATE DFLY_USE_SSL) endif() cxx_link(dfly_facade dfly_parser_lib http_server_lib fibers2 ${TLS_LIB} TRDP::mimalloc2 TRDP::dconv redis_lib) add_library(facade_test facade_test.cc resp_expr_test_utils.cc) cxx_link(facade_test dfly_facade gtest_main_ext) helio_cxx_test(memcache_parser_test dfly_facade LABELS DFLY) helio_cxx_test(redis_parser_test facade_test LABELS DFLY) helio_cxx_test(resp_srv_parser_test facade_test LABELS DFLY) helio_cxx_test(reply_builder_test facade_test LABELS DFLY) helio_cxx_test(resp_parser_test facade_test LABELS DFLY) helio_cxx_test(cmd_arg_parser_test facade_test LABELS DFLY) helio_cxx_test(disk_backed_queue_test facade_test LABELS DFLY) add_executable(ok_backend ok_main.cc) cxx_link(ok_backend dfly_facade) add_executable(resp_validator resp_validator.cc) cxx_link(resp_validator dfly_parser_lib) ================================================ FILE: src/facade/README.md ================================================ ## A facade library The library is responsible for opening dragonfly-like TCP client connections. I call it facade because "client" term is often abused. It should be separated from the rest of dragonfly server logic and should be self-contained, i.e no redis-lib or server dependencies are allowed. ================================================ FILE: src/facade/cmd_arg_parser.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/cmd_arg_parser.h" #include #include "base/logging.h" #include "facade/error.h" namespace facade { void CmdArgParser::ExpectTag(std::string_view tag) { if (cur_i_ >= args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); return; } auto idx = cur_i_++; auto val = ToSV(args_[idx]); if (!absl::EqualsIgnoreCase(val, tag)) { Report(INVALID_NEXT, idx); } } CmdArgParser::ErrorInfo CmdArgParser::TakeError() { return std::exchange(error_, {}); } ErrorReply CmdArgParser::ErrorInfo::MakeReply() const { DCHECK(operator bool()); switch (type) { case INVALID_INT: return ErrorReply{kInvalidIntErr}; case INVALID_FLOAT: return ErrorReply{kInvalidFloatErr}; default: return ErrorReply{kSyntaxErr}; }; return ErrorReply{kSyntaxErr}; } CmdArgParser::~CmdArgParser() { DCHECK(!error_) << "Parsing error occured but not checked"; // TODO DCHECK(!HasNext()) << "Not all args were processed"; } } // namespace facade ================================================ FILE: src/facade/cmd_arg_parser.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include "facade/facade_types.h" namespace facade { // Helper class for numerical range restriction during parsing template struct FInt { decltype(min) value = {}; operator decltype(min)() { return value; } static_assert(std::is_same_v, "inconsistent types"); static constexpr auto kMin = min; static constexpr auto kMax = max; }; template constexpr bool is_fint = false; template constexpr bool is_fint> = true; // Utility class for easily parsing command options from argument lists. struct CmdArgParser { enum ErrorType { NO_ERROR, OUT_OF_BOUNDS, SHORT_OPT_TAIL, INVALID_INT, INVALID_FLOAT, INVALID_CASES, INVALID_NEXT, UNPROCESSED, CUSTOM_ERROR // should be the last one }; struct ErrorInfo { int type = NO_ERROR; size_t index = 0; operator bool() const { return type != ErrorType::NO_ERROR; } ErrorReply MakeReply() const; }; public: CmdArgParser(ArgSlice args) : args_{args} { } // Debug asserts sure error was consumed ~CmdArgParser(); // Get next value without consuming it std::string_view Peek() { return SafeSV(cur_i_); } // Consume next value template auto Next() { if (cur_i_ + sizeof...(Ts) >= args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); return std::conditional_t>(); } if constexpr (sizeof...(Ts) == 0) { auto idx = cur_i_++; return Convert(idx); } else { std::tuple res; NextImpl<0>(&res); cur_i_ += sizeof...(Ts) + 1; return res; } } // returns next value if exists or default value template auto NextOrDefault(T default_value = {}) { return HasNext() ? Next() : default_value; } // check next value ignoring case and consume it void ExpectTag(std::string_view tag); // Consume next value template auto MapNext(Cases&&... cases) { if (cur_i_ >= args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); return typename decltype(MapImpl(std::string_view(), std::forward(cases)...))::value_type{}; } auto idx = cur_i_++; auto res = MapImpl(SafeSV(idx), std::forward(cases)...); if (!res) { Report(INVALID_CASES, idx); return typename decltype(res)::value_type{}; } return *res; } // Consume next value if can map it and return mapped result or return nullopt template auto TryMapNext(Cases&&... cases) -> std::optional>> { if (cur_i_ >= args_.size()) { return std::nullopt; } auto res = MapImpl(SafeSV(cur_i_), std::forward(cases)...); cur_i_ = res ? cur_i_ + 1 : cur_i_; return res; } // Check if the next value is equal to a specific tag. If equal, its consumed. template bool Check(std::string_view tag, Args*... args) { if (cur_i_ + sizeof...(Args) >= args_.size()) return false; std::string_view arg = SafeSV(cur_i_); if (!absl::EqualsIgnoreCase(arg, tag)) return false; ((*args = Convert(++cur_i_)), ...); ++cur_i_; return true; } // Skip specified number of arguments CmdArgParser& Skip(size_t n) { if (cur_i_ + n > args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); } else { cur_i_ += n; } return *this; } // Expect no more arguments and return if no error has occured bool Finalize() { if (HasNext()) { Report(UNPROCESSED, cur_i_); return false; } return !HasError(); } // Return remaining arguments ArgSlice Tail() const { return args_.subspan(cur_i_); } // Return true if arguments are left and no errors occured bool HasNext() { return cur_i_ < args_.size() && !error_; } bool HasError() const { return bool(error_); } ErrorInfo TakeError(); bool HasAtLeast(size_t i) const { return cur_i_ + i <= args_.size() && !error_; } size_t GetCurrentIndex() const { return cur_i_; } // Custom error_type should start from CUSTOM_ERROR void Report(int error_type) { // we use previous index, because the check was done outside and it's done after element is // processed Report(error_type, cur_i_ - 1); } private: void Report(int error_type, size_t idx) { if (!error_) { error_ = {error_type, idx}; cur_i_ = args_.size(); } } template std::optional> MapImpl(std::string_view arg, std::string_view tag, T&& value, Cases&&... cases) { if (absl::EqualsIgnoreCase(arg, tag)) return std::forward(value); if constexpr (sizeof...(cases) > 0) return MapImpl(arg, cases...); return std::nullopt; } template void NextImpl(Tuple* t) { std::get(*t) = Convert>(cur_i_ + shift); if constexpr (constexpr auto next = shift + 1; next < std::tuple_size_v) NextImpl(t); } template T Convert(size_t idx) { static_assert( std::is_arithmetic_v || std::is_constructible_v || is_fint, "incorrect type"); if constexpr (std::is_arithmetic_v) { return Num(idx); } else if constexpr (std::is_constructible_v) { return static_cast(SafeSV(idx)); } else if constexpr (is_fint) { return {ConvertFInt(idx)}; } } template FInt ConvertFInt(size_t idx) { auto res = Num(idx); if (res < min || res > max) { Report(INVALID_INT, idx); return {}; } return {res}; } std::string_view SafeSV(size_t i) const { using namespace std::literals::string_view_literals; if (i >= args_.size()) return ""sv; return args_[i].empty() ? ""sv : ToSV(args_[i]); } template T Num(size_t idx) { auto arg = SafeSV(idx); T out; if constexpr (std::is_same_v) { if (absl::SimpleAtof(arg, &out)) return out; } else if constexpr (std::is_same_v) { if (absl::SimpleAtod(arg, &out)) return out; } else if constexpr (std::is_integral_v && sizeof(T) >= sizeof(int32_t)) { if (absl::SimpleAtoi(arg, &out)) return out; } else if constexpr (std::is_integral_v && sizeof(T) < sizeof(int32_t)) { int32_t tmp; if (absl::SimpleAtoi(arg, &tmp)) { out = tmp; // out can not store the whole tmp if (tmp == out) return out; } } if constexpr (std::is_floating_point_v) { Report(INVALID_FLOAT, idx); } else { Report(INVALID_INT, idx); } return {}; } private: size_t cur_i_ = 0; ArgSlice args_; ErrorInfo error_; }; } // namespace facade ================================================ FILE: src/facade/cmd_arg_parser_test.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/cmd_arg_parser.h" #include #include #include "facade/memcache_parser.h" using namespace testing; using namespace std; namespace facade { class CmdArgParserTest : public testing::Test { public: CmdArgParser Make(absl::Span args) { storage_.assign(args.begin(), args.end()); arg_vec_.clear(); for (auto& s : storage_) arg_vec_.push_back(MutableSlice{s.data(), s.size()}); return CmdArgParser{absl::MakeSpan(arg_vec_)}; } private: CmdArgVec arg_vec_; std::vector storage_; }; TEST_F(CmdArgParserTest, BasicTypes) { auto parser = Make({"STRING", "VIEW", "11", "22", "33", "44"}); EXPECT_TRUE(parser.HasNext()); EXPECT_EQ(parser.Next(), "STRING"s); EXPECT_EQ(parser.Next(), "VIEW"sv); EXPECT_EQ(parser.Next(), 11u); EXPECT_EQ(parser.Next(), 22u); auto [a, b] = parser.Next(); EXPECT_EQ(a, 33u); EXPECT_EQ(b, 44u); EXPECT_FALSE(parser.HasNext()); EXPECT_FALSE(parser.HasError()); } TEST_F(CmdArgParserTest, BoundError) { auto parser = Make({}); EXPECT_EQ(absl::implicit_cast(parser.Next()), ""sv); auto err = parser.TakeError(); EXPECT_TRUE(err); EXPECT_EQ(err.type, CmdArgParser::OUT_OF_BOUNDS); EXPECT_EQ(err.index, 0); } #ifndef __APPLE__ TEST_F(CmdArgParserTest, IntError) { auto parser = Make({"NOTANINT"}); EXPECT_EQ(parser.Next(), 0u); auto err = parser.TakeError(); EXPECT_TRUE(err); EXPECT_EQ(err.type, CmdArgParser::INVALID_INT); EXPECT_EQ(err.index, 0); } #endif TEST_F(CmdArgParserTest, Check) { auto parser = Make({"TAG", "TAG_2", "22"}); EXPECT_FALSE(parser.Check("NOT_TAG")); EXPECT_TRUE(parser.Check("TAG")); EXPECT_FALSE(parser.Check("NOT_TAG_2")); EXPECT_TRUE(parser.Check("TAG_2")); EXPECT_EQ(parser.Next(), 22); } TEST_F(CmdArgParserTest, NextStatement) { auto parser = Make({"TAG", "tag_2", "tag_3"}); parser.ExpectTag("TAG"); EXPECT_FALSE(parser.TakeError()); parser.ExpectTag("TAG_2"); EXPECT_FALSE(parser.TakeError()); parser.ExpectTag("TAG_2"); EXPECT_TRUE(parser.TakeError()); } TEST_F(CmdArgParserTest, CheckTailFail) { auto parser = Make({"TAG", "11", "22", "TAG", "text"}); int first; string_view second; EXPECT_TRUE(parser.Check("TAG", &first, &second)); EXPECT_EQ(first, 11); EXPECT_EQ(second, "22"); EXPECT_FALSE(parser.Check("TAG", &first, &second)); EXPECT_TRUE(parser.Check("TAG", &first)); EXPECT_TRUE(parser.TakeError()); } TEST_F(CmdArgParserTest, Map) { auto parser = Make({"TWO", "NONE"}); EXPECT_EQ(parser.MapNext("ONE", 1, "TWO", 2), 2); EXPECT_EQ(parser.MapNext("ONE", 1, "TWO", 2), 0); auto err = parser.TakeError(); EXPECT_TRUE(err); EXPECT_EQ(err.type, CmdArgParser::INVALID_CASES); EXPECT_EQ(err.index, 1); } TEST_F(CmdArgParserTest, TryMapNext) { auto parser = Make({"TWO", "GREEN"}); EXPECT_EQ(parser.TryMapNext("ONE", 1, "TWO", 2), std::make_optional(2)); EXPECT_EQ(parser.TryMapNext("ONE", 1, "TWO", 2), std::nullopt); EXPECT_FALSE(parser.HasError()); EXPECT_EQ(parser.TryMapNext("green", 1, "yellow", 2), std::make_optional(1)); EXPECT_FALSE(parser.HasError()); } TEST_F(CmdArgParserTest, IgnoreCase) { auto parser = Make({"hello", "marker", "taail", "world"}); EXPECT_EQ(absl::implicit_cast(parser.Next()), "hello"sv); EXPECT_TRUE(parser.Check("MARKER"sv)); parser.Skip(1); EXPECT_EQ(absl::implicit_cast(parser.Next()), "world"sv); } TEST_F(CmdArgParserTest, FixedRangeInt) { { auto parser = Make({"10", "-10", "12"}); EXPECT_EQ((parser.Next>().value), 10); EXPECT_EQ((parser.Next>().value), -10); EXPECT_EQ((parser.Next>().value), 0); auto err = parser.TakeError(); EXPECT_TRUE(err); EXPECT_EQ(err.type, CmdArgParser::INVALID_INT); EXPECT_EQ(err.index, 2); } { auto parser = Make({"-12"}); EXPECT_EQ((parser.Next>().value), 0); auto err = parser.TakeError(); EXPECT_TRUE(err); EXPECT_EQ(err.type, CmdArgParser::INVALID_INT); EXPECT_EQ(err.index, 0); } } } // namespace facade ================================================ FILE: src/facade/command_id.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace facade { class CommandId { public: /** * @brief Construct a new Command Id object * * When creating a new command use the https://github.com/redis/redis/tree/unstable/src/commands * files to find the right arguments. * * @param name * @param mask * @param arity - positive if command has fixed number of required arguments including * the command, negative if command has minimum number of required arguments, * but may have more. * @param first_key - position of first key in argument list * @param last_key - position of last key in argument list, * -1 means the last key index is (arg_length - 1), -2 means that the last key * index is (arg_length - 2). * @param acl_categories - bitfield for acl categories of the command */ CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, int8_t last_key, uint32_t acl_categories); std::string_view name() const { return name_; } int arity() const { return arity_; } uint32_t opt_mask() const { return opt_mask_; } int8_t first_key_pos() const { return first_key_; } int8_t last_key_pos() const { return last_key_; } uint32_t acl_categories() const { return acl_categories_; } void SetFamily(size_t fam) { family_ = fam; } void SetBitIndex(uint64_t bit) { bit_index_ = bit; } size_t GetFamily() const { return family_; } uint64_t GetBitIndex() const { return bit_index_; } // Returns true if the command can only be used by admin connections, false // otherwise. bool IsRestricted() const { return restricted_; } void SetRestricted(bool restricted) { restricted_ = restricted; } void SetFlag(uint32_t flag) { opt_mask_ |= flag; } protected: std::string name_; uint32_t opt_mask_; int8_t arity_; int8_t first_key_; int8_t last_key_; // Acl categories uint32_t acl_categories_; // Acl commands indices size_t family_; uint64_t bit_index_; // Whether the command can only be used by admin connections. bool restricted_ = false; }; } // namespace facade ================================================ FILE: src/facade/conn_context.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace facade { class Connection; class ConnectionContext { public: explicit ConnectionContext(Connection* owner) : owner_(owner) { conn_closing = false; req_auth = false; replica_conn = false; authenticated = false; async_dispatch = false; sync_dispatch = false; paused = false; blocked = false; subscriptions = 0; } virtual ~ConnectionContext() { } Connection* conn() { return owner_; } const Connection* conn() const { return owner_; } virtual size_t UsedMemory() const { return 0; } // Noop. virtual void Unsubscribe(std::string_view channel) { } // connection state / properties. bool conn_closing : 1; bool req_auth : 1; bool replica_conn : 1; // whether it's a replica connection on the master side. bool authenticated : 1; bool async_dispatch : 1; // whether this connection is amid an async dispatch bool sync_dispatch : 1; // whether this connection is amid a sync dispatch bool paused = false; // whether this connection is paused due to CLIENT PAUSE // whether it's blocked on blocking commands like BLPOP, needs to be addressable bool blocked = false; // How many async subscription sources are active: monitor and/or pubsub - at most 2. uint8_t subscriptions; private: Connection* owner_; }; } // namespace facade ================================================ FILE: src/facade/connection_ref.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace facade { class Connection; // Weak reference to a connection, invalidated upon connection close. // Used to dispatch async operations for the connection without worrying about pointer lifetime. struct ConnectionRef { public: // Get residing thread of connection. Thread-safe. unsigned LastKnownThreadId() const { return last_known_thread_id_; } // Get pointer to connection if still valid, nullptr if expired. // Can only be called from connection's thread. Validity is guaranteed // only until the next suspension point. Connection* Get() const; // Returns true if the reference expired. Thread-safe. bool IsExpired() const; // Returns client id.Thread-safe. uint32_t GetClientId() const; bool operator<(const ConnectionRef& other) const; bool operator==(const ConnectionRef& other) const; private: friend class Connection; ConnectionRef(const std::shared_ptr& ptr, unsigned thread_id, uint32_t client_id); std::weak_ptr ptr_; unsigned last_known_thread_id_; uint32_t client_id_; }; } // namespace facade ================================================ FILE: src/facade/disk_backed_queue.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // // See LICENSE for licensing terms. // #include "facade/disk_backed_queue.h" #include #include #include #include #include #include "base/flags.h" #include "base/logging.h" #include "facade/facade_types.h" #include "io/io.h" #include "util/fibers/uring_file.h" #include "util/fibers/uring_proactor.h" using facade::operator""_MB; ABSL_FLAG(std::string, disk_backpressure_folder, "/tmp/", "Folder to store disk-backed connection backpressure"); ABSL_FLAG(size_t, disk_backpressure_file_max_bytes, 50_MB, "Maximum size of the backing file. When max size is reached, connection will " "stop offloading backpressure to disk and block on client read."); namespace facade { DiskBackedQueue::DiskBackedQueue(uint32_t conn_id) : max_backing_size_(absl::GetFlag(FLAGS_disk_backpressure_file_max_bytes)), id_(conn_id) { } std::error_code DiskBackedQueue::Init() { std::string backing_name = absl::StrCat(absl::GetFlag(FLAGS_disk_backpressure_folder), id_); // Open a single O_RDWR file so the same fd serves writes, reads, and fallocate punch holes. // Kernel transparently handles buffering via the page cache. auto res = util::fb2::OpenLinux(backing_name, O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0600); if (!res) { return res.error(); } file_ = std::move(*res); VLOG(3) << "Created backing for connection " << this << " " << backing_name; return {}; } DiskBackedQueue::~DiskBackedQueue() { DCHECK_EQ(in_flight_callbacks_, 0ul); } std::error_code DiskBackedQueue::Close() { if (file_) { auto ec = file_->Close(); LOG_IF(WARNING, ec) << ec.message(); std::string backing = absl::StrCat(absl::GetFlag(FLAGS_disk_backpressure_folder), id_); int errc = unlink(backing.c_str()); LOG_IF(ERROR, errc != 0) << "Failed to unlink backing file: " << std::error_code{errc, std::system_category()}; return ec; } return {}; } // Check if backing file is empty, i.e. backing file has 0 bytes. bool DiskBackedQueue::Empty() const { return total_backing_bytes_ == 0; } bool DiskBackedQueue::HasEnoughBackingSpaceFor(size_t bytes) const { return (bytes + total_backing_bytes_) < max_backing_size_; } void DiskBackedQueue::MaybePunchHole() { // Punch holes over the aligned region we have fully read past so the OS can reclaim pages. // Both offset and length must be multiples of the filesystem block size: XFS returns EINVAL // otherwise, and ext4/tmpfs only zero partial blocks rather than freeing them. // We assume 4096-byte blocks (correct for virtually all deployments); a fully robust // implementation would query the actual block size via fstatfs(file_->GetFd(), &fsst) and // align to fsst.f_bsize instead. const size_t aligned_end = (next_read_offset_ / 4096) * 4096; if (aligned_end > punch_offset_) { int res = fallocate(file_->GetFd(), FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, punch_offset_, aligned_end - punch_offset_); DCHECK_EQ(res, 0) << "fallocate punch failed: " << strerror(errno); punch_offset_ = aligned_end; } } void DiskBackedQueue::PushAsync(io::Bytes bytes, AsyncPushCallback cb) { const size_t offset = write_offset_; const size_t size = bytes.size(); ++in_flight_callbacks_; file_->WriteAsync(bytes, offset, [this, size, cb = std::move(cb)](int res) { --in_flight_callbacks_; if (res < 0) { std::error_code ec{-res, std::system_category()}; VLOG(2) << "Failed to offload blob of size " << size << " to backing with error: " << ec; cb(ec); return; } write_offset_ += size; total_backing_bytes_ += size; VLOG(2) << "Offload connection " << this << " backpressure of " << size; cb({}); }); } void DiskBackedQueue::PopAsync(io::MutableBytes out, AsyncPopCallback cb) { const size_t to_read = std::min(total_backing_bytes_, out.size()); const size_t offset = next_read_offset_; ++in_flight_callbacks_; // Capture a subset of out for the actual read size io::MutableBytes read_buf = out.subspan(0, to_read); file_->ReadAsync(read_buf, offset, [this, to_read, offset, cb = std::move(cb)](int res) { --in_flight_callbacks_; if (res < 0) { std::error_code ec{-res, std::system_category()}; LOG(ERROR) << "Could not load item at offset " << offset << " of size " << to_read << " from disk with error: " << ec.value() << " " << ec.message(); cb(nonstd::make_unexpected(ec)); return; } size_t bytes_read = static_cast(res); next_read_offset_ += bytes_read; total_backing_bytes_ -= bytes_read; VLOG(2) << "Loaded item with offset " << offset << " of size " << bytes_read << " for connection " << this; MaybePunchHole(); cb(bytes_read); }); } } // namespace facade ================================================ FILE: src/facade/disk_backed_queue.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "io/io.h" #include "util/fibers/uring_file.h" namespace facade { class DiskBackedQueue { public: explicit DiskBackedQueue(uint32_t conn_id); ~DiskBackedQueue(); std::error_code Init(); // Check if we can offload bytes to backing file. bool HasEnoughBackingSpaceFor(size_t bytes) const; using AsyncPushCallback = std::function; void PushAsync(io::Bytes bytes, AsyncPushCallback cb); using AsyncPopCallback = std::function)>; // Async read variant. Callback is invoked with Result containing bytes read or error. void PopAsync(io::MutableBytes out, AsyncPopCallback cb); // Check if backing file is empty, i.e. backing file has 0 bytes. bool Empty() const; std::error_code Close(); private: // Punch holes over the aligned region we have fully read past so the OS can reclaim pages. void MaybePunchHole(); // Single O_RDWR file used for both writes and reads, avoiding a separate fd for fallocate. std::unique_ptr file_; size_t write_offset_ = 0; size_t total_backing_bytes_ = 0; size_t next_read_offset_ = 0; // Tracks how far into the file holes have been punched (always 4096-aligned). size_t punch_offset_ = 0; // Read only constants const size_t max_backing_size_ = 0; // same as connection id. Used to uniquely identify the backed file const size_t id_ = 0; size_t in_flight_callbacks_ = 0; }; } // namespace facade ================================================ FILE: src/facade/disk_backed_queue_test.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/disk_backed_queue.h" #include #include #include #include #include #include #include #include "base/flags.h" #include "base/gtest.h" #include "base/logging.h" #include "io/io.h" #include "util/fibers/pool.h" namespace dfly { namespace { using namespace facade; class DiskBackedQueueTest : public testing::Test { protected: void SetUp() override { pp_.reset(util::fb2::Pool::IOUring(16, 1)); pp_->Run(); } void TearDown() override { pp_->Stop(); pp_.reset(); } std::unique_ptr pp_; }; // Verifies that after reading >= 4096 bytes, punch_hole is called correctly // and disk space is reclaimed. TEST_F(DiskBackedQueueTest, PunchHoleReleasesSpace) { pp_->at(0)->Await([]() { // Use id=2 to avoid collision with ReadWrite test. DiskBackedQueue backing(2); ASSERT_FALSE(backing.Init()); // Write 3 pages (12288 bytes) so the punch logic is triggered on reads. std::string data(12288, 'x'); { util::fb2::Done done; backing.PushAsync(io::MutableBytes(reinterpret_cast(data.data()), data.size()), [&done](std::error_code ec) { ASSERT_FALSE(ec); done.Notify(); }); done.Wait(); } // Read all data back in 4096-byte chunks. std::string results; while (!backing.Empty()) { std::string buf(4096, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(out, [&done, &results, &buf](io::Result res) { ASSERT_TRUE(res); results.append(buf.data(), *res); done.Notify(); }); done.Wait(); } EXPECT_EQ(results, data); // After reading all 3 pages the punch should have freed the first 3 aligned pages. // SEEK_HOLE at offset 0 returns 0 when a hole starts at the beginning of the file. int check_fd = open("/tmp/2", O_RDONLY); ASSERT_GE(check_fd, 0); off_t hole_start = lseek(check_fd, 0, SEEK_HOLE); close(check_fd); EXPECT_EQ(hole_start, 0) << "Expected hole at start of file - punch_hole did not free space"; ASSERT_FALSE(backing.Close()); }); } // Verifies that reading across multiple pages advances the punch offset correctly so that // successive reads keep freeing space (not re-punching offset 0 or skipping blocks). TEST_F(DiskBackedQueueTest, PunchHoleAdvancesOffset) { pp_->at(0)->Await([]() { DiskBackedQueue backing(3); ASSERT_FALSE(backing.Init()); // Write 8 pages so we can do several reads and check the hole grows. std::string data(32768, 'y'); { util::fb2::Done done; backing.PushAsync(io::MutableBytes(reinterpret_cast(data.data()), data.size()), [&done](std::error_code ec) { ASSERT_FALSE(ec); done.Notify(); }); done.Wait(); } // Read exactly 4096 bytes (1 page). { std::string buf(4096, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(out, [&done](io::Result res) { ASSERT_TRUE(res); done.Notify(); }); done.Wait(); } // After 1 page read the hole should start at 0 and the first non-hole (data) should be at // offset 4096 (i.e., lseek SEEK_DATA starting from 0 skips the punched hole). int check_fd = open("/tmp/3", O_RDONLY); ASSERT_GE(check_fd, 0); off_t first_hole = lseek(check_fd, 0, SEEK_HOLE); off_t first_data = lseek(check_fd, 0, SEEK_DATA); close(check_fd); EXPECT_EQ(first_hole, 0) << "Hole should begin at offset 0 after first page read"; EXPECT_EQ(first_data, 4096) << "Non-hole data should start at 4096 after punching first page"; ASSERT_FALSE(backing.Close()); }); } // Verifies that unaligned writes and reads correctly punch holes at aligned boundaries. // Punch should only occur when we've fully read past 4096-byte boundaries. TEST_F(DiskBackedQueueTest, PunchHoleUnalignedReadsAndWrites) { pp_->at(0)->Await([]() { DiskBackedQueue backing(4); ASSERT_FALSE(backing.Init()); // Write 10000 bytes (not a multiple of 4096). // This is 2 full pages (8192 bytes) + 1808 partial bytes. std::string data(10000, 'z'); { util::fb2::Done done; backing.PushAsync(io::MutableBytes(reinterpret_cast(data.data()), data.size()), [&done](std::error_code ec) { ASSERT_FALSE(ec); done.Notify(); }); done.Wait(); } // Read 3000 bytes (unaligned, less than 1 page). // next_read_offset_ will be 3000, but aligned_end = (3000/4096)*4096 = 0. // So no punch should happen yet. std::string results; { std::string buf(3000, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(out, [&done, &results, &buf](io::Result res) { ASSERT_TRUE(res); results.append(buf.data(), *res); done.Notify(); }); done.Wait(); } // Check that no hole exists yet (first 3000 bytes read but not 4096-aligned). int check_fd = open("/tmp/4", O_RDONLY); ASSERT_GE(check_fd, 0); off_t hole_at_start = lseek(check_fd, 0, SEEK_HOLE); // SEEK_HOLE from offset 0 should jump to EOF if no hole exists at start. EXPECT_GT(hole_at_start, 0) << "No hole should exist yet after reading 3000 bytes"; close(check_fd); // Read another 2000 bytes (total read = 5000 bytes). // next_read_offset_ will be 5000, aligned_end = (5000/4096)*4096 = 4096. // Now the first page (0-4095) should be punched. { std::string buf(2000, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(out, [&done, &results, &buf](io::Result res) { ASSERT_TRUE(res); results.append(buf.data(), *res); done.Notify(); }); done.Wait(); } // Verify first page is now a hole. check_fd = open("/tmp/4", O_RDONLY); ASSERT_GE(check_fd, 0); off_t first_hole = lseek(check_fd, 0, SEEK_HOLE); off_t first_data = lseek(check_fd, 0, SEEK_DATA); EXPECT_EQ(first_hole, 0) << "Hole should start at offset 0 after reading past 4096 bytes"; EXPECT_EQ(first_data, 4096) << "Data should start at 4096 (second page)"; // Read another 3500 bytes (total read = 8500 bytes). // next_read_offset_ will be 8500, aligned_end = (8500/4096)*4096 = 8192. // Now the first two pages (0-8191) should be punched. { std::string buf(3500, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(out, [&done, &results, &buf](io::Result res) { ASSERT_TRUE(res); results.append(buf.data(), *res); done.Notify(); }); done.Wait(); } // Verify first two pages are holes. first_hole = lseek(check_fd, 0, SEEK_HOLE); first_data = lseek(check_fd, 0, SEEK_DATA); close(check_fd); EXPECT_EQ(first_hole, 0) << "Hole should start at offset 0"; EXPECT_EQ(first_data, 8192) << "Data should start at 8192 (third page)"; // Read remaining data and verify results match. while (!backing.Empty()) { std::string buf(4096, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(out, [&done, &results, &buf](io::Result res) { ASSERT_TRUE(res); results.append(buf.data(), *res); done.Notify(); }); done.Wait(); } EXPECT_EQ(results, data); ASSERT_FALSE(backing.Close()); }); } TEST_F(DiskBackedQueueTest, AsyncReadWrite) { pp_->at(0)->Await([]() { DiskBackedQueue backing(5 /* id */); EXPECT_FALSE(backing.Init()); std::string commands; for (size_t i = 0; i < 100; ++i) { auto cmd = absl::StrCat("SET FOO", i, " BAR"); commands += cmd; } // Async write all commands util::fb2::Fiber write_fiber = util::fb2::Fiber("writer", [&]() { for (size_t i = 0; i < 100; ++i) { auto cmd = absl::StrCat("SET FOO", i, " BAR"); auto bytes = io::MutableBytes(reinterpret_cast(cmd.data()), cmd.size()); util::fb2::Done done; backing.PushAsync(bytes, [&done](std::error_code ec) { EXPECT_FALSE(ec); done.Notify(); }); done.Wait(); } }); write_fiber.Join(); // Async read all results std::string results; util::fb2::Fiber read_fiber = util::fb2::Fiber("reader", [&]() { while (!backing.Empty()) { std::string buf(1024, 'c'); auto bytes = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done done; backing.PopAsync(bytes, [&done, &results, &buf](io::Result res) { EXPECT_TRUE(res); results.append(buf.data(), *res); done.Notify(); }); done.Wait(); } }); read_fiber.Join(); EXPECT_EQ(results.size(), commands.size()); EXPECT_EQ(results, commands); EXPECT_FALSE(backing.Close()); }); } TEST_F(DiskBackedQueueTest, AsyncPunchHole) { pp_->at(0)->Await([]() { DiskBackedQueue backing(6); ASSERT_FALSE(backing.Init()); // Write 3 pages (12288 bytes) asynchronously std::string data(12288, 'x'); util::fb2::Done write_done; backing.PushAsync(io::MutableBytes(reinterpret_cast(data.data()), data.size()), [&write_done](std::error_code ec) { ASSERT_FALSE(ec); write_done.Notify(); }); write_done.Wait(); // Async read all data back in 4096-byte chunks std::string results; while (!backing.Empty()) { std::string buf(4096, '\0'); auto out = io::MutableBytes(reinterpret_cast(buf.data()), buf.size()); util::fb2::Done read_done; backing.PopAsync(out, [&read_done, &results, &buf](io::Result res) { ASSERT_TRUE(res); results.append(buf.data(), *res); read_done.Notify(); }); read_done.Wait(); } EXPECT_EQ(results, data); // Verify punch hole freed space int check_fd = open("/tmp/6", O_RDONLY); ASSERT_GE(check_fd, 0); off_t hole_start = lseek(check_fd, 0, SEEK_HOLE); close(check_fd); EXPECT_EQ(hole_start, 0) << "Expected hole at start of file - async punch did not free space"; ASSERT_FALSE(backing.Close()); }); } } // namespace } // namespace dfly ================================================ FILE: src/facade/dragonfly_connection.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // // See LICENSE for licensing terms. // #include "facade/dragonfly_connection.h" #include #include #include #include #include #include #include #include #include "base/cycle_clock.h" #include "base/flag_utils.h" #include "base/flags.h" #include "base/histogram.h" #include "base/io_buf.h" #include "base/logging.h" #include "base/stl_util.h" #include "common/heap_size.h" #include "facade/conn_context.h" #include "facade/dragonfly_listener.h" #include "facade/facade_types.h" #include "facade/memcache_parser.h" #include "facade/redis_parser.h" #include "facade/reply_builder.h" #include "facade/resp_srv_parser.h" #include "facade/service_interface.h" #include "facade/socket_utils.h" #include "io/file.h" #include "strings/human_readable.h" #include "util/fiber_socket_base.h" #include "util/fibers/fibers.h" #include "util/fibers/proactor_base.h" #ifdef DFLY_USE_SSL #include "util/tls/tls_socket.h" #endif #ifdef __linux__ #include "util/fibers/uring_file.h" #include "util/fibers/uring_proactor.h" #include "util/fibers/uring_socket.h" #endif using namespace std; using facade::operator""_MB; ABSL_FLAG(bool, tcp_nodelay, true, "Configures dragonfly connections with socket option TCP_NODELAY"); ABSL_FLAG(bool, primary_port_http_enabled, true, "If true allows accessing http console on main TCP port"); ABSL_FLAG(uint16_t, admin_port, 0, "If set, would enable admin access to console on the assigned port. " "This supports both HTTP and RESP protocols"); ABSL_FLAG(string, admin_bind, "", "If set, the admin consol TCP connection would be bind the given address. " "This supports both HTTP and RESP protocols"); ABSL_FLAG(strings::MemoryBytesFlag, request_cache_limit, 64_MB, "Amount of memory to use for request cache in bytes - per IO thread."); ABSL_FLAG(strings::MemoryBytesFlag, pipeline_buffer_limit, 128_MB, "Amount of memory to use for storing pipeline requests - per IO thread." "Please note that clients that send excecissively huge pipelines, " "may deadlock themselves. See https://github.com/dragonflydb/dragonfly/discussions/3997" "for details."); ABSL_FLAG(uint32_t, pipeline_queue_limit, 10000, "Pipeline queue max length, the server will stop reading from the client socket" " once its pipeline queue crosses this limit, and will resume once it processes " "excessive requests. This is to prevent OOM states. Users of huge pipelines sizes " "may require increasing this limit to prevent the risk of deadlocking." "See https://github.com/dragonflydb/dragonfly/discussions/3997 for details"); ABSL_FLAG(strings::MemoryBytesFlag, publish_buffer_limit, 128_MB, "Amount of memory to use for storing pub commands in bytes - per IO thread"); ABSL_FLAG(uint32_t, pipeline_squash, 1, "Number of queued pipelined commands above which squashing is enabled, 0 means disabled"); // When changing this constant, also update `test_large_cmd` test in connection_test.py. ABSL_FLAG(uint32_t, max_multi_bulk_len, 1u << 16, "Maximum multi-bulk (array) length that is " "allowed to be accepted when parsing RESP protocol"); ABSL_FLAG(uint64_t, max_bulk_len, 2u << 30, "Maximum bulk length that is " "allowed to be accepted when parsing RESP protocol"); ABSL_FLAG(strings::MemoryBytesFlag, max_client_iobuf_len, 1u << 16, "Maximum io buffer length that is used to read client requests."); ABSL_FLAG(bool, migrate_connections, true, "When enabled, Dragonfly will try to migrate connections to the target thread on which " "they operate. Currently this is only supported for Lua script invocations, and can " "happen at most once per connection."); ABSL_FLAG(uint32_t, max_busy_read_usec, 200, "Maximum time we read and parse from " "a socket without yielding. In microseconds."); ABSL_FLAG(size_t, squashed_reply_size_limit, 0, "Max bytes allowed for squashing_current_reply_size. If this limit is reached, " "connections dispatching pipelines won't squash them."); ABSL_FLAG(bool, always_flush_pipeline, false, "if true will flush pipeline response after each pipeline squashing"); ABSL_FLAG(uint32_t, async_dispatch_quota, 100, "Maximum number of consecutive async dispatch messages to process before either " "yielding to I/O when the pipeline appears empty or forcibly processing a queued " "pipelined command to prevent starvation. Set to 0 to disable this mechanism."); ABSL_FLAG(uint32_t, pipeline_squash_limit, 1 << 30, "Limit on the size of a squashed pipeline. "); ABSL_FLAG(uint32_t, pipeline_wait_batch_usec, 0, "If non-zero, waits for this time for more I/O " " events to come for the connection in case there is only one command in the pipeline. "); ABSL_FLAG(bool, experimental_io_loop_v2, true, "new io loop"); using namespace util; using namespace std; using absl::GetFlag; using base::CycleClock; using nonstd::make_unexpected; namespace facade { namespace { void SendProtocolError(RespSrvParser::Result pres, SinkReplyBuilder* builder) { constexpr string_view res = "-ERR Protocol error: "sv; if (pres == RespSrvParser::BAD_BULKLEN) { builder->SendProtocolError(absl::StrCat(res, "invalid bulk length")); } else if (pres == RespSrvParser::BAD_ARRAYLEN) { builder->SendProtocolError(absl::StrCat(res, "invalid multibulk length")); } else { builder->SendProtocolError(absl::StrCat(res, "parse error")); } } // TODO: to implement correct matcher according to HTTP spec // https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html // One place to find a good implementation would be https://github.com/h2o/picohttpparser bool MatchHttp11Line(string_view line) { return (absl::StartsWith(line, "GET ") || absl::StartsWith(line, "POST ")) && absl::EndsWith(line, "HTTP/1.1"); } void UpdateIoBufCapacity(const io::IoBuf& io_buf, ConnectionStats* stats, absl::FunctionRef f) { const size_t prev_capacity = io_buf.Capacity(); f(); const size_t capacity = io_buf.Capacity(); if (prev_capacity != capacity) { VLOG(2) << "Grown io_buf to " << capacity; stats->read_buf_capacity += capacity - prev_capacity; } } size_t UsedMemoryInternal(const ParsedCommand& msg) { return msg.GetSize() + msg.HeapMemory(); } struct TrafficLogger { // protects agains closing the file while writing or data races when opening the file. // Also, makes sure that LogTraffic are executed atomically. fb2::Mutex mutex; unique_ptr log_file; void ResetLocked(); // Returns true if Write succeeded, false if it failed and the recording should be aborted. bool Write(string_view blob); bool Write(iovec* blobs, size_t len); }; void TrafficLogger::ResetLocked() { if (log_file) { std::ignore = log_file->Close(); log_file.reset(); } } // Returns true if Write succeeded, false if it failed and the recording should be aborted. bool TrafficLogger::Write(string_view blob) { auto ec = log_file->Write(io::Buffer(blob)); if (ec) { LOG(ERROR) << "Error writing to traffic log: " << ec; ResetLocked(); return false; } return true; } bool TrafficLogger::Write(iovec* blobs, size_t len) { auto ec = log_file->Write(blobs, len); if (ec) { LOG(ERROR) << "Error writing to traffic log: " << ec; ResetLocked(); return false; } return true; } thread_local TrafficLogger tl_traffic_logger{}; thread_local base::Histogram* io_req_size_hist = nullptr; thread_local const size_t reply_size_limit = absl::GetFlag(FLAGS_squashed_reply_size_limit); thread_local uint32 pipeline_wait_batch_usec = absl::GetFlag(FLAGS_pipeline_wait_batch_usec); void OpenTrafficLogger(string_view base_path) { unique_lock lk{tl_traffic_logger.mutex}; if (tl_traffic_logger.log_file) return; #ifdef __linux__ // Open file with append mode, without it concurrent fiber writes seem to conflict string path = absl::StrCat( base_path, "-", absl::Dec(ProactorBase::me()->GetPoolIndex(), absl::kZeroPad3), ".bin"); auto file = util::fb2::OpenWrite(path, io::WriteFile::Options{/*.append = */ false}); if (!file) { LOG(ERROR) << "Error opening a file " << path << " for traffic logging: " << file.error(); return; } tl_traffic_logger.log_file = unique_ptr{file.value()}; #else LOG(WARNING) << "Traffic logger is only supported on Linux"; #endif // Write version, incremental numbering :) uint8_t version[1] = {2}; std::ignore = tl_traffic_logger.log_file->Write(version); } void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, ServiceInterface::ContextInfo ci) { string_view cmd = args.Front(); if (absl::EqualsIgnoreCase(cmd, "debug"sv)) return; DVLOG(2) << "Recording " << cmd; char stack_buf[1024]; char* next = stack_buf; // We write id, timestamp, db_index, has_more, num_parts, part_len, part_len, part_len, ... // And then all the part blobs concatenated together. auto write_u32 = [&next](uint32_t i) { absl::little_endian::Store32(next, i); next += 4; }; // id write_u32(id); // timestamp absl::little_endian::Store64(next, absl::GetCurrentTimeNanos()); next += 8; // db_index write_u32(ci.db_index); // has_more, num_parts write_u32(has_more ? 1 : 0); write_u32(uint32_t(args.size())); // Grab the lock and check if the file is still open. lock_guard lk{tl_traffic_logger.mutex}; if (!tl_traffic_logger.log_file) return; // part_len, ... for (auto part : args) { if (size_t(next - stack_buf + 4) > sizeof(stack_buf)) { if (!tl_traffic_logger.Write(string_view{stack_buf, size_t(next - stack_buf)})) { return; } next = stack_buf; } write_u32(part.size()); } // Write the data itself. array blobs; unsigned index = 0; if (next != stack_buf) { blobs[index++] = iovec{.iov_base = stack_buf, .iov_len = size_t(next - stack_buf)}; } for (auto part : args) { if (auto blob_len = part.size(); blob_len > 0) { blobs[index++] = iovec{.iov_base = const_cast(part.data()), .iov_len = blob_len}; if (index >= blobs.size()) { if (!tl_traffic_logger.Write(blobs.data(), blobs.size())) { return; } index = 0; } } } if (index) { tl_traffic_logger.Write(blobs.data(), index); } } constexpr size_t kMinReadSize = 256; const char* kPhaseName[Connection::NUM_PHASES] = {"SETUP", "READ", "PROCESS", "SHUTTING_DOWN", "PRECLOSE"}; // Keeps track of total per-thread sizes of dispatch queues to limit memory taken up by messages // in these queues. struct QueueBackpressure { QueueBackpressure() { } // Block until subscriber memory usage is below limit, can be called from any thread. void EnsureBelowLimit(); // Checks if backpressure should be applied. // 'size' should be the total bytes currently consumed by all connections on this thread. // 'q_len' should be the length of the pipeline queue for the current connection. // // Returns true if EITHER: // 1. Thread-local: memory limit (on all thread's connections) is exceeded (protects server from // OOM). // 2. Per-Connection queue length limit is exceeded (protects against single-client abuse). bool IsPipelineBufferOverLimit(size_t size, uint32_t q_len) const { return size >= (pipeline_buffer_limit) || (q_len > pipeline_queue_max_len); } // Checks if usage has dropped below the limit in at least one criteria. // Used to determine if we should notify waiters. // 'size' should be the total bytes currently consumed by all connections on this thread. // 'q_len' should be the length of the pipeline queue for the current connection. // // Returns true if EITHER: // 1. Thread-Global memory is now under the limit (allows neighbors to wake up). // 2. Per-Connection queue length is now within the limit (allows self to wake up). bool IsPipelineBufferUnderLimit(size_t size, uint32_t q_len) const { return (size < pipeline_buffer_limit) || (q_len <= pipeline_queue_max_len); } // Used by publisher/subscriber actors to make sure we do not publish too many messages // into the queue. Thread-safe to allow safe access in EnsureBelowLimit. util::fb2::EventCount pubsub_ec; atomic_size_t subscriber_bytes = 0; // Used by pipelining/execution fiber to throttle the incoming pipeline messages. // Used together with pipeline_buffer_limit to limit the pipeline usage per thread. util::fb2::CondVarAny pipeline_cnd; size_t publish_buffer_limit = 0; // cached flag publish_buffer_limit size_t pipeline_cache_limit = 0; // cached flag pipeline_cache_limit size_t pipeline_buffer_limit = 0; // cached flag for buffer size in bytes uint32_t pipeline_queue_max_len = 256; // cached flag for pipeline queue max length. }; void QueueBackpressure::EnsureBelowLimit() { pubsub_ec.await( [this] { return subscriber_bytes.load(memory_order_relaxed) <= publish_buffer_limit; }); } // Global array for each io thread to keep track of the total memory usage of the dispatch queues. QueueBackpressure* thread_queue_backpressure = nullptr; QueueBackpressure& GetQueueBackpressure() { DCHECK(thread_queue_backpressure != nullptr); return thread_queue_backpressure[ProactorBase::me()->GetPoolIndex()]; } // A special accessor for accessing thread local ConnectionStats that is robust to fiber-thread // migrations. Compiler optimizations can cache a stale thread local pointer, and not refresh it // after HandleMigrateRequest() is called. This function should be used to force loading // the variable from memory every time, preventing such bugs. ConnectionStats& __attribute__((noinline)) GetLocalConnStats() { // https://stackoverflow.com/a/75622732 asm volatile(""); return tl_facade_stats->conn_stats; } thread_local uint64_t max_busy_read_cycles_cached = 1ULL << 32; thread_local bool always_flush_pipeline_cached = absl::GetFlag(FLAGS_always_flush_pipeline); thread_local uint32_t pipeline_squash_limit_cached = absl::GetFlag(FLAGS_pipeline_squash_limit); } // namespace thread_local vector Connection::pipeline_req_pool_; class PipelineCacheSizeTracker { public: bool CheckAndUpdateWatermark(size_t pipeline_sz) { const auto now = absl::Now(); const auto elapsed = now - last_check_; min_ = std::min(min_, pipeline_sz); if (elapsed < absl::Milliseconds(10)) { return false; } const bool watermark_reached = (min_ > 0); min_ = Limits::max(); last_check_ = absl::Now(); return watermark_reached; } private: using Limits = std::numeric_limits; absl::Time last_check_ = absl::Now(); size_t min_ = Limits::max(); }; thread_local PipelineCacheSizeTracker tl_pipe_cache_sz_tracker; size_t Connection::MessageHandle::UsedMemory() const { struct MessageSize { size_t operator()(const PubMessagePtr& msg) { return sizeof(PubMessage) + (msg->channel.size() + msg->message.size()); } size_t operator()(const MonitorMessage& msg) { return msg.capacity(); } size_t operator()(const MigrationRequestMessage& msg) { return 0; } size_t operator()(const CheckpointMessage& msg) { return 0; // no access to internal type, memory usage negligible } size_t operator()(const InvalidationMessage& msg) { return 0; } }; return sizeof(MessageHandle) + visit(MessageSize{}, this->handle); } bool Connection::MessageHandle::IsReplying() const { return IsPubMsg() || holds_alternative(handle); } struct Connection::AsyncOperations { AsyncOperations(SinkReplyBuilder* b, Connection* me) : builder{b}, self(me) { } void operator()(const PubMessage& msg); void operator()(ParsedCommand& msg); void operator()(const MonitorMessage& msg); void operator()(const MigrationRequestMessage& msg); void operator()(CheckpointMessage msg); void operator()(const InvalidationMessage& msg); template void operator()(unique_ptr& ptr) { operator()(*ptr.get()); } SinkReplyBuilder* builder = nullptr; Connection* self = nullptr; }; void Connection::AsyncOperations::operator()(const MonitorMessage& msg) { RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; rbuilder->SendSimpleString(msg); } void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) { RedisReplyBuilder* rb = static_cast(builder); // Discard stale messages to not break the protocol after exiting "pubsub" mode. // Even after removing all subscriptions, we still can receive messages delayed // by inter-thread dispatches or backpressure. // TODO: filter messages from channels the client unsubscribed from if (self->cntx()->subscriptions == 0 && !base::_in(pub_msg.channel, {"unsubscribe", "punsubscribe"})) return; if (pub_msg.force_unsubscribe) { rb->StartCollection(3, CollectionType::PUSH); rb->SendBulkString("sunsubscribe"); rb->SendBulkString(pub_msg.channel); rb->SendLong(0); self->cntx()->Unsubscribe(pub_msg.channel); return; } unsigned i = 0; array arr; if (pub_msg.pattern.empty()) { arr[i++] = pub_msg.is_sharded ? "smessage" : "message"; } else { arr[i++] = "pmessage"; arr[i++] = pub_msg.pattern; } arr[i++] = pub_msg.channel; arr[i++] = pub_msg.message; rb->SendBulkStrArr(absl::Span{arr.data(), i}, CollectionType::PUSH); } void Connection::AsyncOperations::operator()(ParsedCommand& cmd) { DVLOG(2) << "Dispatching pipeline: " << cmd.Front(); ++self->local_stats_.cmds; self->service_->DispatchCommand(ParsedArgs{cmd}, &cmd, facade::AsyncPreference::ONLY_SYNC); self->last_interaction_ = time(nullptr); self->skip_next_squashing_ = false; } void Connection::AsyncOperations::operator()(const MigrationRequestMessage& msg) { // no-op } void Connection::AsyncOperations::operator()(CheckpointMessage msg) { VLOG(2) << "Decremented checkpoint at " << self->DebugInfo(); msg.bc->Dec(); } void Connection::AsyncOperations::operator()(const InvalidationMessage& msg) { RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; DCHECK(rbuilder->IsResp3()); rbuilder->StartCollection(2, facade::CollectionType::PUSH); rbuilder->SendBulkString("invalidate"); if (msg.invalidate_due_to_flush) { rbuilder->SendNull(); } else { string_view keys[] = {msg.key}; rbuilder->SendBulkStrArr(keys); } } namespace { thread_local absl::flat_hash_map g_libname_ver_map; void UpdateLibNameVerMap(const string& name, const string& ver, int delta) { string key = absl::StrCat(name, ":", ver); uint64_t& val = g_libname_ver_map[key]; val += delta; if (val == 0) { g_libname_ver_map.erase(key); } } } // namespace void Connection::Init(unsigned io_threads) { CHECK(thread_queue_backpressure == nullptr); thread_queue_backpressure = new QueueBackpressure[io_threads]; for (unsigned i = 0; i < io_threads; ++i) { auto& qbp = thread_queue_backpressure[i]; qbp.publish_buffer_limit = GetFlag(FLAGS_publish_buffer_limit); qbp.pipeline_cache_limit = GetFlag(FLAGS_request_cache_limit); qbp.pipeline_buffer_limit = GetFlag(FLAGS_pipeline_buffer_limit); qbp.pipeline_queue_max_len = GetFlag(FLAGS_pipeline_queue_limit); if (qbp.publish_buffer_limit == 0 || qbp.pipeline_cache_limit == 0 || qbp.pipeline_buffer_limit == 0 || qbp.pipeline_queue_max_len == 0) { LOG(ERROR) << "pipeline flag limit is 0"; exit(-1); } } } void Connection::Shutdown() { delete[] thread_queue_backpressure; thread_queue_backpressure = nullptr; } Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service) : io_buf_(kMinReadSize), protocol_(protocol), http_listener_(http_listener), ssl_ctx_(ctx), service_(service), flags_(0) { static atomic_uint32_t next_id{1}; constexpr size_t kReqSz = sizeof(ParsedCommand); static_assert(kReqSz <= 256); // TODO: to move parser initialization to where we initialize the reply builder. switch (protocol) { case Protocol::REDIS: redis_parser_.reset( new RespSrvParser(GetFlag(FLAGS_max_multi_bulk_len), GetFlag(FLAGS_max_bulk_len))); break; case Protocol::MEMCACHE: memcache_parser_ = make_unique(std::min(GetFlag(FLAGS_max_bulk_len), UINT32_MAX)); break; } creation_time_ = time(nullptr); last_interaction_ = creation_time_; id_ = next_id.fetch_add(1, memory_order_relaxed); migration_enabled_ = GetFlag(FLAGS_migrate_connections); // Create shared_ptr with empty value and associate it with `this` pointer (aliasing constructor). // We use it for reference counting and accessing `this` (without managing it). self_ = {make_shared(), this}; #ifdef DFLY_USE_SSL // Increment reference counter so Listener won't free the context while we're // still using it. if (ctx) { SSL_CTX_up_ref(ctx); } #endif UpdateLibNameVerMap(lib_name_, lib_ver_, +1); migration_allowed_to_register_ = false; } Connection::~Connection() { #ifdef DFLY_USE_SSL SSL_CTX_free(ssl_ctx_); #endif UpdateLibNameVerMap(lib_name_, lib_ver_, -1); } bool Connection::IsSending() const { return reply_builder_ && reply_builder_->IsSendActive(); } void Connection::MarkForClose() { if (reply_builder_) { reply_builder_->CloseConnection(); } request_shutdown_ = true; } // Called from Connection::Shutdown() right after socket_->Shutdown call. void Connection::OnShutdown() { VLOG(1) << "Connection::OnShutdown"; BreakOnce(POLLHUP); io_ec_ = make_error_code(errc::connection_aborted); io_event_.notify(); } void Connection::OnPreMigrateThread() { DVLOG(1) << "OnPreMigrateThread " << GetClientId(); CHECK(!cc_->conn_closing); DCHECK(!migration_in_process_); // CancelOnErrorCb is a preemption point, so we make sure the Migration start // is marked beforehand. migration_in_process_ = true; // Mark as not owned by any thread as it going through the dark hole self_.reset(); socket_->CancelOnErrorCb(); DCHECK(!async_fb_.IsJoinable()) << GetClientId(); DecreaseConnStats(); } void Connection::OnPostMigrateThread() { DVLOG(1) << "[" << id_ << "] OnPostMigrateThread"; // Once we migrated, we should rearm OnBreakCb callback. if (breaker_cb_ && socket()->IsOpen()) { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } if (ioloop_v2_ && socket_ && socket_->IsOpen() && migration_allowed_to_register_) { socket_->RegisterOnRecv([this](const FiberSocketBase::RecvNotification& n) { DoReadOnRecv(n); io_event_.notify(); }); } migration_in_process_ = false; self_ = {make_shared(), this}; // Recreate shared_ptr to self. DCHECK(!async_fb_.IsJoinable()); // If someone had sent Async during the migration, we must create async_fb_. if (HasPendingMessages()) { LaunchAsyncFiberIfNeeded(); } IncreaseConnStats(); } void Connection::OnConnectionStart() { SetName(absl::StrCat(id_)); // is null in unit-tests. if (const Listener* lsnr = static_cast(listener()); lsnr) { is_main_ = lsnr->IsMainInterface(); } if (GetFlag(FLAGS_tcp_nodelay) && !socket_->IsUDS()) { int val = 1; int res = setsockopt(socket_->native_handle(), IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val)); DCHECK_EQ(res, 0); } } void Connection::HandleRequests() { VLOG(1) << "[" << id_ << "] HandleRequests"; DCHECK(tl_facade_stats); auto& conn_stats = tl_facade_stats->conn_stats; auto remote_ep = RemoteEndpointStr(); #ifdef DFLY_USE_SSL if (ssl_ctx_) { // Early TLS connection filter // // Before entering the expensive OpenSSL handshake we pre-read the 5-byte TLS Record Layer // header on the raw TCP socket. This serves two purposes: // // 1. Wrong-client detection: // Clients that forgot to enable TLS (e.g. a plaintext Redis client connecting to the TLS // port) will not send a valid TLS Record Layer header. We detect this immediately and // reply with a human-readable "-ERR" message before disconnecting, instead of letting // OpenSSL produce a cryptic handshake failure. // // 2. Zombie-connection rejection: // Zombie connections —— open a TCP socket but never send any data. By demanding at least // the 5-byte header before allocating any SSL state, we drop these cheaply on the raw // socket instead of tying up an OpenSSL context and handshake state machine that will never // complete. // // The pre-read header bytes are injected into the TlsSocket via InitSSL(), which writes them // into OpenSSL's internal BIO so that Accept() can drive the normal handshake from there. // // Reminder: TLS Record Layer header structure (universal across TLS 1.0 – 1.3): // - Byte 0: ContentType (0x16 = Handshake) // - Bytes 1–2: ProtocolVersion. While the minor version varies (0x01 for TLS 1.0, // 0x03 for TLS 1.2/1.3), the major version is consistently 0x03 for all // modern TLS versions. // - Bytes 3–4: Length (uint16 BE) — payload length, max 2^14 = 16384 uint8_t buf[5]; // universal TLS Record Header size is 5 bytes auto read_sz = socket_->Read(io::MutableBytes(buf)); if (!read_sz || *read_sz < sizeof(buf)) { auto msg = read_sz ? absl::StrCat(*read_sz, " < ", sizeof(buf)) : read_sz.error().message(); LOG_EVERY_T(INFO, 1) << "Error reading from peer " << remote_ep << " " << msg << ", socket state: " + dfly::GetSocketInfo(socket_->native_handle()); conn_stats.tls_accept_disconnects++; return; } // Byte 0: ContentType must be 0x16 (Handshake). // Byte 1: major ProtocolVersion — always 0x03 for TLS 1.0 through TLS 1.3. // Byte 2: minor ProtocolVersion — 0x01 (TLS 1.0), 0x02 (TLS 1.1), 0x03 (TLS 1.2/1.3). // SSL 3.0 (0x00) is deprecated (RFC 7568) and rejected. if ((buf[0] != 0x16) || (buf[1] != 0x03) || (buf[2] < 0x01) || (buf[2] > 0x03)) { VLOG(1) << "Bad TLS header " << absl::StrCat(absl::Hex(buf[0], absl::kZeroPad2), absl::Hex(buf[1], absl::kZeroPad2), absl::Hex(buf[2], absl::kZeroPad2)); std::ignore = socket_->Write(io::Buffer("-ERR Bad TLS header, double check " "if you enabled TLS for your client.\r\n")); conn_stats.tls_accept_disconnects++; return; } // Must be done atomically before the preemption point in Accept so that at any // point in time, the socket_ is defined. { FiberAtomicGuard fg; unique_ptr tls_sock = make_unique(std::move(socket_)); tls_sock->InitSSL(ssl_ctx_, buf); SetSocket(tls_sock.release()); } FiberSocketBase::AcceptResult aresult = socket_->Accept(); if (!aresult) { // This can flood the logs -- don't change LOG_EVERY_T(INFO, 1) << "Error handshaking " << aresult.error().message() << ", socket state: " + dfly::GetSocketInfo(socket_->native_handle()); conn_stats.tls_accept_disconnects++; return; } is_tls_ = 1; VLOG(1) << "TLS handshake succeeded"; } #endif io::Result http_res{false}; http_res = CheckForHttpProto(); // We need to check if the socket is open because the server might be // shutting down. During the shutdown process, the server iterates over // the connections of each shard and shuts down their socket. Since the // main listener dispatches the connection into the next proactor, we // allow a schedule order that first shuts down the socket and then calls // this function which triggers a DCHECK on the socket while it tries to // RegisterOnErrorCb. Furthermore, we can get away with one check here // because both Write and Recv internally check if the socket was shut // down and return with an error accordingly. if (http_res && socket_->IsOpen()) { cc_.reset(service_->CreateContext(this)); if (*http_res) { VLOG(1) << "HTTP1.1 identified"; is_http_ = true; HttpConnection http_conn{http_listener_}; http_conn.SetSocket(socket_.get()); http_conn.set_user_data(cc_.get()); // We validate the http request using basic-auth inside HttpConnection::HandleSingleRequest. cc_->authenticated = true; auto ec = http_conn.ParseFromBuffer(io_buf_.InputBuffer()); io_buf_.ConsumeInput(io_buf_.InputLen()); if (!ec) { http_conn.HandleRequests(); } // Release the ownership of the socket from http_conn so it would stay with // this connection. http_conn.ReleaseSocket(); } else { // non-http // ioloop_v2 not supported for TLS & redis connections yet. ioloop_v2_ = GetFlag(FLAGS_experimental_io_loop_v2) && !is_tls_ && protocol_ == Protocol::MEMCACHE; if (breaker_cb_) { socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); }); } switch (protocol_) { case Protocol::REDIS: reply_builder_.reset(new RedisReplyBuilder(socket_.get())); break; case Protocol::MEMCACHE: reply_builder_.reset(new MCReplyBuilder(socket_.get())); break; default: break; } parsed_cmd_ = CreateParsedCommand(); ConnectionFlow(); socket_->CancelOnErrorCb(); // noop if nothing is registered. VLOG(1) << "Closed connection for peer " << GetClientInfo(fb2::ProactorBase::me()->GetPoolIndex()); reply_builder_.reset(); DestroyParsedQueue(); } cc_.reset(); } } unsigned Connection::GetSendWaitTimeSec() const { if (reply_builder_ && reply_builder_->IsSendActive()) { return (util::fb2::ProactorBase::GetMonotonicTimeNs() - reply_builder_->GetLastSendTimeNs()) / 1'000'000'000; } return 0; } void Connection::RegisterBreakHook(BreakerCb breaker_cb) { breaker_cb_ = std::move(breaker_cb); } void Connection::FlushReplies() { // NOLINT must not be const due to flush side effect DCHECK(reply_builder_); reply_builder_->Flush(); } pair Connection::GetClientInfoBeforeAfterTid() const { if (!socket_) { LOG(DFATAL) << "unexpected null socket_ " << " phase " << unsigned(phase_) << ", is_http: " << unsigned(is_http_); return {}; } CHECK_LT(unsigned(phase_), NUM_PHASES); string before; auto le = LocalBindStr(); auto re = RemoteEndpointStr(); time_t now = time(nullptr); int cpu = 0; socklen_t len = sizeof(cpu); getsockopt(socket_->native_handle(), SOL_SOCKET, SO_INCOMING_CPU, &cpu, &len); #ifdef __APPLE__ int my_cpu_id = -1; // __APPLE__ does not have sched_getcpu() #else int my_cpu_id = sched_getcpu(); #endif static constexpr string_view PHASE_NAMES[] = {"setup", "readsock", "process", "shutting_down", "preclose"}; static_assert(NUM_PHASES == ABSL_ARRAYSIZE(PHASE_NAMES)); static_assert(PHASE_NAMES[SHUTTING_DOWN] == "shutting_down"); absl::StrAppend(&before, "id=", id_, " addr=", re, " laddr=", le); absl::StrAppend(&before, " fd=", socket_->native_handle()); if (is_http_) { absl::StrAppend(&before, " http=true"); } else { absl::StrAppend(&before, " name=", name_); } #ifdef DFLY_USE_SSL if (is_tls_) { tls::TlsSocket* tls_sock = static_cast(socket_.get()); string_view proto_version = SSL_get_version(tls_sock->ssl_handle()); const SSL_CIPHER* cipher = SSL_get_current_cipher(tls_sock->ssl_handle()); absl::StrAppend(&before, " tls=", proto_version, "|", SSL_CIPHER_get_name(cipher)); } #endif string after; absl::StrAppend(&after, " irqmatch=", int(cpu == my_cpu_id)); if (parsed_cmd_q_len_ > 0) { absl::StrAppend(&after, " pipeline=", parsed_cmd_q_len_); absl::StrAppend(&after, " pbuf=", parsed_cmd_q_bytes_); } absl::StrAppend(&after, " age=", now - creation_time_, " idle=", now - last_interaction_); string_view phase_name = PHASE_NAMES[phase_]; absl::StrAppend(&after, " tot-cmds=", local_stats_.cmds, " tot-net-in=", local_stats_.net_bytes_in, " tot-read-calls=", local_stats_.read_cnt, " tot-dispatches=", local_stats_.dispatch_entries_added); if (cc_) { string cc_info = service_->GetContextInfo(cc_.get()).Format(); // reply_builder_ may be null if the connection is in the setup phase, for example. if (reply_builder_ && reply_builder_->IsSendActive()) phase_name = "send"; absl::StrAppend(&after, " ", cc_info); } absl::StrAppend(&after, " phase=", phase_name); if (IsSending()) { absl::StrAppend(&before, " send-wait-time=", GetSendWaitTimeSec()); } return {std::move(before), std::move(after)}; } string Connection::GetClientInfo(unsigned thread_id) const { auto [before, after] = GetClientInfoBeforeAfterTid(); absl::StrAppend(&before, " tid=", thread_id); absl::StrAppend(&before, after); absl::StrAppend(&before, " lib-name=", lib_name_, " lib-ver=", lib_ver_); return before; } string Connection::GetClientInfo() const { auto [before, after] = GetClientInfoBeforeAfterTid(); absl::StrAppend(&before, after); // The following are dummy fields and users should not rely on those unless // we decide to implement them. // This is only done because the redis pyclient parser for the field "client-info" // for the command ACL LOG hardcodes the expected values. This behaviour does not // conform to the actual expected values, since it's missing half of them. // That is, even for redis-server, issuing an ACL LOG command via redis-cli and the pyclient // will return different results! For example, the fields: // addr=127.0.0.1:57275 // laddr=127.0.0.1:6379 // are missing from the pyclient. absl::StrAppend(&before, " qbuf=0 ", "qbuf-free=0 ", "obl=0 ", "argv-mem=0 "); absl::StrAppend(&before, "oll=0 ", "omem=0 ", "tot-mem=0 ", "multi=0 "); absl::StrAppend(&before, "psub=0 ", "sub=0"); return before; } uint32_t Connection::GetClientId() const { return id_; } bool Connection::IsPrivileged() const { return static_cast(listener())->IsPrivilegedInterface(); } bool Connection::IsMain() const { return is_main_; } bool Connection::IsMainOrMemcache() const { return is_main_ || protocol_ == Protocol::MEMCACHE; } void Connection::SetName(string name) { util::ThisFiber::SetName(absl::StrCat("DflyConn_", name)); name_ = std::move(name); } void Connection::SetLibName(string name) { UpdateLibNameVerMap(lib_name_, lib_ver_, -1); lib_name_ = std::move(name); UpdateLibNameVerMap(lib_name_, lib_ver_, +1); } void Connection::SetLibVersion(string version) { UpdateLibNameVerMap(lib_name_, lib_ver_, -1); lib_ver_ = std::move(version); UpdateLibNameVerMap(lib_name_, lib_ver_, +1); } const absl::flat_hash_map& Connection::GetLibStatsTL() { return g_libname_ver_map; } io::Result Connection::CheckForHttpProto() { if (!IsPrivileged() && !IsMain()) { return false; } const bool primary_port_enabled = GetFlag(FLAGS_primary_port_http_enabled); if (!primary_port_enabled && !IsPrivileged()) { return false; } size_t last_len = 0; auto* peer = socket_.get(); auto& conn_stats = tl_facade_stats->conn_stats; do { auto buf = io_buf_.AppendBuffer(); DCHECK(!buf.empty()); ::io::Result recv_sz = peer->Recv(buf); if (!recv_sz) { return make_unexpected(recv_sz.error()); } if (recv_sz == 0) { // Peer closed connection. return false; } io_buf_.CommitWrite(*recv_sz); string_view ib = io::View(io_buf_.InputBuffer()); if (ib.size() >= 2 && ib[0] == 22 && ib[1] == 3) { // We matched the TLS handshake raw data, which means "peer" is a TCP socket. // Reject the connection. return make_unexpected(make_error_code(errc::protocol_not_supported)); } ib = ib.substr(last_len); size_t pos = ib.find('\n'); if (pos != string_view::npos) { ib = io::View(io_buf_.InputBuffer().first(last_len + pos)); if (ib.size() < 10 || ib.back() != '\r') return false; ib.remove_suffix(1); return MatchHttp11Line(ib); } last_len = io_buf_.InputLen(); UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.EnsureCapacity(128); }); } while (last_len < 1024); return false; } void Connection::ConnectionFlow() { DCHECK(reply_builder_); auto& conn_stats = tl_facade_stats->conn_stats; // Register the new connection with the thread-local statistics. // At this point (connection birth), local queue stats/luggage are 0, // so only connection counts and buffer capacities are incremented. IncreaseConnStats(); ++conn_stats.conn_received_cnt; ++local_stats_.read_cnt; local_stats_.net_bytes_in += io_buf_.InputLen(); ParserStatus parse_status = OK; // At the start we read from the socket to determine the HTTP/Memstore protocol. // Therefore we may already have some data in the buffer. if (io_buf_.InputLen() > 0) { phase_ = PROCESS; if (redis_parser_) { parse_status = ParseRedis(10000); } else { DCHECK(memcache_parser_); parse_status = ParseLoop(); } } error_code ec = reply_builder_->GetError(); // Main loop. if (parse_status != ERROR && !ec) { UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.EnsureCapacity(64); }); variant res; if (ioloop_v2_) { // Everything above the IoLoopV2 is fiber blocking. A connection can migrate before // it reaches here and will cause a double RegisterOnRecv check fail. To avoid this, // a migration shall only call RegisterOnRecv if it reached the main IoLoopV2 below. migration_allowed_to_register_ = true; res = IoLoopV2(); } else { res = IoLoop(); } if (holds_alternative(res)) { ec = get(res); } else { parse_status = get(res); } } // After the client disconnected. cc_->conn_closing = true; // Signal dispatch to close. cnd_.notify_one(); phase_ = SHUTTING_DOWN; VLOG(2) << "Before dispatch_fb.join()"; async_fb_.JoinIfNeeded(); VLOG(2) << "After dispatch_fb.join()"; phase_ = PRECLOSE; ClearPipelinedMessages(); DCHECK(!HasPendingMessages()); service_->OnConnectionClose(cc_.get()); // We have already cleared the queues above (ClearPipelinedMessages), so local queue stats // (dispatch_q_bytes_, etc.) represent 0 usage. DecreaseConnStats will safely subtract 0 for those // stats, while correctly removing this connection from the global connection counts and buffer // capacity tracking. DecreaseConnStats(); if (ioloop_v2_) { socket_->ResetOnRecvHook(); } // We wait for dispatch_fb to finish writing the previous replies before replying to the last // offending request. if (parse_status == ERROR) { VLOG(1) << "Error parser status " << parser_error_; if (redis_parser_) { SendProtocolError(RespSrvParser::Result(parser_error_), reply_builder_.get()); } else { DCHECK(memcache_parser_); reply_builder_->SendProtocolError("bad command line format"); } // Shut down the servers side of the socket to send a FIN to the client // then keep draining the socket (discarding any received data) until // the client closes the connection. // // Otherwise the clients write could fail (or block), so they would never // read the above protocol error (see issue #1327). // TODO: we have a bug that can potentially deadlock the code below. // If the socket does not close the socket on the other side, the while loop will never finish. // to reproduce: nc localhost 6379 and then run invalid sequence: *1 *1 error_code ec2 = socket_->Shutdown(SHUT_WR); LOG_IF(WARNING, ec2) << "Could not shutdown socket " << ec2; while (!ec2) { // Discard any received data. io_buf_.Clear(); auto recv_sz = socket_->Recv(io_buf_.AppendBuffer()); if (!recv_sz || *recv_sz == 0) { break; // Peer closed connection. } } } if (ec && !FiberSocketBase::IsConnClosed(ec)) { string conn_info = service_->GetContextInfo(cc_.get()).Format(); LOG_EVERY_T(WARNING, 1) << "Socket error for connection " << conn_info << " " << GetName() << " during phase " << kPhaseName[phase_] << " : " << ec << " " << ec.message(); } } void Connection::DispatchSingle(bool has_more, absl::FunctionRef invoke_cb, absl::FunctionRef enqueue_cmd_cb) { // Unconditional return when closing: // else, non-throttled connections skip the check below and enqueue data even if they are closing. // No one will read that data anyway. if (cc_->conn_closing) return; auto can_dispatch_sync_fn = [this]() { return !cc_->async_dispatch && !HasPendingMessages() && (cc_->subscriptions == 0); }; bool optimize_for_async = has_more; bool can_dispatch_sync = can_dispatch_sync_fn(); QueueBackpressure& qbp = GetQueueBackpressure(); ConnectionStats* conn_stats = &tl_facade_stats->conn_stats; if ((optimize_for_async || !can_dispatch_sync) && qbp.IsPipelineBufferOverLimit(conn_stats->pipeline_queue_bytes, parsed_cmd_q_len_)) { conn_stats->pipeline_throttle_count++; LOG_EVERY_T(WARNING, 10) << "Pipeline buffer over limit." << ", Thread pipeline_queue_bytes: " << conn_stats->pipeline_queue_bytes << ", Thread pipeline_queue_entries: " << conn_stats->pipeline_queue_entries << ", Connection parsed_cmd_q_bytes_: " << parsed_cmd_q_bytes_ << ", Connection parsed commands queue size: " << parsed_cmd_q_len_ << ", consider increasing pipeline_buffer_limit/pipeline_queue_limit"; fb2::NoOpLock noop; qbp.pipeline_cnd.wait(noop, [this, &qbp, &can_dispatch_sync_fn] { // Wait until at least one is true: // 1) Connection is closing. // 2) Can dispatch synchronously. // 3) Not over limits (for an async dispatch). bool can_dispatch_sync = can_dispatch_sync_fn(); if (can_dispatch_sync) return true; bool over_limits = qbp.IsPipelineBufferOverLimit( tl_facade_stats->conn_stats.pipeline_queue_bytes, parsed_cmd_q_len_); return !over_limits || cc_->conn_closing; }); // prefer synchronous dispatching to save memory. optimize_for_async = false; last_interaction_ = time(nullptr); } // Avoid sync dispatch if we can interleave with an ongoing async dispatch. can_dispatch_sync = can_dispatch_sync_fn(); // Dispatch async if we're handling a pipeline or if we can't dispatch sync. if (optimize_for_async || !can_dispatch_sync) { LaunchAsyncFiberIfNeeded(); enqueue_cmd_cb(); } else { ShrinkPipelinePool(); // Gradually release pipeline request pool. { ++local_stats_.cmds; cc_->sync_dispatch = true; invoke_cb(); cc_->sync_dispatch = false; } last_interaction_ = time(nullptr); // We might have blocked the dispatch queue from processing, wake it up. if (HasPendingMessages()) cnd_.notify_one(); } } Connection::ParserStatus Connection::ParseRedis(unsigned max_busy_cycles, bool enqueue_only) { uint32_t consumed = 0; RespSrvParser::Result result = RespSrvParser::OK; auto dispatch_sync = [this] { service_->DispatchCommand(ParsedArgs{*parsed_cmd_}, parsed_cmd_, facade::AsyncPreference::ONLY_SYNC); }; auto dispatch_async = [this]() -> void { PipelineMessagePtr ptr = GetFromPoolOrCreate(); // parsed_cmd_ holds the parsed arguments. Move it to 'cmd' to be enqueued and set it with a new // empty ParsedCommand for the next parse. auto* cmd = std::exchange(parsed_cmd_, ptr.release()); EnqueueParsedCommand(cmd); }; io::Bytes read_buffer = io_buf_.InputBuffer(); // Keep track of total bytes consumed/parsed. The do/while{} loop below preempts, // and InputBuffer() size might change between preemption points. There is a corner case, // that ConsumeInput() will strip a portion of the request which makes the test_publish_stuck // test fail. // TODO(kostas): follow up on this size_t total_consumed = 0; do { DCHECK(parsed_cmd_); result = redis_parser_->Parse(read_buffer, &consumed, parsed_cmd_); request_consumed_bytes_ += consumed; total_consumed += consumed; if (result == RespSrvParser::OK) { DCHECK(!parsed_cmd_->empty()); DVLOG(2) << "Got Args with first token " << parsed_cmd_->Front(); if (io_req_size_hist) io_req_size_hist->Add(request_consumed_bytes_); request_consumed_bytes_ = 0; bool has_more = consumed < read_buffer.size(); if (tl_traffic_logger.log_file && IsMain() /* log only on the main interface */) { LogTraffic(id_, has_more, *parsed_cmd_, service_->GetContextInfo(cc_.get())); } if (enqueue_only) dispatch_async(); else DispatchSingle(has_more, dispatch_sync, dispatch_async); } if (result != RespSrvParser::OK && result != RespSrvParser::INPUT_PENDING) { // We do not expect that a replica sends an invalid command so we log if it happens. LOG_IF(WARNING, cntx()->replica_conn) << "Redis parser error: " << result << " during parse: " << io::View(read_buffer); } read_buffer.remove_prefix(consumed); // We must yield from time to time to allow other fibers to run. // Specifically, if a client sends a huge chunk of data resulting in a very long pipeline, // we want to yield to allow AsyncFiber to actually execute on the pending pipeline. if (ThisFiber::GetRunningTimeCycles() > max_busy_cycles) { GetLocalConnStats().num_read_yields++; ThisFiber::Yield(); } } while (RespSrvParser::OK == result && read_buffer.size() > 0 && !reply_builder_->GetError()); io_buf_.ConsumeInput(total_consumed); parser_error_ = result; if (result == RespSrvParser::OK) return OK; if (result == RespSrvParser::INPUT_PENDING) { DCHECK_EQ(read_buffer.size(), 0u); return NEED_MORE; } VLOG(1) << "Parser error " << result; return ERROR; } auto Connection::ParseLoop() -> ParserStatus { auto parse_func = protocol_ == Protocol::MEMCACHE ? &Connection::ParseMCBatch : &Connection::ParseRedisBatch; bool commands_parsed = false; do { commands_parsed = (this->*parse_func)(); if (!ExecuteBatch()) return ERROR; if (!ReplyBatch()) return ERROR; } while (commands_parsed && io_buf_.InputLen() > 0); return commands_parsed ? OK : NEED_MORE; } void Connection::OnBreakCb(int32_t mask) { if (mask <= 0) return; // we cancelled the poller, which means we do not need to break from anything. if (!cc_) { LOG(ERROR) << "Unexpected event " << mask; return; } DCHECK(reply_builder_) << "[" << id_ << "] " << phase_ << " " << migration_in_process_; VLOG(1) << "[" << id_ << "] Got event " << mask << " " << phase_ << " " << reply_builder_->IsSendActive() << " " << reply_builder_->GetError(); cc_->conn_closing = true; BreakOnce(mask); cnd_.notify_one(); // Notify dispatch fiber. } void Connection::HandleMigrateRequest() { if (cc_->conn_closing || !migration_request_) { return; } ProactorBase* dest = migration_request_; if (async_fb_.IsJoinable()) { SendAsync({MigrationRequestMessage{}}); async_fb_.Join(); } // We don't support migrating with subscriptions as it would require moving thread local // handles. We can't check above, as the queue might have contained a subscribe request. if (cc_->subscriptions == 0) { // RegisterOnErrorCb might be called on POLLHUP and the join above is a preemption point. // So, it could be the case that after this fiber wakes up the connection might be closing. if (cc_->conn_closing) { return; } tl_facade_stats->conn_stats.num_migrations++; migration_request_ = nullptr; // We need to return early as the socket is closing and IoLoop will clean up. // The reason that this is true is because of the following DCHECK DCHECK(!async_fb_.IsJoinable()); // which can never trigger since we Joined on the async_fb_ above and we are // atomic in respect to our proactor meaning that no other fiber will // launch the DispatchFiber. std::ignore = !this->Migrate(dest); } } io::Result Connection::HandleRecvSocket() { phase_ = READ_SOCKET; auto& conn_stats = tl_facade_stats->conn_stats; io::MutableBytes append_buf = io_buf_.AppendBuffer(); DCHECK(!append_buf.empty()); ::io::Result recv_sz = socket_->Recv(append_buf); last_interaction_ = time(nullptr); // In case the socket was closed orderly, we get 0 bytes read. if (recv_sz && *recv_sz) { size_t commit_sz = *recv_sz; io_buf_.CommitWrite(commit_sz); conn_stats.io_read_bytes += commit_sz; local_stats_.net_bytes_in += commit_sz; ++conn_stats.io_read_cnt; ++local_stats_.read_cnt; } return recv_sz; } variant Connection::IoLoop() { error_code ec; ParserStatus parse_status = OK; size_t max_iobfuf_len = GetFlag(FLAGS_max_client_iobuf_len); auto* peer = socket_.get(); recv_buf_.res_len = 0; do { HandleMigrateRequest(); auto recv_sz = HandleRecvSocket(); if (!recv_sz) { LOG_IF(WARNING, cntx()->replica_conn) << "HandleRecvSocket() error: " << recv_sz.error(); return recv_sz.error(); } if (*recv_sz == 0) { break; } phase_ = PROCESS; bool is_iobuf_full = io_buf_.AppendLen() == 0; if (redis_parser_) { parse_status = ParseRedis(max_busy_read_cycles_cached); } else { DCHECK(memcache_parser_); parse_status = ParseLoop(); } if (reply_builder_->GetError()) { return reply_builder_->GetError(); } if (parse_status == NEED_MORE) { parse_status = OK; size_t capacity = io_buf_.Capacity(); if (capacity < max_iobfuf_len) { size_t parser_hint = 0; if (redis_parser_) parser_hint = redis_parser_->parselen_hint(); // Could be done for MC as well. // If we got a partial request and we managed to parse its // length, make sure we have space to store it instead of // increasing space incrementally. // (Note: The buffer object is only working in power-of-2 sizes, // so there's no danger of accidental O(n^2) behavior.) if (parser_hint > capacity) { auto& conn_stats = GetLocalConnStats(); UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.Reserve(std::min(max_iobfuf_len, parser_hint)); }); } // If we got a partial request because iobuf was full, grow it up to // a reasonable limit to save on Recv() calls. if (is_iobuf_full && capacity < max_iobfuf_len / 2) { auto& conn_stats = GetLocalConnStats(); // Last io used most of the io_buf to the end. UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.Reserve(capacity * 2); // Valid growth range. }); } if (io_buf_.AppendLen() == 0U) { // it can happen with memcached but not for RedisParser, because RedisParser fully // consumes the passed buffer LOG_EVERY_T(WARNING, 10) << "Maximum io_buf length reached, consider to increase max_client_iobuf_len flag"; } } } else if (parse_status != OK) { break; } } while (peer->IsOpen()); return parse_status; } bool Connection::ShouldEndAsyncFiber(const MessageHandle& msg) { if (!holds_alternative(msg.handle)) { return false; } if (!HasPendingMessages()) { // Migration requests means we should terminate this function (and allow the fiber to // join), so that we can re-launch the fiber in the new thread. // We intentionally return and not break in order to keep the connection open. return true; } // There shouldn't be any other migration requests in the queue, but it's worth checking // as otherwise it would lead to an endless loop. bool has_migration_req = any_of(dispatch_q_.begin(), dispatch_q_.end(), [](const MessageHandle& msg) { return holds_alternative(msg.handle); }); if (!has_migration_req) { SendAsync({MigrationRequestMessage{}}); } return false; } void Connection::SquashPipeline() { DCHECK_EQ(GetPendingMessageCount(), parsed_cmd_q_len_); DCHECK_EQ(reply_builder_->GetProtocol(), Protocol::REDIS); // Only Redis is supported. unsigned pipeline_count = std::min(parsed_cmd_q_len_, pipeline_squash_limit_cached); auto& conn_stats = tl_facade_stats->conn_stats; uint64_t start = CycleClock::Now(); // Define a "Feeder" Lambda // This lambda advances a temporary pointer exec_cmd_ptr to feed the execution engine. // We do not modify parsed_to_execute_ yet, in case execution throws/fails. auto exec_cmd_ptr{parsed_to_execute_}; auto get_next_fn = [&exec_cmd_ptr]() mutable -> ParsedArgs { DCHECK(exec_cmd_ptr); return ParsedArgs{*std::exchange(exec_cmd_ptr, exec_cmd_ptr->next)}; }; // async_dispatch is a guard to prevent concurrent writes into reply_builder_, hence // it must guard the Flush() as well. cc_->async_dispatch = true; DispatchManyResult result = service_->DispatchManyCommands(get_next_fn, pipeline_count, reply_builder_.get(), cc_.get()); local_stats_.cmds += result.processed; last_interaction_ = time(nullptr); uint32_t num_dispatched_cmds = result.processed; uint64_t flush_start_cycle_cnt = CycleClock::Now(); // // TODO: to investigate if always flushing will improve P99 latency because otherwise we // wait for the next batch to finish before fully flushing the current response. if (parsed_cmd_q_len_ == pipeline_count || always_flush_pipeline_cached) { // Flush if no new commands appeared reply_builder_->Flush(); reply_builder_->SetBatchMode(false); // in case the next dispatch is sync } else { conn_stats.skip_pipeline_flushing++; } cc_->async_dispatch = false; if (result.account_in_stats) { conn_stats.pipeline_dispatch_calls++; conn_stats.pipeline_dispatch_commands += num_dispatched_cmds; conn_stats.pipeline_dispatch_flush_usec += CycleClock::ToUsec(CycleClock::Now() - flush_start_cycle_cnt); } auto* current{parsed_head_}; for (size_t i = 0; (i < num_dispatched_cmds) && current; ++i) { auto* next{current->next}; if (result.account_in_stats) { conn_stats.pipelined_wait_latency += CycleClock::ToUsec(start - current->parsed_cycle); } ReleaseParsedCommand(current, result.account_in_stats /* is_pipelined */); current = next; } parsed_head_ = current; if (!parsed_head_) { parsed_tail_ = nullptr; } parsed_to_execute_ = parsed_head_; // If interrupted due to pause, fall back to regular dispatch skip_next_squashing_ = (num_dispatched_cmds != pipeline_count); } void Connection::ClearPipelinedMessages() { AsyncOperations async_op{reply_builder_.get(), this}; // First, clear dispatch queue // Recycle messages even from disconnecting client to keep properly track of memory stats // As well as to avoid pubsub backpressure leakage. for (auto& msg : dispatch_q_) { FiberAtomicGuard guard; // don't suspend when concluding to avoid getting new messages if (msg.IsCheckPoint()) visit(async_op, msg.handle); // to not miss checkpoints UpdateDispatchStats(msg, false /* subtract */); } dispatch_q_.clear(); // Second, drain the pending pipeline queue: release memory and update stats without executing // commands. while (parsed_head_) { auto* curr{parsed_head_}; parsed_head_ = parsed_head_->next; // Wait for the in-flight async commands processing by consumer to finish before recycling. if (curr->IsDeferredReply() && !curr->CanReply()) { curr->Blocker()->Wait(); } ReleaseParsedCommand(curr, false); } DCHECK_EQ(parsed_cmd_q_len_, 0u); DCHECK_EQ(parsed_cmd_q_bytes_, 0u); parsed_tail_ = nullptr; parsed_to_execute_ = nullptr; QueueBackpressure& qbp = GetQueueBackpressure(); qbp.pipeline_cnd.notify_all(); qbp.pubsub_ec.notifyAll(); } string Connection::DebugInfo() const { string info = "{"; absl::StrAppend(&info, "id=", id_, ", "); absl::StrAppend(&info, "phase=", phase_, ", "); if (cc_) { // In some rare cases cc_ can be null, see https://github.com/dragonflydb/dragonfly/pull/3873 absl::StrAppend(&info, "dispatch(s/a)=", cc_->sync_dispatch, " ", cc_->async_dispatch, ", "); absl::StrAppend(&info, "closing=", cc_->conn_closing, ", "); } absl::StrAppend(&info, "df:joinable=", async_fb_.IsJoinable(), ", "); absl::StrAppend(&info, "dq:size=", dispatch_q_.size(), ", "); absl::StrAppend(&info, "pq:parsed_cmd_q_len=", parsed_cmd_q_len_, ", "); absl::StrAppend(&info, "pq:is_empty=", (parsed_head_ == nullptr), ", "); if (cc_) { absl::StrAppend(&info, "state="); if (cc_->paused) absl::StrAppend(&info, "p"); if (cc_->blocked) absl::StrAppend(&info, "b"); } time_t now = time(nullptr); absl::StrAppend(&info, " age=", now - creation_time_, " idle=", now - last_interaction_, "}"); return info; } bool Connection::ProcessAdminMessage(MessageHandle* msg, AsyncOperations* async_op) { // Guard: Automatically subtract stats when this scope exits (via return or exception). absl::Cleanup stats_guard = [this, msg] { UpdateDispatchStats(*msg, false /* subtract */); }; bool is_replying = msg->IsReplying(); // Pre-execution Flush // If this is a non-replying control message (e.g. Migration) and it's the last item, // we MUST flush the buffer now. Otherwise, previous pipelined replies might wait // indefinitely or be lost if the fiber terminates. if (!HasPendingMessages() && !is_replying) { reply_builder_->Flush(); } // Fiber Termination Check if (ShouldEndAsyncFiber(*msg)) { CHECK(!HasPendingMessages()) << DebugInfo(); GetQueueBackpressure().pipeline_cnd.notify_all(); return true; // Signal to terminate AsyncFiber } // Execution auto replies_recorded_before = reply_builder_->RepliesRecorded(); cc_->async_dispatch = true; std::visit(*async_op, msg->handle); cc_->async_dispatch = false; // Post-execution Flush // We force a flush If the message is supposed to reply (e.g. PubSub) but didn't write to the // buffer (e.g. subscription filter), and the queues are empty. if (!HasPendingMessages() && is_replying && (replies_recorded_before == reply_builder_->RepliesRecorded())) { reply_builder_->Flush(); } return false; } void Connection::ProcessPipelineCommand() { DCHECK(parsed_head_ && parsed_to_execute_) << DebugInfo(); auto* cmd = parsed_to_execute_; parsed_to_execute_ = cmd->next; parsed_head_ = parsed_to_execute_; if (!parsed_head_) { parsed_tail_ = nullptr; } tl_facade_stats->conn_stats.pipelined_wait_latency += CycleClock::ToUsec(CycleClock::Now() - cmd->parsed_cycle); cc_->async_dispatch = true; local_stats_.cmds++; service_->DispatchCommand(ParsedArgs{*cmd}, cmd, facade::AsyncPreference::ONLY_SYNC); last_interaction_ = time(nullptr); skip_next_squashing_ = false; cc_->async_dispatch = false; ReleaseParsedCommand(cmd, true); // If we drained the pipeline and no admin messages are waiting, flush. if (!HasPendingMessages()) { reply_builder_->Flush(); } } // AsyncFiber acts as the consumer for all asynchronous connection tasks. // // It operates on a producer-consumer model where the InputLoop parses socket data // and routes it into two distinct streams: // 1. Data Path: Pipelined commands are queued in a Parsed Commands linked list // 2. Control Path: Admin events (Migrations, Checkpoints, PubSub) use a deque (dispatch_q_) // // AsyncFiber drains these queues according to system prioritization, ensuring // high-priority events are handled promptly while preventing priority inversion // during thread migrations. For simple requests, the InputLoop may bypass this // fiber and dispatch synchronously to minimize latency. void Connection::AsyncFiber() { ThisFiber::SetName("AsyncFiber"); AsyncOperations async_op{reply_builder_.get(), this}; size_t squashing_threshold = GetFlag(FLAGS_pipeline_squash); uint64_t prev_epoch = fb2::FiberSwitchEpoch(); fb2::NoOpLock noop_lk; QueueBackpressure& qbp = GetQueueBackpressure(); auto& conn_stats = tl_facade_stats->conn_stats; uint32_t dispatch_q_cmd_processed = 0; uint32_t async_dispatch_quota = GetFlag(FLAGS_async_dispatch_quota); while (!reply_builder_->GetError()) { DCHECK_EQ(socket()->proactor(), ProactorBase::me()); cnd_.wait(noop_lk, [this] { if (cc_->conn_closing) return true; // If we are currently executing a synchronous dispatch (e.g. inside IoLoop), // we must wait until it finishes to avoid race conditions. if (cc_->sync_dispatch) return false; // For Memcache, we ONLY wake up for Admin messages (dispatch_q_) as we process // parsed_head_ in the connection fiber. For RESP, we wake up for both queues. if (protocol_ == Protocol::MEMCACHE) { return !dispatch_q_.empty(); } return HasPendingMessages(); }); if (cc_->conn_closing) break; // We really want to have batching in the builder if possible. This is especially // critical in situations where Nagle's algorithm can introduce unwanted high // latencies. However we can only batch if we're sure that there are more commands // on the way that will trigger a flush. To know if there are, we sometimes yield before // executing the last command in the queue and let the producer fiber push more commands if it // wants to. // As an optimization, we only yield if the fiber was not suspended since the last dispatch. uint64_t cur_epoch = fb2::FiberSwitchEpoch(); if ((GetPendingMessageCount() == 1) && (cur_epoch == prev_epoch)) { if (pipeline_wait_batch_usec > 0) { ThisFiber::SleepFor(chrono::microseconds(pipeline_wait_batch_usec)); } else { ThisFiber::Yield(); } DVLOG(2) << "After yielding to producer, parsed_cmd_q_len_=" << parsed_cmd_q_len_ << " dispatch_q size=" << dispatch_q_.size(); if (cc_->conn_closing) break; } prev_epoch = cur_epoch; reply_builder_->SetBatchMode(GetPendingMessageCount() > 1); bool subscriber_over_limit = conn_stats.dispatch_queue_subscriber_bytes >= qbp.publish_buffer_limit; // The below if/else conditionally choose between 3 message processing policies: // 1. Pipeline squashing // 2. Process pipeline queue // 3. Process admin queue // // Special case: if the dispatch queue accumulated a big number of commands, // we can try to squash them // It is only enabled if the threshold is reached and the whole dispatch queue // consists only of commands (no pubsub or monitor messages) bool squashing_enabled = squashing_threshold > 0; bool threshold_reached = parsed_cmd_q_len_ > squashing_threshold; if (squashing_enabled && threshold_reached && dispatch_q_.empty() && !skip_next_squashing_ && !IsReplySizeOverLimit()) { // 1. Pipeline squashing SquashPipeline(); dispatch_q_cmd_processed = 0; } else { MessageHandle msg; // If the front message is a Migration Request, but we still have pipeline data // (parsed_head_), we must block the migration and process the pipeline messages first. bool is_migration_req = !dispatch_q_.empty() && std::holds_alternative(dispatch_q_.front().handle); // If the quota is reached but the pipeline appears empty, we must yield to the IoLoop // (producer). This allows the discovery and parsing of commands potentially sitting in the // TCP buffer. Without this yield, AsyncFiber would monopolize the CPU, starving the IoLoop // and remaining blind to pending pipeline data. bool quota_reached = (async_dispatch_quota > 0) && (dispatch_q_cmd_processed >= async_dispatch_quota); if (quota_reached && (parsed_head_ == nullptr)) { ThisFiber::Yield(); // If it is STILL empty after IoLoop got a chance to run, the client hasn't sent anything. // Reset the counter so we don't yield on every single loop. if (parsed_head_ == nullptr) { dispatch_q_cmd_processed = 0; } } // We prioritize pipeline execution over the admin queue in two distinct cases (Pipeline queue // must be non-empty for both cases): // 1. A migration is requested (Redis only), but we must drain the existing // pipeline first. // 2. The dispatch quota was reached, forcing a pipeline execution to prevent // starvation. bool prefer_pipeline_execution = false; if (parsed_head_ != nullptr) { prefer_pipeline_execution = quota_reached || (is_migration_req && (protocol_ == Protocol::REDIS)); } if (dispatch_q_.empty() || prefer_pipeline_execution) { // 2. Process pipeline Queue VLOG_IF(1, prefer_pipeline_execution) << "[" << id_ << "] Preferring pipeline execution over admin queue. " << "Migration requested: " << is_migration_req << ", dispatch quota reached: " << quota_reached << ", async_dispatch_quota: " << async_dispatch_quota << ", dispatch_q_cmd_processed: " << dispatch_q_cmd_processed; ProcessPipelineCommand(); dispatch_q_cmd_processed = 0; } else { // 3. Process admin Queue msg = std::move(dispatch_q_.front()); dispatch_q_.pop_front(); dispatch_q_cmd_processed++; // Execute and check if we need to terminate the fiber if (ProcessAdminMessage(&msg, &async_op)) { return; // don't set conn closing flag } } } // Notify waiters if backpressure constraints are relieved. // 1. Global memory (bytes) is under limit -> Wakes up neighbors on this thread. // 2. Local queue (length) is under limit -> Wakes up this connection's producer. if (qbp.IsPipelineBufferUnderLimit(conn_stats.pipeline_queue_bytes, parsed_cmd_q_len_) || !HasPendingMessages()) { qbp.pipeline_cnd.notify_all(); } if (subscriber_over_limit && conn_stats.dispatch_queue_subscriber_bytes < qbp.publish_buffer_limit) qbp.pubsub_ec.notify(); } DCHECK(cc_->conn_closing || reply_builder_->GetError()); cc_->conn_closing = true; qbp.pipeline_cnd.notify_all(); // If shutdown was requested, we need to break the receive call in case the i/o fiber // is blocked there. With io loop v2, we can have a different mechanism to break from recv flow. if (request_shutdown_) { ShutdownSelfBlocking(); } } void Connection::ShrinkPipelinePool() { if (pipeline_req_pool_.empty()) return; auto& conn_stats = tl_facade_stats->conn_stats; if (tl_pipe_cache_sz_tracker.CheckAndUpdateWatermark(pipeline_req_pool_.size())) { conn_stats.pipeline_cmd_cache_bytes -= UsedMemoryInternal(*pipeline_req_pool_.back()); pipeline_req_pool_.pop_back(); } } Connection::PipelineMessagePtr Connection::GetFromPoolOrCreate() { if (pipeline_req_pool_.empty()) return PipelineMessagePtr{CreateParsedCommand()}; auto& conn_stats = tl_facade_stats->conn_stats; auto ptr = std::move(pipeline_req_pool_.back()); pipeline_req_pool_.pop_back(); conn_stats.pipeline_cmd_cache_bytes -= UsedMemoryInternal(*ptr); ptr->ResetForReuse(); ptr->Init(reply_builder_.get(), cc_.get()); ptr->ConfigureMCExtension(protocol_ == Protocol::MEMCACHE); return ptr; } void Connection::ShutdownSelfBlocking() { util::Connection::Shutdown(); } bool Connection::Migrate(util::fb2::ProactorBase* dest) { // Migrate is used only by replication, so it doesn't have properties of full-fledged // connections CHECK(!cc_->async_dispatch); CHECK_EQ(cc_->subscriptions, 0); // are bound to thread local caches CHECK_EQ(self_.use_count(), 1u); // references cache our thread and backpressure // if (ioloop_v2_ && socket_ && socket_->IsOpen()) { socket_->ResetOnRecvHook(); } // Migrate is only used by DFLY Thread and Flow command which both check against // the result of Migration and handle it explicitly in their flows so this can act // as a weak if condition instead of a crash prone CHECK. if (async_fb_.IsJoinable() || cc_->conn_closing) { return false; } listener()->Migrate(this, dest); // After we migrate, it could be the case the connection was shut down. We should // act accordingly. if (!socket()->IsOpen()) { return false; } return true; } Connection::WeakRef Connection::Borrow() { DCHECK(self_); return {self_, unsigned(socket_->proactor()->GetPoolIndex()), id_}; } void Connection::ShutdownThreadLocal() { pipeline_req_pool_.clear(); } bool Connection::IsCurrentlyDispatching() const { if (!cc_) return false; return cc_->async_dispatch || cc_->sync_dispatch; } void Connection::SendPubMessageAsync(PubMessage msg) { SendAsync({make_unique(std::move(msg))}); } void Connection::SendMonitorMessageAsync(string msg) { SendAsync({MonitorMessage{std::move(msg)}}); } void Connection::SendCheckpoint(fb2::BlockingCounter bc, bool ignore_paused, bool ignore_blocked) { if (!IsCurrentlyDispatching()) return; if (cc_->paused && ignore_paused) return; if (cc_->blocked && ignore_blocked) return; VLOG(2) << "Sent checkpoint to " << DebugInfo(); bc->Add(1); SendAsync({CheckpointMessage{bc}}); } void Connection::SendInvalidationMessageAsync(InvalidationMessage msg) { SendAsync({std::move(msg)}); } void Connection::LaunchAsyncFiberIfNeeded() { if (!async_fb_.IsJoinable() && !migration_in_process_) { VLOG(1) << "[" << id_ << "] LaunchAsyncFiberIfNeeded "; async_fb_ = fb2::Fiber(fb2::Launch::post, "connection_dispatch", [this]() { AsyncFiber(); }); } } // SendAsync is now strictly for the Control Path (Admin/Events). // Pipeline commands are handled separately via EnqueueParsedCommand to maintain // clean separation between Data and Control paths. // Note: Should never block - the callers may run in as a brief callback. void Connection::SendAsync(MessageHandle msg) { DCHECK(cc_); DCHECK(listener()); DCHECK_EQ(ProactorBase::me(), socket_->proactor()); auto& conn_stats = tl_facade_stats->conn_stats; // "Closing" connections might be still processing commands, as we don't interrupt them. // So we still want to deliver control messages to them (like checkpoints) if // async_fb_ is running (joinable). if (cc_->conn_closing && (!msg.IsCheckPoint() || !async_fb_.IsJoinable())) return; // If we launch while closing, it won't be awaited. Control messages will be processed on cleanup. if (!cc_->conn_closing) { LaunchAsyncFiberIfNeeded(); } DCHECK_NE(phase_, PRECLOSE); // No more messages are processed after this point // Close MONITOR connection if we overflow limits. // We must check the Thread-Global memory usage of BOTH: // 1. The Control Path (dispatch_queue_bytes) // 2. The Data Path (pipeline_queue_bytes) if (msg.IsMonitor()) { if (GetQueueBackpressure().IsPipelineBufferOverLimit( conn_stats.dispatch_queue_bytes + conn_stats.pipeline_queue_bytes, GetPendingMessageCount())) { cc_->conn_closing = true; request_shutdown_ = true; // We don't shutdown here. The reason is that TLS socket is preemptive // and SendAsync is atomic. cnd_.notify_one(); return; } } local_stats_.dispatch_entries_added++; UpdateDispatchStats(msg, true /* add */); msg.dispatch_cycle = CycleClock::Now(); // Admin Queueing Rules: // Checkpoints go to the front (after existing checkpoints), while all others to the back. bool had_pending_messages = HasPendingMessages(); // check the queues before enqueuing if (msg.IsCheckPoint()) { auto it = dispatch_q_.begin(); while (it < dispatch_q_.end() && it->IsCheckPoint()) ++it; dispatch_q_.insert(it, std::move(msg)); } else { dispatch_q_.push_back(std::move(msg)); } // Control Path Notification: // We need to wake up the AsyncFiber only if it is currently sleeping. // 1. Memcache: Sleeps if dispatch_q_ is empty. Must notify on 0->1 transition. // 2. Redis: Sleeps if BOTH queues are empty. If pipeline has items, it's already awake. bool should_notify = false; if (protocol_ == Protocol::REDIS) { if (!had_pending_messages) { should_notify = true; } } else { // MEMCACHE should_notify = (dispatch_q_.size() == 1); } if (should_notify && !cc_->sync_dispatch) { cnd_.notify_one(); } } void Connection::UpdateDispatchStats(const MessageHandle& msg, bool add) { size_t mem = msg.UsedMemory(); auto& qbp = GetQueueBackpressure(); auto& conn_stats = tl_facade_stats->conn_stats; if (add) { conn_stats.dispatch_queue_entries++; conn_stats.dispatch_queue_bytes += mem; dispatch_q_bytes_ += mem; if (msg.IsPubMsg()) { qbp.subscriber_bytes.fetch_add(mem, std::memory_order_relaxed); conn_stats.dispatch_queue_subscriber_bytes += mem; dispatch_q_subscriber_bytes_ += mem; } } else { DCHECK_GT(conn_stats.dispatch_queue_entries, 0u); DCHECK_GE(conn_stats.dispatch_queue_bytes, mem); conn_stats.dispatch_queue_entries--; conn_stats.dispatch_queue_bytes -= mem; dispatch_q_bytes_ -= mem; if (msg.IsPubMsg()) { DCHECK_GE(conn_stats.dispatch_queue_subscriber_bytes, mem); DCHECK_GE(qbp.subscriber_bytes.load(std::memory_order_relaxed), mem); qbp.subscriber_bytes.fetch_sub(mem, std::memory_order_relaxed); conn_stats.dispatch_queue_subscriber_bytes -= mem; dispatch_q_subscriber_bytes_ -= mem; } } } std::string Connection::LocalBindStr() const { if (socket_->IsUDS()) return "unix-domain-socket"; auto le = socket_->LocalEndpoint(); return absl::StrCat(le.address().to_string(), ":", le.port()); } std::string Connection::LocalBindAddress() const { if (socket_->IsUDS()) return "unix-domain-socket"; auto le = socket_->LocalEndpoint(); return le.address().to_string(); } std::string Connection::RemoteEndpointStr() const { if (socket_->IsUDS()) return "unix-domain-socket"; auto re = socket_->RemoteEndpoint(); return absl::StrCat(re.address().to_string(), ":", re.port()); } std::string Connection::RemoteEndpointAddress() const { if (socket_->IsUDS()) return "unix-domain-socket"; auto re = socket_->RemoteEndpoint(); return re.address().to_string(); } facade::ConnectionContext* Connection::cntx() { return cc_.get(); } void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest, bool force) { if ((!force && !migration_enabled_) || cc_ == nullptr) { return; } // Connections can migrate at most once. migration_enabled_ = false; migration_request_ = dest; } void Connection::StartTrafficLogging(string_view path) { OpenTrafficLogger(path); } void Connection::StopTrafficLogging() { lock_guard lk(tl_traffic_logger.mutex); tl_traffic_logger.ResetLocked(); } bool Connection::IsHttp() const { return is_http_; } size_t Connection::GetMemoryUsage() const { size_t mem = sizeof(*this) + cmn::HeapSize(name_) + cmn::HeapSize(memcache_parser_) + cmn::HeapSize(redis_parser_) + cmn::HeapSize(cc_) + cmn::HeapSize(reply_builder_); // parsed_cmd_ can be null when dispatching a command, or for http connections. if (parsed_cmd_) { mem += UsedMemoryInternal(*parsed_cmd_); } // We add a hardcoded 9k value to accommodate for the part of the Fiber stack that is in use. // The allocated stack is actually larger (~130k), but only a small fraction of that (9k // according to our checks) is actually part of the RSS. mem += 9'000; return mem; } void Connection::IncreaseConnStats() { DCHECK(tl_facade_stats); auto& conn_stats = tl_facade_stats->conn_stats; if (IsMainOrMemcache()) ++conn_stats.num_conns_main; else ++conn_stats.num_conns_other; conn_stats.read_buf_capacity += io_buf_.Capacity(); conn_stats.dispatch_queue_entries += dispatch_q_.size(); conn_stats.dispatch_queue_bytes += dispatch_q_bytes_; conn_stats.pipeline_queue_entries += parsed_cmd_q_len_; conn_stats.pipeline_queue_bytes += parsed_cmd_q_bytes_; if (dispatch_q_subscriber_bytes_ > 0) { auto& qbp = GetQueueBackpressure(); conn_stats.dispatch_queue_subscriber_bytes += dispatch_q_subscriber_bytes_; qbp.subscriber_bytes.fetch_add(dispatch_q_subscriber_bytes_, std::memory_order_relaxed); } } void Connection::DecreaseConnStats() { DCHECK(tl_facade_stats); auto& conn_stats = tl_facade_stats->conn_stats; if (IsMainOrMemcache()) { DCHECK_GT(conn_stats.num_conns_main, 0u); --conn_stats.num_conns_main; } else { DCHECK_GT(conn_stats.num_conns_other, 0u); --conn_stats.num_conns_other; } DCHECK_GE(conn_stats.read_buf_capacity, io_buf_.Capacity()); conn_stats.read_buf_capacity -= io_buf_.Capacity(); DCHECK_GE(conn_stats.dispatch_queue_entries, dispatch_q_.size()); conn_stats.dispatch_queue_entries -= dispatch_q_.size(); DCHECK_GE(conn_stats.dispatch_queue_bytes, dispatch_q_bytes_); conn_stats.dispatch_queue_bytes -= dispatch_q_bytes_; if (dispatch_q_subscriber_bytes_ > 0) { auto& qbp = GetQueueBackpressure(); DCHECK_GE(conn_stats.dispatch_queue_subscriber_bytes, dispatch_q_subscriber_bytes_); conn_stats.dispatch_queue_subscriber_bytes -= dispatch_q_subscriber_bytes_; DCHECK_GE(qbp.subscriber_bytes.load(std::memory_order_relaxed), dispatch_q_subscriber_bytes_); qbp.subscriber_bytes.fetch_sub(dispatch_q_subscriber_bytes_, std::memory_order_relaxed); } DCHECK_GE(conn_stats.pipeline_queue_entries, parsed_cmd_q_len_); conn_stats.pipeline_queue_entries -= parsed_cmd_q_len_; DCHECK_GE(conn_stats.pipeline_queue_bytes, parsed_cmd_q_bytes_); conn_stats.pipeline_queue_bytes -= parsed_cmd_q_bytes_; } void Connection::BreakOnce(uint32_t ev_mask) { if (breaker_cb_) { DVLOG(1) << "[" << id_ << "] Connection::breaker_cb_ " << ev_mask; auto fun = std::move(breaker_cb_); DCHECK(!breaker_cb_); fun(ev_mask); } } bool Connection::IsReplySizeOverLimit() const { std::atomic& reply_sz = tl_facade_stats->reply_stats.squashing_current_reply_size; size_t current = reply_sz.load(std::memory_order_acquire); const bool over_limit = reply_size_limit != 0 && current > 0 && current > reply_size_limit; if (over_limit) { LOG_EVERY_T(INFO, 10) << "Commands squashing current reply size is overlimit: " << current << "/" << reply_size_limit << ". Falling back to single command dispatch (instead of squashing)"; // Used by testing. Should not be used in production, therefore debug log level 5. DVLOG(5) << "Commands squashing current reply size is overlimit: " << current << "/" << reply_size_limit << ". Falling back to single command dispatch (instead of squashing)"; } return over_limit; } bool Connection::ParseRedisBatch() { return ParseRedis(max_busy_read_cycles_cached, true) == ParserStatus::OK; } bool Connection::ParseMCBatch() { CHECK(io_buf_.InputLen() > 0); do { if (parsed_cmd_ == nullptr) { // Happens with pipelined commands after the first one. PipelineMessagePtr ptr = GetFromPoolOrCreate(); parsed_cmd_ = ptr.release(); } uint32_t consumed = 0; memcache_parser_->set_last_unix_time(time(nullptr)); MemcacheParser::Result result = memcache_parser_->Parse(io::View(io_buf_.InputBuffer()), &consumed, parsed_cmd_->mc_command()); io_buf_.ConsumeInput(consumed); DVLOG(2) << "mc_result " << unsigned(result) << " consumed: " << consumed << " type " << unsigned(parsed_cmd_->mc_command()->type); if (result == MemcacheParser::INPUT_PENDING) return false; // We push the command to the parsed queue even in case of parse errors, // so that we can reply in order. EnqueueParsedCommand(parsed_cmd_); parsed_cmd_ = nullptr; // ownership transferred. if (result != MemcacheParser::OK) { // We can not just reply directly to parse error, as we may have pipelined commands before. // Fill the reply_payload into parsed_tail_ with the error and continue parsing. memcache_parser_->Reset(); // TODO(vlad): Use Proper SendError calls instead of SendSimpleString and error building auto client_error = [](string_view msg) { return absl::StrCat("CLIENT_ERROR ", msg); }; parsed_tail_->SetDeferredReply(); switch (result) { case MemcacheParser::UNKNOWN_CMD: parsed_tail_->SendSimpleString("ERROR"); break; case MemcacheParser::PARSE_ERROR: parsed_tail_->SendSimpleString(client_error("bad data chunk")); break; case MemcacheParser::BAD_DELTA: parsed_tail_->SendSimpleString(client_error("invalid numeric delta argument")); break; default: parsed_tail_->SendSimpleString(client_error("bad command line format")); break; } } } while (parsed_cmd_q_len_ < 128 && io_buf_.InputLen() > 0); return true; } bool Connection::ExecuteBatch() { auto& conn_stats = tl_facade_stats->conn_stats; auto advance_head = [this]() -> ParsedCommand* { auto* cmd = parsed_head_; parsed_head_ = cmd->next; ReleaseParsedCommand(cmd, parsed_head_ != nullptr /* is_pipelined */); return parsed_head_; }; auto dispatch = protocol_ == Protocol::MEMCACHE ? &ServiceInterface::DispatchMC : &ServiceInterface::DispatchCommandSimple; // Execute sequentially all parsed commands. for (auto& cmd = parsed_to_execute_; cmd != nullptr;) { if (reply_builder_->GetError()) return false; bool is_head = cmd == parsed_head_; // parser errors are stored as deferred replies if (cmd->IsDeferredReply() && cmd->CanReply()) { if (is_head) { cmd->SendReply(); cmd = advance_head(); } else { cmd = cmd->next; } continue; } // We must continue with async execution if we already have executing commands auto mode = is_head ? AsyncPreference::PREFER_ASYNC : AsyncPreference::ONLY_ASYNC; if (!ioloop_v2_) // only v2 loop supports any async commands so far mode = AsyncPreference::ONLY_SYNC; auto dispatch_res = (service_->*dispatch)(cmd, mode); // Enforce the pipeline invariant between the IO loop (producer) and AsyncFiber (consumer). // To prevent stream corruption, the command state must satisfy ONE of these rules: // 1. It is the head command (safely writes to the socket directly). // 2. It did not stall the pipeline (dispatch_res != WOULD_BLOCK) and therefore // must have buffered its reply locally (is_deferred == true). // 3. It stalled the pipeline because it requires synchronous execution // (dispatch_res == WOULD_BLOCK) and therefore must NOT have buffered // a reply (is_deferred == false). bool is_deferred = cmd->IsDeferredReply(); DCHECK(is_head || (is_deferred == (dispatch_res != DispatchResult::WOULD_BLOCK))) << "Pipeline contract breach! Invalid state for non-head command. " << "DispatchResult: " << static_cast(dispatch_res) << ", IsDeferred: " << is_deferred << ", Command Type: " << cmd->mc_command()->type; if (dispatch_res == DispatchResult::WOULD_BLOCK) break; // Sync command. Wait for current async commands to finish conn_stats.pipeline_dispatch_commands++; if (is_head) conn_stats.pipeline_dispatch_calls++; if (cmd->IsDeferredReply()) { cmd = cmd->next; } else { DCHECK(is_head); // only head can execute sync cmd = advance_head(); // advance it } } if (parsed_head_ == nullptr) parsed_tail_ = nullptr; return true; } bool Connection::ReplyBatch() { reply_builder_->SetBatchMode(true); for (auto& cmd = parsed_head_; cmd != parsed_to_execute_;) { if (!cmd->CanReply()) break; current_wait_.reset(); // we must free waiter before proceeding with other commands cmd->SendReply(); auto* prev = exchange(cmd, cmd->next); ReleaseParsedCommand(prev, cmd != parsed_to_execute_ /* is_pipelined */); if (reply_builder_->GetError()) return false; } if (parsed_head_ == nullptr) parsed_tail_ = nullptr; reply_builder_->SetBatchMode(false); reply_builder_->Flush(); return !reply_builder_->GetError(); } ParsedCommand* Connection::CreateParsedCommand() { auto* res = service_->AllocateParsedCommand(); res->Init(reply_builder_.get(), cc_.get()); res->ConfigureMCExtension(protocol_ == Protocol::MEMCACHE); return res; } void Connection::EnqueueParsedCommand(ParsedCommand* cmd) { DCHECK(cmd); cmd->next = nullptr; auto& conn_stats = tl_facade_stats->conn_stats; cmd->parsed_cycle = base::CycleClock::Now(); if (parsed_head_ == nullptr) { parsed_head_ = cmd; parsed_to_execute_ = cmd; } else { parsed_tail_->next = cmd; if (parsed_to_execute_ == nullptr) { // we've executed all the parsed commands so far. parsed_to_execute_ = cmd; } } parsed_tail_ = cmd; size_t used_mem = cmd->UsedMemory(); parsed_cmd_q_len_++; parsed_cmd_q_bytes_ += used_mem; local_stats_.dispatch_entries_added++; conn_stats.pipeline_queue_entries++; conn_stats.pipeline_queue_bytes += used_mem; // AsyncFiber for Memcache only wakes up on dispatch_q_, notify only redis as this is the parse // commands queue. if ((!cc_->sync_dispatch) && (protocol_ == Protocol::REDIS)) { cnd_.notify_one(); } } void Connection::ReleaseParsedCommand(ParsedCommand* cmd, bool is_pipelined) { size_t used_mem = cmd->UsedMemory(); auto& conn_stats = tl_facade_stats->conn_stats; DCHECK_GT(parsed_cmd_q_len_, 0u); DCHECK_GE(parsed_cmd_q_bytes_, used_mem); DCHECK_GT(conn_stats.pipeline_queue_entries, 0u); DCHECK_GE(conn_stats.pipeline_queue_bytes, used_mem); parsed_cmd_q_len_--; parsed_cmd_q_bytes_ -= used_mem; conn_stats.pipeline_queue_entries--; conn_stats.pipeline_queue_bytes -= used_mem; if (is_pipelined) { conn_stats.pipelined_cmd_cnt++; uint64_t latency_usec = CycleClock::ToUsec(CycleClock::Now() - cmd->parsed_cycle); conn_stats.pipelined_cmd_latency += latency_usec; conn_stats.pipelined_latency_hist.Add(latency_usec); // Decay the histogram every kPipelineLatencyDecayPeriod samples to // approximate a moving-window distribution; older observations contribute // half as much after each decay period. constexpr uint64_t kPipelineLatencyDecayPeriod = 1 << 14; // 16384 if ((conn_stats.pipelined_latency_hist.count() & (kPipelineLatencyDecayPeriod - 1)) == 0) { conn_stats.pipelined_latency_hist.Decay(); } } if (parsed_cmd_ == nullptr) { parsed_cmd_ = cmd; parsed_cmd_->ResetForReuse(); } else { // If we are over the limit, destroy the command instead of caching it. size_t cmd_mem = UsedMemoryInternal(*cmd); QueueBackpressure& qbp = GetQueueBackpressure(); if (conn_stats.pipeline_cmd_cache_bytes + cmd_mem <= qbp.pipeline_cache_limit) { conn_stats.pipeline_cmd_cache_bytes += cmd_mem; pipeline_req_pool_.emplace_back(cmd); } else { delete cmd; } } } void Connection::DestroyParsedQueue() { while (parsed_head_ != nullptr) { auto* cmd = parsed_head_; parsed_head_ = cmd->next; // Being able to drop an in-flight transaction would require it keeping no pointers // at all to any context data - too costly for now! (maybe let it own the arguments?) if (cmd->IsDeferredReply() && !cmd->CanReply()) cmd->Blocker()->Wait(); // explicitly wait for it to finish ReleaseParsedCommand(cmd, false); } parsed_tail_ = nullptr; CHECK_EQ(parsed_cmd_q_len_, 0u); CHECK_EQ(parsed_cmd_q_bytes_, 0u); delete parsed_cmd_; parsed_cmd_ = nullptr; } void Connection::UpdateFromFlags() { unsigned tid = fb2::ProactorBase::me()->GetPoolIndex(); thread_queue_backpressure[tid].pipeline_queue_max_len = GetFlag(FLAGS_pipeline_queue_limit); thread_queue_backpressure[tid].pipeline_buffer_limit = GetFlag(FLAGS_pipeline_buffer_limit); thread_queue_backpressure[tid].pipeline_cnd.notify_all(); max_busy_read_cycles_cached = base::CycleClock::FromUsec(GetFlag(FLAGS_max_busy_read_usec)); always_flush_pipeline_cached = GetFlag(FLAGS_always_flush_pipeline); pipeline_squash_limit_cached = GetFlag(FLAGS_pipeline_squash_limit); pipeline_wait_batch_usec = GetFlag(FLAGS_pipeline_wait_batch_usec); } std::vector Connection::GetMutableFlagNames() { return base::GetFlagNames(FLAGS_pipeline_queue_limit, FLAGS_pipeline_buffer_limit, FLAGS_max_busy_read_usec, FLAGS_always_flush_pipeline, FLAGS_pipeline_squash_limit, FLAGS_pipeline_wait_batch_usec); } void Connection::GetRequestSizeHistogramThreadLocal(std::string* hist) { if (io_req_size_hist) *hist = io_req_size_hist->ToString(); } void Connection::TrackRequestSize(bool enable) { if (enable && !io_req_size_hist) { io_req_size_hist = new base::Histogram; } else if (!enable && io_req_size_hist) { delete io_req_size_hist; io_req_size_hist = nullptr; } } void Connection::EnsureMemoryBudget(unsigned tid) { thread_queue_backpressure[tid].EnsureBelowLimit(); } ConnectionRef::ConnectionRef(const std::shared_ptr& ptr, unsigned thread_id, uint32_t client_id) : ptr_{ptr}, last_known_thread_id_{thread_id}, client_id_{client_id} { } Connection* ConnectionRef::Get() const { auto sptr = ptr_.lock(); // The connection can only be deleted on this thread, so // this pointer is valid until the next suspension. // Note: keeping a shared_ptr doesn't prolong the lifetime because // it doesn't manage the underlying connection. See definition of `self_`. return sptr.get(); } bool Connection::WeakRef::IsExpired() const { return ptr_.expired(); } uint32_t Connection::WeakRef::GetClientId() const { return client_id_; } bool ConnectionRef::operator<(const ConnectionRef& other) const { return client_id_ < other.client_id_; } bool ConnectionRef::operator==(const ConnectionRef& other) const { return client_id_ == other.client_id_; } void Connection::DoReadOnRecv(const util::FiberSocketBase::RecvNotification& n) { if (std::holds_alternative(n.read_result)) { io_ec_ = std::get(n.read_result); return; } using RecvNoti = util::FiberSocketBase::RecvNotification::RecvCompletion; if (std::holds_alternative(n.read_result)) { if (!std::get(n.read_result)) { io_ec_ = make_error_code(errc::connection_aborted); return; } if (io_buf_.AppendLen() == 0) { // We will regrow in IoLoopV2 return; } io::MutableBytes buf = io_buf_.AppendBuffer(); io::Result res = socket_->TryRecv(buf); if (res) { if (*res > 0) { // A recv call can return fewer bytes than requested even if the // socket buffer actually contains enough data to satisfy the full request. // TODO maybe worth looping here and try another recv call until it fails // with EAGAIN or EWOULDBLOCK. The problem there is that we need to handle // resizing if AppendBuffer is zero. io_buf_.CommitWrite(*res); return; } // *res == 0 io_ec_ = make_error_code(errc::connection_aborted); return; } // error path (!res) auto ec = res.error(); // EAGAIN and EWOULDBLOCK if (ec == errc::resource_unavailable_try_again || ec == errc::operation_would_block) { return; } io_ec_ = ec; } else if (std::holds_alternative(n.read_result)) { // provided buffer. io::MutableBytes buf = std::get(n.read_result); UpdateIoBufCapacity(io_buf_, &tl_facade_stats->conn_stats, [&]() { io_buf_.WriteAndCommit(buf.data(), buf.size()); }); } else { LOG(FATAL) << "Should not reach here"; } } void Connection::CheckIoBufCapacity(bool is_iobuf_full) { auto& conn_stats = tl_facade_stats->conn_stats; size_t max_io_buf_len = GetFlag(FLAGS_max_client_iobuf_len); size_t capacity = io_buf_.Capacity(); if (capacity < max_io_buf_len) { size_t parser_hint = 0; if (redis_parser_) parser_hint = redis_parser_->parselen_hint(); // Could be done for MC as well. // If we got a partial request and we managed to parse its // length, make sure we have space to store it instead of // increasing space incrementally. // (Note: The buffer object is only working in power-of-2 sizes, // so there's no danger of accidental O(n^2) behavior.) if (parser_hint > capacity) { UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.Reserve(std::min(max_io_buf_len, parser_hint)); }); } // If we got a partial request because iobuf was full, grow it up to // a reasonable limit to save on Recv() calls. if (is_iobuf_full && capacity < max_io_buf_len / 2) { // Last io used most of the io_buf to the end. UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.Reserve(capacity * 2); // Valid growth range. }); } if (io_buf_.AppendLen() == 0U) { // it can happen with memcached but not for RedisParser, because RedisParser fully // consumes the passed buffer LOG_EVERY_T(WARNING, 10) << "Maximum io_buf length reached " << io_buf_.Capacity() << ", consider to increase max_client_iobuf_len flag"; } } } variant Connection::IoLoopV2() { DCHECK(memcache_parser_) << "Not supported for redis yet"; auto* peer = socket_.get(); recv_buf_.res_len = 0; // Don't proceed with RegisterOnRecv() if socket is closed (possible cancellation) if (!peer->IsOpen()) return ParserStatus::OK; if (fb2::ProactorBase::me()->GetKind() == fb2::ProactorBase::Kind::IOURING) { #ifdef __linux__ fb2::UringProactor* up = static_cast(fb2::ProactorBase::me()); if (up->BufRingEntrySize(kRecvSockGid) > 0 && !is_tls_) { static_cast(peer)->EnableRecvMultishot(); } #endif } peer->RegisterOnRecv([this](const FiberSocketBase::RecvNotification& n) { DVLOG(2) << "Calling DoReadOnRecv iobuf_len: " << io_buf_.InputLen(); DoReadOnRecv(n); io_event_.notify(); }); ParserStatus parse_status = OK; // Waiter that is passed to the current async command head to be notified on completion auto ioevent_cb = [this]() { io_event_.notify(); }; util::fb2::detail::Waiter ioevent_waiter{ioevent_cb}; // takes callback by reference absl::Cleanup waiter_cleanup = [this] { current_wait_.reset(); }; do { HandleMigrateRequest(); // Register completion for current head if its pending and we don't wait if (auto* cmd = parsed_head_; cmd && cmd != parsed_to_execute_ && !current_wait_.has_value()) { current_wait_.emplace(cmd, &ioevent_waiter); } if (io_buf_.InputLen() == 0) { // Poll again for readiness. The event handler registered above is edge triggered // We should read from the socket until EAGAIN or EWOULDBLOCK // to make sure we consume all available data. // See "Do I need to continuously read/write" question // under https://man7.org/linux/man-pages/man7/epoll.7.html // The exception is when we use io_uring with multishot recv enabled, in which case // we rely on the kernel to keep feeding us data until we multishot is disabled. DoReadOnRecv(FiberSocketBase::RecvNotification{true}); io_event_.await([this]() { // TODO: optimize CanReply with looking up waiter key bool cmd_executable = parsed_head_ && parsed_head_ == parsed_to_execute_; bool cmd_ready = !cmd_executable && parsed_head_ && parsed_head_->CanReply(); return io_buf_.InputLen() > 0 || cmd_ready || cmd_executable || io_ec_; }); } if (io_ec_) { LOG_IF(WARNING, cntx()->replica_conn) << "async io error: " << io_ec_; return std::exchange(io_ec_, {}); } phase_ = PROCESS; bool is_iobuf_full = io_buf_.AppendLen() == 0; if (io_buf_.InputLen() > 0) { parse_status = ParseLoop(); } else { parse_status = NEED_MORE; if (parsed_head_) { if (parsed_head_ == parsed_to_execute_) ExecuteBatch(); ReplyBatch(); } } if (reply_builder_->GetError()) { return reply_builder_->GetError(); } if (parse_status == NEED_MORE) { parse_status = OK; CheckIoBufCapacity(is_iobuf_full); } else if (parse_status != OK) { break; } } while (peer->IsOpen()); return parse_status; } Connection::WaitEvent::WaitEvent(ParsedCommand* cmd, util::fb2::detail::Waiter* w) : key(cmd->Blocker()->OnCompletion(w)) { } void ResetStats() { auto& cstats = tl_facade_stats->conn_stats; cstats.pipelined_cmd_cnt = 0; cstats.conn_received_cnt = 0; cstats.command_cnt_main = 0; cstats.command_cnt_other = 0; cstats.io_read_cnt = 0; cstats.io_read_bytes = 0; tl_facade_stats->reply_stats = {}; if (io_req_size_hist) io_req_size_hist->Clear(); } } // namespace facade ================================================ FILE: src/facade/dragonfly_connection.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include #include "facade/connection_ref.h" #include "facade/facade_types.h" #include "facade/parsed_command.h" #include "io/io_buf.h" #include "util/connection.h" #include "util/fibers/fibers.h" #include "util/fibers/synchronization.h" typedef struct ssl_ctx_st SSL_CTX; // need to declare for older linux distributions like CentOS 7 #ifndef SO_INCOMING_CPU #define SO_INCOMING_CPU 49 #endif #ifndef SO_INCOMING_NAPI_ID #define SO_INCOMING_NAPI_ID 56 #endif #ifdef ABSL_HAVE_ADDRESS_SANITIZER constexpr size_t kReqStorageSize = 88; #else constexpr size_t kReqStorageSize = 120; #endif namespace util { class HttpListenerBase; } // namespace util namespace facade { struct ConnectionStats; class ConnectionContext; class ServiceInterface; class SinkReplyBuilder; class RespSrvParser; // Connection represents an active connection for a client. // // It directly dispatches regular commands from the io-loop. // For pipelined requests, monitor and pubsub messages it uses // a separate dispatch queue that is processed on a separate fiber. class Connection : public util::Connection { public: static void Init(unsigned io_threads); static void Shutdown(); static void ShutdownThreadLocal(); Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service); ~Connection(); // A callback called by Listener::OnConnectionStart in the same thread where // HandleRequests will run. void OnConnectionStart(); using BreakerCb = std::function; using ShutdownCb = std::function; // PubSub message, either incoming message for active subscription or reply for new subscription. struct PubMessage { std::string pattern; // non-empty for pattern subscriber std::shared_ptr buf; // stores channel name and message std::string_view channel, message; // channel and message parts from buf bool is_sharded = false; // Unsubscribe simultaneously when sending unsubscribe message. Used for cluster migrations bool force_unsubscribe = false; }; // Monitor message, carries a simple payload with the registered event to be sent. struct MonitorMessage : public std::string {}; // Migration request message, the async fiber stops to give way for thread migration. struct MigrationRequestMessage {}; // Checkpoint message, used to track when the connection finishes executing the current command. struct CheckpointMessage { util::fb2::BlockingCounter bc; // Decremented counter when processed }; struct InvalidationMessage { std::string key; bool invalidate_due_to_flush = false; }; // Pipeline message, accumulated Redis command to be executed. using PipelineMessagePtr = std::unique_ptr; using PubMessagePtr = std::unique_ptr; // Variant wrapper around different message types struct MessageHandle { size_t UsedMemory() const; // How much bytes this handle takes up in total. // Checkpoint messages put themselves at the front of the queue, but only in relative // order to the rest of the messages in the queue. bool IsCheckPoint() const { return std::holds_alternative(handle); } bool IsPubMsg() const { return std::holds_alternative(handle); } bool IsMonitor() const { return std::holds_alternative(handle); } bool IsReplying() const; // control messages don't reply, messages carrying data do std::variant handle; // time when the message was dispatched to the dispatch queue as reported by // CycleClock::Now() uint64_t dispatch_cycle = 0; }; static_assert(sizeof(MessageHandle) <= 80, "Big structs should use indirection to avoid wasting deque space!"); enum Phase : uint8_t { SETUP, READ_SOCKET, PROCESS, SHUTTING_DOWN, PRECLOSE, NUM_PHASES }; using WeakRef = ConnectionRef; // Add PubMessage to dispatch queue. // Virtual because behavior is overridden in test_utils. virtual void SendPubMessageAsync(PubMessage); // Add monitor message to dispatch queue. void SendMonitorMessageAsync(std::string); // If any dispatch is currently in progress, increment counter and send checkpoint message to // decrement it once finished. void SendCheckpoint(util::fb2::BlockingCounter bc, bool ignore_paused = false, bool ignore_blocked = false); // Add InvalidationMessage to dispatch queue. virtual void SendInvalidationMessageAsync(InvalidationMessage); // Register hook that is executen when the connection breaks. void RegisterBreakHook(BreakerCb breaker_cb); void FlushReplies(); // Manually shutdown self. void ShutdownSelfBlocking(); // Migrate this connecton to a different thread. // Return true if Migrate succeeded // Return false if dispatch_fb_ is active bool Migrate(util::fb2::ProactorBase* dest); // Borrow weak reference to connection. Can be called from any thread. WeakRef Borrow(); bool IsCurrentlyDispatching() const; std::string GetClientInfo(unsigned thread_id) const; std::string GetClientInfo() const; virtual std::string RemoteEndpointStr() const; // virtual because overwritten in test_utils std::string RemoteEndpointAddress() const; std::string LocalBindStr() const; std::string LocalBindAddress() const; uint32_t GetClientId() const; virtual bool IsPrivileged() const; // virtual because overwritten in test_utils bool IsMain() const; // In addition to the listener role being main, also returns true if the protocol is Memcached. // This method returns true for customer facing listeners. bool IsMainOrMemcache() const; void SetName(std::string name); void SetLibName(std::string name); void SetLibVersion(std::string version); // Returns a map of 'libname:libver'->count, thread local data static const absl::flat_hash_map& GetLibStatsTL(); std::string_view GetName() const { return name_; } // Returns protocol type of this connection Protocol GetProtocol() const { return protocol_; } // Returns memory usage of this connection's auxiliary members in bytes. size_t GetMemoryUsage() const; ConnectionContext* cntx(); // Requests that at some point, this connection will be migrated to `dest` thread. // If force is false, the connection will migrate at most once, // and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest, bool force); // Starts traffic logging in the calling thread. Must be a proactor thread. // Each thread creates its own log file combining requests from all the connections in // that thread. A noop if the thread is already logging. static void StartTrafficLogging(std::string_view base_path); // Stops traffic logging in this thread. A noop if the thread is not logging. static void StopTrafficLogging(); // Get quick debug info for logs std::string DebugInfo() const; bool IsHttp() const; static void UpdateFromFlags(); // Set values from flags static std::vector GetMutableFlagNames(); // Triggers UpdateFromFlags static void TrackRequestSize(bool enable); static void EnsureMemoryBudget(unsigned tid); static void GetRequestSizeHistogramThreadLocal(std::string* hist); unsigned idle_time() const { return time(nullptr) - last_interaction_; } unsigned GetSendWaitTimeSec() const; Phase phase() const { return phase_; } bool IsSending() const; void Notify() { io_event_.notify(); } void MarkForClose(); protected: void OnShutdown() override; void OnPreMigrateThread() override; void OnPostMigrateThread() override; std::unique_ptr cc_; // Null for http connections private: enum ParserStatus : uint8_t { OK, NEED_MORE, ERROR }; struct AsyncOperations; // Check protocol and handle connection. void HandleRequests() final; // Start dispatch fiber and run IoLoop. void ConnectionFlow(); // Main loop reading client messages and passing requests to dispatch queue. std::variant IoLoop(); void DoReadOnRecv(const util::FiberSocketBase::RecvNotification& n); void CheckIoBufCapacity(bool is_iobuf_full); // Main loop reading client messages and passing requests to dispatch queue. std::variant IoLoopV2(); // Returns true if HTTP header is detected. io::Result CheckForHttpProto(); // Dispatches a single (Redis or MC) command. // `has_more` should indicate whether the io buffer has more commands // (pipelining in progress). Performs async dispatch if forced (already in async mode) or if // has_more is true, otherwise uses synchronous dispatch. void DispatchSingle(bool has_more, absl::FunctionRef invoke_cb, absl::FunctionRef enqueue_cmd_cb); // Handles events from the dispatch queue. void AsyncFiber(); // Processes a single Admin/Control message from dispatch_q_. // Returns true if the fiber should terminate (e.g. Migration). bool ProcessAdminMessage(MessageHandle* msg, AsyncOperations* async_op); // Processes the next Pipeline command from parsed_head_. void ProcessPipelineCommand(); void SendAsync(MessageHandle msg); // Updates Control Path statistics and backpressure counters for administrative // events, monitor messages, and PubSub notifications. // If add is true, stats are incremented, otherwise decremented. void UpdateDispatchStats(const MessageHandle& msg, bool add); ParserStatus ParseRedis(unsigned max_busy_cycles, bool enqueue_only = false); void OnBreakCb(int32_t mask); // Shrink pipeline pool by a little while handling regular commands. void ShrinkPipelinePool(); // Returns non-null request ptr if pool has vacant entries. PipelineMessagePtr GetFromPoolOrCreate(); void HandleMigrateRequest(); io::Result HandleRecvSocket(); bool ShouldEndAsyncFiber(const MessageHandle& msg); void LaunchAsyncFiberIfNeeded(); // Async fiber is started lazily // Squashes pipelined commands from the dispatch queue to spread load over all threads void SquashPipeline(); // Clear pipelined messages, disaptching only intrusive ones. void ClearPipelinedMessages(); std::pair GetClientInfoBeforeAfterTid() const; void IncreaseConnStats(); void DecreaseConnStats(); void BreakOnce(uint32_t ev_mask); // The read buffer with read data that needs to be parsed and processed. // For io_uring bundles we may have available_bytes larger than slice.size() // which means that there are more buffers available to read. struct ReadBuffer { size_t available_bytes; io::Bytes slice; void Consume(size_t len) { available_bytes -= len; slice.remove_prefix(len); } }; bool IsReplySizeOverLimit() const; // Returns true if one or more commands were parsed from the read buffer, // and false if no complete commands could be parsed (for example, when // parsing is pending more input). bool ParseMCBatch(); bool ParseRedisBatch(); // Call appropriate ParseBatch function, proceed with Execute and Reply all why input is remaining ParserStatus ParseLoop(); // Loop over enqueued async commands and enqueue them for async execution. // If async execution is not possible, handle them in synchronous mode one by one. // Returns true on successful execution, false on reply builder error. bool ExecuteBatch(); // Loop over finished async commands and let them reply. // Returns true on successful execution, false on reply builder error. bool ReplyBatch(); // Guard of the current subscription to a parsed commands async task blocker struct WaitEvent { explicit WaitEvent(ParsedCommand* cmd, util::fb2::detail::Waiter* w); std::optional key; }; ParsedCommand* CreateParsedCommand(); void EnqueueParsedCommand(ParsedCommand* cmd); // Releases the command memory back to the pool. // - Set is_pipelined=true if the command was successfully executed and should be counted // in latency/throughput stats. // - Set is_pipelined=false if the command is being dropped/cleaned up without execution or should // not be counted in stats. void ReleaseParsedCommand(ParsedCommand* cmd, bool is_pipelined); void DestroyParsedQueue(); // Dispatch Queue - Queue for the Control Path. // Handles asynchronous administrative tasks, events, and high-priority control // messages (e.g., PubSub, Monitor, Migration requests, Checkpoints) processed // by the AsyncFiber. std::deque dispatch_q_; // dispatch queue util::fb2::CondVarAny cnd_; // dispatch queue waker util::fb2::Fiber async_fb_; // async fiber (if started) size_t dispatch_q_bytes_ = 0; // total bytes in dispatch queue size_t dispatch_q_subscriber_bytes_ = 0; // total bytes from subscribers in dispatch queue std::error_code io_ec_; util::fb2::EventCount io_event_; std::optional current_wait_; // how many bytes of the current request have been consumed size_t request_consumed_bytes_ = 0; util::FiberSocketBase::ProvidedBuffer recv_buf_; io::IoBuf io_buf_; // used in io loop and parsers std::unique_ptr redis_parser_; std::unique_ptr memcache_parser_; ParsedCommand* parsed_cmd_ = nullptr; // Parsed Commands Queue - Queue for the Data Path. // // Commands move through the following stages in a single linked list: // 1) parsed but not yet dispatched : [parsed_to_execute_, ..., parsed_tail_] // 2) dispatched but not yet completed : between parsed_head_ and parsed_to_execute_ // 3) completed (replies ready to send) : a prefix of [parsed_head_, ..., parsed_to_execute_) // 4) replied and removed : before parsed_head_ (no longer in the list) // // Logical order diagram: // head -> ... -> (dispatched, waiting for completion) -> ... -> parsed_to_execute_ -> ... -> // tail // // parsed_to_execute_ is advanced as commands are dispatched for execution. // Executed (completed) commands are kept in the queue until their replies are sent, // in order to preserve reply ordering. // ReplyMCBatch walks from parsed_head_ up to (but not including) parsed_to_execute_, // replies commands that have completed, and removes only those replied commands from // the queue, advancing parsed_head_ accordingly. ParsedCommand* parsed_head_ = nullptr; ParsedCommand* parsed_tail_ = nullptr; ParsedCommand* parsed_to_execute_ = nullptr; // Total number of commands in parsed command queue size_t parsed_cmd_q_len_ = 0; // Total bytes used by commands in parsed command queue size_t parsed_cmd_q_bytes_ = 0; // Returns true if there are any commands pending in the parsed command queue or dispatch queue. bool HasPendingMessages() const { return parsed_head_ || !dispatch_q_.empty(); } // Returns total count of commands pending in the parsed command queue and dispatch queue. size_t GetPendingMessageCount() const { return parsed_cmd_q_len_ + dispatch_q_.size(); } uint32_t id_; Protocol protocol_; Phase phase_ = SETUP; struct { size_t read_cnt = 0; // total number of read calls size_t net_bytes_in = 0; // total number of bytes read size_t dispatch_entries_added = 0; // total number of dispatch queue entries size_t cmds = 0; // total number of commands executed } local_stats_; std::unique_ptr reply_builder_; util::HttpListenerBase* http_listener_; SSL_CTX* ssl_ctx_; ServiceInterface* service_; time_t creation_time_, last_interaction_; std::string name_; std::string lib_name_; std::string lib_ver_; unsigned parser_error_ = 0; BreakerCb breaker_cb_; // Used to keep track of borrowed references. Does not really own itself std::shared_ptr self_; util::fb2::ProactorBase* migration_request_ = nullptr; // Pooled pipeline messages per-thread // Aggregated while handling pipelines, gradually released while handling regular commands. static thread_local std::vector pipeline_req_pool_; union { uint16_t flags_; struct { // a flag indicating whether the client has turned on client tracking. bool tracking_enabled_ : 1; bool skip_next_squashing_ : 1; // Forcefully skip next squashing // Connection migration vars, see RequestAsyncMigration() above. bool migration_enabled_ : 1; bool migration_in_process_ : 1; bool is_http_ : 1; // whether the connection is TLS. We can be sure our socket is TlsSocket // if the flag is set. bool is_tls_ : 1; bool is_main_ : 1; bool ioloop_v2_ : 1; // whether this connection is running on ioloop v2 // If post migration is allowed to call RegisterRecv bool migration_allowed_to_register_ : 1; }; }; bool request_shutdown_ = false; }; } // namespace facade ================================================ FILE: src/facade/dragonfly_listener.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/dragonfly_listener.h" #include #include #include #include #include "absl/functional/bind_front.h" #include "facade/tls_helpers.h" #ifdef DFLY_USE_SSL #include #endif #include "base/flags.h" #include "base/logging.h" #include "facade/dragonfly_connection.h" #include "facade/service_interface.h" #include "util/proactor_pool.h" using namespace std; ABSL_FLAG(uint32_t, conn_io_threads, 0, "Number of threads used for handing server connections"); ABSL_FLAG(uint32_t, conn_io_thread_start, 0, "Starting thread id for handling server connections"); ABSL_FLAG(bool, tls, false, ""); ABSL_FLAG(bool, no_tls_on_admin_port, false, "Allow non-tls connections on admin port"); ABSL_FLAG(bool, enable_tcp_defer_accept, true, "Enable TCP_DEFER_ACCEPT option on server sockets"); ABSL_FLAG(bool, conn_use_incoming_cpu, false, "If true uses incoming cpu of a socket in order to distribute" " incoming connections"); ABSL_DECLARE_FLAG(std::string, tls_cert_file); ABSL_DECLARE_FLAG(std::string, tls_key_file); ABSL_DECLARE_FLAG(std::string, tls_ca_cert_file); ABSL_DECLARE_FLAG(std::string, tls_ca_cert_dir); ABSL_FLAG(uint32_t, tcp_keepalive, 300, "the period in seconds of inactivity after which keep-alives are triggerred," "the duration until an inactive connection is terminated is twice the specified time"); ABSL_FLAG(uint32_t, tcp_user_timeout, 0, "the maximum period in milliseconds that transimitted data may stay unacknowledged " "before TCP aborts the connection. 0 means OS default timeout"); ABSL_DECLARE_FLAG(bool, primary_port_http_enabled); #if 0 enum TlsClientAuth { CL_AUTH_NO = 0, CL_AUTH_YES = 1, CL_AUTH_OPTIONAL = 2, }; facade::ConfigEnum tls_auth_clients_enum[] = { {"no", CL_AUTH_NO}, {"yes", CL_AUTH_YES}, {"optional", CL_AUTH_OPTIONAL}, }; static int tls_auth_clients_opt = CL_AUTH_YES; CONFIG_enum(tls_auth_clients, "yes", "", tls_auth_clients_enum, tls_auth_clients_opt); #endif namespace facade { // See dragonfly_listener.h std::atomic g_shutdown_fast{false}; using namespace util; using util::detail::SafeErrorMessage; using absl::GetFlag; namespace { bool ConfigureKeepAlive(int fd) { int val = 1; if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &val, sizeof(val)) < 0) return false; val = absl::GetFlag(FLAGS_tcp_keepalive); #ifdef __APPLE__ if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPALIVE, &val, sizeof(val)) < 0) return false; #else if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &val, sizeof(val)) < 0) return false; #endif /* Send next probes after the specified interval. Note that we set the * delay as interval / 3, as we send three probes before detecting * an error (see the next setsockopt call). */ val = std::max(val / 3, 1); if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &val, sizeof(val)) < 0) return false; /* Consider the socket in error state after three we send three ACK * probes without getting a reply. */ val = 3; if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &val, sizeof(val)) < 0) return false; return true; } struct ListenerStats { size_t tls_allocated_bytes = 0; uint64_t refused_conn_maxclients_reached_cnt = 0; }; thread_local ListenerStats listener_tl_stats; atomic_int ssl_init_refcount = 0; void* OverriddenSSLMalloc(size_t size, const char* file, int line) { void* res = mi_malloc(size); listener_tl_stats.tls_allocated_bytes += mi_malloc_usable_size(res); return res; } void* OverriddenSSLRealloc(void* addr, size_t size, const char* file, int line) { size_t prev_size = mi_malloc_usable_size(addr); void* res = mi_realloc(addr, size); listener_tl_stats.tls_allocated_bytes += mi_malloc_usable_size(res); listener_tl_stats.tls_allocated_bytes -= prev_size; return res; } void OverriddenSSLFree(void* addr, const char* file, int line) { listener_tl_stats.tls_allocated_bytes -= mi_malloc_usable_size(addr); mi_free(addr); } } // namespace Listener::Listener(Protocol protocol, ServiceInterface* si, Role role) : service_(si), role_(role), protocol_(protocol) { #ifdef DFLY_USE_SSL if (ssl_init_refcount.fetch_add(1) == 0) { CRYPTO_set_mem_functions(&OverriddenSSLMalloc, &OverriddenSSLRealloc, &OverriddenSSLFree); } // Always initialise OpenSSL so we can enable TLS at runtime. OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, nullptr); // Print this only for main interface if (IsMainInterface()) { std::string_view ssl_version = SSLeay_version(SSLEAY_VERSION); LOG(INFO) << "SSL version: " << ssl_version; } if (!ReconfigureTLS()) { exit(-1); } #endif // We only set the HTTP interface for: // 1. Privileged users (on privileged listener) // 2. Main listener (if enabled) const bool is_main_enabled = GetFlag(FLAGS_primary_port_http_enabled); if (IsPrivilegedInterface() || (IsMainInterface() && is_main_enabled)) { http_base_ = std::make_unique>(); http_base_->set_resource_prefix("http://static.dragonflydb.io/data-plane"); si->ConfigureHttpHandlers(http_base_.get(), IsPrivilegedInterface()); } } Listener::~Listener() { #ifdef DFLY_USE_SSL SSL_CTX_free(ctx_); if (ssl_init_refcount.fetch_sub(1) == 1) { OPENSSL_cleanup(); } #endif } util::Connection* Listener::NewConnection(ProactorBase* proactor) { return new Connection{protocol_, http_base_.get(), ctx_, service_}; } error_code Listener::ConfigureServerSocket(int fd) { int val = 1; if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)) < 0) { LOG(WARNING) << "Could not set reuse addr on socket " << SafeErrorMessage(errno); } #ifdef TCP_DEFER_ACCEPT // TCP_DEFER_ACCEPT is only for Linux, and defined by Linux OS-Kernel if (GetFlag(FLAGS_enable_tcp_defer_accept)) { sockaddr_storage addr; socklen_t len = sizeof(addr); // TCP_DEFER_ACCEPT is only applicable to TCP (IPv4/IPv6) sockets, not Unix domain sockets // (UDS). if (getsockname(fd, reinterpret_cast(&addr), &len) == 0 && (addr.ss_family == AF_INET || addr.ss_family == AF_INET6)) { // Instruct the kernel to defer waking up accept() until actual payload data arrives, // with a timeout of 1 second. // This provides a kernel-level shield against "Pure Zombie" storms - where malicious or // misconfigured clients complete the TCP 3-way handshake but never send data (or immediately // send FIN/RST). The kernel will silently clean up these empty connections without // consuming Dragonfly fibers or OpenSSL memory. // This imposes zero latency penalty on well-behaved clients, as the kernel instantly // yields the connection to user-space the moment their first byte (e.g., TLS ClientHello // or RESP command) arrives. static constexpr int kDeferAcceptTimeoutSec = 1; if (setsockopt(fd, IPPROTO_TCP, TCP_DEFER_ACCEPT, &kDeferAcceptTimeoutSec, sizeof(kDeferAcceptTimeoutSec)) < 0) { LOG(WARNING) << "Could not set TCP_DEFER_ACCEPT " << SafeErrorMessage(errno); } } } #endif bool success = ConfigureKeepAlive(fd); #ifdef __linux__ int user_timeout = absl::GetFlag(FLAGS_tcp_user_timeout); if (setsockopt(fd, IPPROTO_TCP, TCP_USER_TIMEOUT, &user_timeout, sizeof(int)) < 0) { LOG(WARNING) << "Could not set user timeout on socket " << SafeErrorMessage(errno); } #endif if (!success) { #ifndef __APPLE__ int myerr = errno; int socket_type; socklen_t length = sizeof(socket_type); // Ignore the error on UDS. if (getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &socket_type, &length) != 0 || socket_type != AF_UNIX) { LOG(WARNING) << "Could not configure keep alive " << SafeErrorMessage(myerr); } #endif } return error_code{}; } bool Listener::ReconfigureTLS() { #ifdef DFLY_USE_SSL SSL_CTX* prev_ctx = ctx_; const bool tls_on_privileged_port = !GetFlag(FLAGS_no_tls_on_admin_port); if (GetFlag(FLAGS_tls) && (!IsPrivilegedInterface() || tls_on_privileged_port)) { SSL_CTX* ctx = CreateSslCntx(facade::TlsContextRole::SERVER); if (!ctx) { return false; } ctx_ = ctx; } else { ctx_ = nullptr; } if (prev_ctx) { // SSL_CTX is reference counted so if other connections have a reference // to the context it won't be freed yet. SSL_CTX_free(prev_ctx); } #endif return true; } size_t Listener::TLSUsedMemoryThreadLocal() { return listener_tl_stats.tls_allocated_bytes; } uint64_t Listener::RefusedConnectionMaxClientsCount() { return listener_tl_stats.refused_conn_maxclients_reached_cnt; } void Listener::PreAcceptLoop(util::ProactorBase* pb) { } bool Listener::IsPrivilegedInterface() const { return role_ == Role::PRIVILEGED; } bool Listener::IsMainInterface() const { return role_ == Role::MAIN; } void Listener::PreShutdown() { // If NOW/FORCE requested, expedite shutdown without waiting. if (g_shutdown_fast.load(std::memory_order_acquire)) { return; } // Otherwise: Iterate on all connections and allow them to finish their commands for // a short period. // Executed commands can be visible in snapshots or replicas, but if we close the client // connections too fast we might not send the acknowledgment for those commands. // This shouldn't take a long time: All clients should reject incoming commands // at this stage since we're in SHUTDOWN mode. // If a command is running for too long we give up and proceed. DispatchTracker tracker{ {this}, nullptr, false /* paused connections */, false /* blocking connections*/}; tracker.TrackAll(); if (!tracker.Wait(absl::Milliseconds(10))) { LOG(WARNING) << "Some commands are still being dispatched but didn't conclude in time. " "Proceeding in shutdown."; } } void Listener::PostShutdown() { } void Listener::OnConnectionStart(util::Connection* conn) { facade::Connection* facade_conn = static_cast(conn); VLOG(1) << "Opening connection " << facade_conn->GetClientId(); facade_conn->OnConnectionStart(); } void Listener::OnConnectionClose(util::Connection* conn) { Connection* facade_conn = static_cast(conn); VLOG(1) << "Closing connection " << facade_conn->GetClientId(); } void Listener::OnMaxConnectionsReached(util::FiberSocketBase* sock) { listener_tl_stats.refused_conn_maxclients_reached_cnt++; sock->Write(io::Buffer("-ERR max number of clients reached\r\n")); } // We can limit number of threads handling dragonfly connections. ProactorBase* Listener::PickConnectionProactor(util::FiberSocketBase* sock) { util::ProactorPool* pp = pool(); uint32_t res_id = kuint32max; if (!sock->IsUDS()) { int fd = sock->native_handle(); int cpu, napi_id; socklen_t len = sizeof(cpu); // I suspect that the advantage of using SO_INCOMING_NAPI_ID is that // we can also track the affinity changes during the lifetime of the process // i.e. when a different CPU is assigned to handle the RX traffic. // On some distributions (WSL1, for example), SO_INCOMING_CPU is not supported. if (0 == getsockopt(fd, SOL_SOCKET, SO_INCOMING_CPU, &cpu, &len)) { VLOG(1) << "CPU for connection " << fd << " is " << cpu; // Avoid CHECKINGing success, it sometimes fail on WSL // https://github.com/dragonflydb/dragonfly/issues/2090 if (0 == getsockopt(fd, SOL_SOCKET, SO_INCOMING_NAPI_ID, &napi_id, &len)) { VLOG(1) << "NAPI for connection " << fd << " is " << napi_id; } if (GetFlag(FLAGS_conn_use_incoming_cpu)) { // We choose a thread that is running on the incoming CPU. Usually there is // a single thread per CPU. SO_INCOMING_CPU returns the CPU that the kernel // uses to steer the packets to. In order to make // conn_use_incoming_cpu effective, we should make sure that the receive packets are // steered to enough CPUs. This can be done by setting the RPS mask in // /sys/class/net//queues/rx-/rps_cpus. For more details, see // https://docs.kernel.org/networking/scaling.html#rps-configuration // Please note that if conn_use_incoming_cpu is true, connections will be handled only // on the CPUs that handle the softirqs for the incoming packets. // To avoid imbalance in CPU load, RPS tuning is strongly advised. const vector& ids = pool()->MapCpuToThreads(cpu); if (!ids.empty()) { res_id = ids[0]; } } } } if (res_id == kuint32max) { uint32_t total = GetFlag(FLAGS_conn_io_threads); uint32_t start = GetFlag(FLAGS_conn_io_thread_start) % pp->size(); if (total == 0 || total + start > pp->size()) { total = pp->size() - start; } res_id = start + (next_id_.fetch_add(1, std::memory_order_relaxed) % total); } return pp->at(res_id); } DispatchTracker::DispatchTracker(absl::Span listeners, facade::Connection* issuer, bool ignore_paused, bool ignore_blocked) : listeners_{listeners.begin(), listeners.end()}, issuer_{issuer}, ignore_paused_{ignore_paused}, ignore_blocked_{ignore_blocked} { } void DispatchTracker::TrackOnThread() { for (auto* listener : listeners_) { listener->TraverseConnectionsOnThread( [this](unsigned thread_index, util::Connection* conn) { Handle(thread_index, conn); }, UINT32_MAX, nullptr); } } bool DispatchTracker::Wait(absl::Duration duration) { bool res = bc_->WaitFor(absl::ToChronoMilliseconds(duration)); if (!res && ignore_blocked_) { LOG(INFO) << "Retrying DispatchTracker::Wait, as bc=" << bc_->DEBUG_Count(); // We track all connections again because a connection might became blocked between the time // we call tracking the last time. bc_ = BlockingCounter{0}; TrackAll(); res = bc_->WaitFor(absl::ToChronoMilliseconds(duration)); LOG_IF(INFO, !res) << "DispatchTracker::Wait failed again, bc=" << bc_->DEBUG_Count(); } return res; } void DispatchTracker::TrackAll() { for (auto* listener : listeners_) listener->TraverseConnections(absl::bind_front(&DispatchTracker::Handle, this)); } void DispatchTracker::Handle(unsigned thread_index, util::Connection* conn) { if (auto* fconn = static_cast(conn); fconn != issuer_) fconn->SendCheckpoint(bc_, ignore_paused_, ignore_blocked_); } } // namespace facade ================================================ FILE: src/facade/dragonfly_listener.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include "facade/facade_types.h" #include "util/fiber_socket_base.h" #include "util/fibers/proactor_base.h" #include "util/http/http_handler.h" #include "util/listener_interface.h" typedef struct ssl_ctx_st SSL_CTX; namespace facade { class ServiceInterface; class Connection; class Listener : public util::ListenerInterface { public: // The Role PRIVILEGED is for admin port/listener // The Role MAIN is for the main listener on main port // The Role OTHER is for all the other listeners enum class Role { PRIVILEGED, MAIN, OTHER }; Listener(Protocol protocol, ServiceInterface*, Role role = Role::OTHER); ~Listener(); std::error_code ConfigureServerSocket(int fd) final; // Wait until all command dispatches that are currently in progress finish, // ignore commands from issuer connection. bool AwaitCurrentDispatches(absl::Duration timeout, util::Connection* issuer); // ReconfigureTLS MUST be called from the same proactor as the listener. bool ReconfigureTLS(); // Returns thread-local dynamic memory usage by TLS. static size_t TLSUsedMemoryThreadLocal(); static uint64_t RefusedConnectionMaxClientsCount(); bool IsPrivilegedInterface() const; bool IsMainInterface() const; Protocol protocol() const { return protocol_; } private: util::Connection* NewConnection(ProactorBase* proactor) final; ProactorBase* PickConnectionProactor(util::FiberSocketBase* sock) final; void OnConnectionStart(util::Connection* conn) final; void OnConnectionClose(util::Connection* conn) final; void OnMaxConnectionsReached(util::FiberSocketBase* sock) final; void PreAcceptLoop(ProactorBase* pb) final; void PreShutdown() final; void PostShutdown() final; std::unique_ptr http_base_; ServiceInterface* service_; std::atomic_uint32_t next_id_{0}; Role role_; uint32_t conn_cnt_{0}; Protocol protocol_; SSL_CTX* ctx_ = nullptr; }; // Dispatch tracker allows tracking the dispatch state of connections and blocking until all // detected busy connections finished dispatching. Ignores issuer connection. // // Mostly used to detect when global state changes (takeover, pause, cluster config update) are // visible to all commands and no commands are still running according to the old state / config. class DispatchTracker { public: DispatchTracker(absl::Span, facade::Connection* issuer, bool ignore_paused, bool ignore_blocked); void TrackAll(); // Track busy connection on all threads void TrackOnThread(); // Track busy connections on current thread // Wait until all tracked connections finished dispatching. // Returns true on success, false if timeout was reached. bool Wait(absl::Duration timeout); private: void Handle(unsigned thread_index, util::Connection* conn); std::vector listeners_; facade::Connection* issuer_; util::fb2::BlockingCounter bc_{0}; // tracks number of pending checkpoints bool ignore_paused_; bool ignore_blocked_; }; // Global shutdown tuning flag, controlled by SHUTDOWN options. // When true, listeners perform expedited shutdown without waiting for // in-flight dispatches (used by NOW/FORCE). extern std::atomic g_shutdown_fast; } // namespace facade ================================================ FILE: src/facade/error.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace facade { std::string WrongNumArgsError(std::string_view cmd); std::string ConfigSetFailed(std::string_view config_name); std::string InvalidExpireTime(std::string_view cmd); std::string UnknownSubCmd(std::string_view subcmd, std::string_view cmd); inline constexpr char kSyntaxErr[] = "syntax error"; inline constexpr char kWrongTypeErr[] = "-WRONGTYPE Operation against a key holding the wrong kind of value"; inline constexpr char kWrongJsonTypeErr[] = "-WRONGTYPE wrong JSON type of path value"; inline constexpr char kKeyNotFoundErr[] = "no such key"; inline constexpr char kInvalidIntErr[] = "value is not an integer or out of range"; inline constexpr char kInvalidFloatErr[] = "value is not a valid float"; inline constexpr char kUintErr[] = "value is out of range, must be positive"; inline constexpr char kIncrOverflow[] = "increment or decrement would overflow"; inline constexpr char kDbIndOutOfRangeErr[] = "DB index is out of range"; inline constexpr char kInvalidDbIndErr[] = "invalid DB index"; inline constexpr char kScriptNotFound[] = "-NOSCRIPT No matching script. Please use EVAL."; inline constexpr char kAuthRejected[] = "-WRONGPASS invalid username-password pair or user is disabled."; inline constexpr char kExpiryOutOfRange[] = "expiry is out of range"; inline constexpr char kIndexOutOfRange[] = "index out of range"; inline constexpr char kOutOfMemory[] = "Out of memory"; inline constexpr char kInvalidNumericResult[] = "result is not a number"; inline constexpr char kClusterNotConfigured[] = "Cluster is not yet configured"; inline constexpr char kLoadingErr[] = "-LOADING Dragonfly is loading the dataset in memory"; inline constexpr char kUndeclaredKeyErr[] = "script tried accessing undeclared key"; inline constexpr char kInvalidDumpValueErr[] = "DUMP payload version or checksum are wrong"; inline constexpr char kInvalidJsonPathErr[] = "invalid JSON path"; inline constexpr char kJsonParseError[] = "failed to parse JSON"; inline constexpr char kNanOrInfDuringIncr[] = "increment would produce NaN or Infinity"; inline constexpr char kCrossSlotError[] = "-CROSSSLOT Keys in request don't hash to the same slot"; inline constexpr char kTieredIoError[] = "IO error when reading value from tiered storage"; inline constexpr char kInvalidHllError[] = "Key is not a valid HyperLogLog string value"; inline constexpr char kSyntaxErrType[] = "syntax_error"; inline constexpr char kScriptErrType[] = "script_error"; inline constexpr char kConfigErrType[] = "config_error"; inline constexpr char kSearchErrType[] = "search_error"; inline constexpr char kWrongTypeErrType[] = "wrong_type"; inline constexpr char kRestrictDenied[] = "restrict_denied"; inline constexpr char kNoGroupErrType[] = "no_group_error"; inline constexpr char kNoAuthErrType[] = "no_auth"; } // namespace facade ================================================ FILE: src/facade/facade.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #include "base/logging.h" #include "facade/command_id.h" #include "facade/error.h" #include "facade/facade_stats.h" #include "facade/parsed_command.h" #include "facade/reply_builder.h" #include "facade/resp_expr.h" #include "strings/human_readable.h" namespace facade { using namespace std; #define ADD(x) (x) += o.x constexpr size_t kSizeConnStats = sizeof(ConnectionStats); ConnectionStats& ConnectionStats::operator+=(const ConnectionStats& o) { static_assert(kSizeConnStats == 272); ADD(read_buf_capacity); ADD(dispatch_queue_entries); ADD(dispatch_queue_bytes); ADD(pipeline_queue_entries); ADD(pipeline_queue_bytes); ADD(dispatch_queue_subscriber_bytes); ADD(pipeline_cmd_cache_bytes); ADD(io_read_cnt); ADD(io_read_bytes); ADD(command_cnt_main); ADD(command_cnt_other); ADD(pipelined_cmd_cnt); ADD(pipelined_cmd_latency); pipelined_latency_hist.Merge(o.pipelined_latency_hist); ADD(pipelined_wait_latency); ADD(conn_received_cnt); ADD(num_conns_main); ADD(num_conns_other); ADD(num_blocked_clients); ADD(num_read_yields); ADD(num_migrations); ADD(num_recv_provided_calls); ADD(pipeline_throttle_count); ADD(tls_accept_disconnects); ADD(handshakes_started); ADD(handshakes_completed); ADD(pipeline_dispatch_calls); ADD(pipeline_dispatch_commands); ADD(pipeline_dispatch_flush_usec); ADD(skip_pipeline_flushing); return *this; } ReplyStats::ReplyStats(ReplyStats&& other) noexcept { *this = other; } ReplyStats& ReplyStats::operator+=(const ReplyStats& o) { static_assert(sizeof(ReplyStats) == 80u + kSanitizerOverhead); ADD(io_write_cnt); ADD(io_write_bytes); for (const auto& k_v : o.err_count) { err_count[k_v.first] += k_v.second; } ADD(script_error_count); send_stats += o.send_stats; squashing_current_reply_size.fetch_add(o.squashing_current_reply_size.load(memory_order_relaxed), memory_order_relaxed); return *this; } #undef ADD ReplyStats& ReplyStats::operator=(const ReplyStats& o) { static_assert(sizeof(ReplyStats) == 80u + kSanitizerOverhead); if (this == &o) { return *this; } send_stats = o.send_stats; io_write_cnt = o.io_write_cnt; io_write_bytes = o.io_write_bytes; err_count = o.err_count; script_error_count = o.script_error_count; squashing_current_reply_size.store(o.squashing_current_reply_size.load(memory_order_relaxed), memory_order_relaxed); return *this; } string WrongNumArgsError(string_view cmd) { return absl::StrCat("wrong number of arguments for '", absl::AsciiStrToLower(cmd), "' command"); } string InvalidExpireTime(string_view cmd) { return absl::StrCat("invalid expire time in '", absl::AsciiStrToLower(cmd), "' command"); } string UnknownSubCmd(string_view subcmd, string_view cmd) { return absl::StrCat("Unknown subcommand or wrong number of arguments for '", subcmd, "'. Try ", cmd, " HELP."); } string ConfigSetFailed(string_view config_name) { return absl::StrCat("CONFIG SET failed (possibly related to argument '", config_name, "')."); } const char* RespExpr::TypeName(Type t) { switch (t) { case STRING: return "string"; case INT64: return "int"; case DOUBLE: return "double"; case ARRAY: return "array"; case NIL_ARRAY: return "nil-array"; case NIL: return "nil"; case ERROR: return "error"; } ABSL_UNREACHABLE(); } CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, int8_t last_key, uint32_t acl_categories) : name_(name), opt_mask_(mask), arity_(arity), first_key_(first_key), last_key_(last_key), acl_categories_(acl_categories) { } } // namespace facade namespace std { using facade::ArgS; ostream& operator<<(ostream& os, facade::CmdArgList ras) { os << "["; if (!ras.empty()) { for (size_t i = 0; i < ras.size() - 1; ++i) { os << absl::CHexEscape(ArgS(ras, i)) << ","; } os << absl::CHexEscape(ArgS(ras, ras.size() - 1)); } os << "]"; return os; } ostream& operator<<(ostream& os, const facade::RespExpr& e) { using facade::RespExpr; using facade::ToSV; switch (e.type) { case RespExpr::INT64: os << "i" << get(e.u); break; case RespExpr::DOUBLE: os << "d" << get(e.u); break; case RespExpr::STRING: os << "'" << ToSV(get(e.u)) << "'"; break; case RespExpr::NIL: os << "nil"; break; case RespExpr::NIL_ARRAY: os << "[]"; break; case RespExpr::ARRAY: os << facade::RespSpan{*get(e.u)}; break; case RespExpr::ERROR: os << "e(" << ToSV(get(e.u)) << ")"; break; } return os; } ostream& operator<<(ostream& os, facade::RespSpan ras) { os << "["; if (!ras.empty()) { for (size_t i = 0; i < ras.size() - 1; ++i) { os << ras[i] << ","; } os << ras.back(); } os << "]"; return os; } ostream& operator<<(ostream& os, facade::Protocol p) { switch (p) { case facade::Protocol::REDIS: os << "REDIS"; break; case facade::Protocol::MEMCACHE: os << "MEMCACHE"; break; } return os; } } // namespace std ================================================ FILE: src/facade/facade_stats.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "base/histogram.h" namespace facade { struct ConnectionStats { size_t read_buf_capacity = 0; // total capacity of input buffers // Count of pending messages in dispatch queue uint64_t dispatch_queue_entries = 0; // Memory used by pending messages in dispatch queue size_t dispatch_queue_bytes = 0; // Count of pending parsed commands in the pipeline queue (Data Path) uint64_t pipeline_queue_entries = 0; // Memory used by pending parsed commands in the pipeline queue (Data Path) size_t pipeline_queue_bytes = 0; // total size of all publish messages (subset of dispatch_queue_bytes) size_t dispatch_queue_subscriber_bytes = 0; size_t pipeline_cmd_cache_bytes = 0; uint64_t io_read_cnt = 0; size_t io_read_bytes = 0; uint64_t command_cnt_main = 0; uint64_t command_cnt_other = 0; uint64_t pipelined_cmd_cnt = 0; uint64_t pipelined_cmd_latency = 0; // in microseconds base::Histogram pipelined_latency_hist; // distribution of per-command latencies (usec) // in microseconds, time spent waiting for the pipelined commands to start executing uint64_t pipelined_wait_latency = 0; uint64_t conn_received_cnt = 0; uint32_t num_conns_main = 0; uint32_t num_conns_other = 0; uint32_t num_blocked_clients = 0; // number of times the connection yielded due to max_busy_read_usec limit uint32_t num_read_yields = 0; uint64_t num_migrations = 0; uint64_t num_recv_provided_calls = 0; // Number of times the tls connection was closed by the time we started reading from it. uint64_t tls_accept_disconnects = 0; // number of TLS socket disconnects during the handshake // uint64_t handshakes_started = 0; uint64_t handshakes_completed = 0; // Number of events when the pipeline queue was over the limit and was throttled. uint64_t pipeline_throttle_count = 0; uint64_t pipeline_dispatch_calls = 0; uint64_t pipeline_dispatch_commands = 0; uint64_t pipeline_dispatch_flush_usec = 0; uint64_t skip_pipeline_flushing = 0; // number of times we skipped flushing the pipeline ConnectionStats& operator+=(const ConnectionStats& o); }; struct ReplyStats { struct SendStats { int64_t count = 0; int64_t total_duration = 0; SendStats& operator+=(const SendStats& other) { static_assert(sizeof(SendStats) == 16u); count += other.count; total_duration += other.total_duration; return *this; } }; // Send() operations that are written to sockets SendStats send_stats; size_t io_write_cnt = 0; size_t io_write_bytes = 0; absl::flat_hash_map err_count; size_t script_error_count = 0; // This variable can be updated directly from shard threads when they allocate memory for replies. std::atomic squashing_current_reply_size{0}; ReplyStats() = default; ReplyStats(ReplyStats&& other) noexcept; ReplyStats& operator+=(const ReplyStats& other); ReplyStats& operator=(const ReplyStats& other); }; struct FacadeStats { ConnectionStats conn_stats; ReplyStats reply_stats; FacadeStats& operator+=(const FacadeStats& other) { conn_stats += other.conn_stats; reply_stats += other.reply_stats; return *this; } }; inline thread_local FacadeStats* tl_facade_stats = nullptr; } // namespace facade ================================================ FILE: src/facade/facade_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/facade_test.h" #include #include #include "base/logging.h" namespace facade { using namespace testing; using namespace std; bool RespMatcher::MatchAndExplain(RespExpr e, MatchResultListener* listener) const { if (e.type != type_) { if (e.type == RespExpr::STRING && type_ == RespExpr::DOUBLE) { // Doubles are encoded as strings, unless RESP3 is selected. So parse string and try to // compare it. double d = 0; if (!absl::SimpleAtod(e.GetString(), &d)) { *listener << "\nCan't parse as double: " << e.GetString(); return false; } e.type = RespExpr::DOUBLE; e.u = d; } else { *listener << "\nWrong type: " << RespExpr::TypeName(e.type); return false; } } if (type_ == RespExpr::STRING || type_ == RespExpr::ERROR) { RespExpr::Buffer ebuf = e.GetBuf(); std::string_view actual{reinterpret_cast(ebuf.data()), ebuf.size()}; if (type_ == RespExpr::ERROR && !absl::StrContains(actual, exp_str_)) { *listener << "Actual does not contain '" << exp_str_ << "'"; return false; } if (type_ == RespExpr::STRING && exp_str_ != actual) { *listener << "\nActual string: " << actual; return false; } } else if (type_ == RespExpr::INT64) { auto actual = get(e.u); if (exp_int_ != actual) { *listener << "\nActual : " << actual << " expected: " << exp_int_; return false; } } else if (type_ == RespExpr::DOUBLE) { auto actual = get(e.u); if (abs(exp_double_ - actual) > 0.0001) { *listener << "\nActual : " << actual << " expected: " << exp_double_; return false; } } else if (type_ == RespExpr::ARRAY) { size_t len = get(e.u)->size(); if (len != size_t(exp_int_)) { *listener << "Actual length " << len << ", expected: " << exp_int_; return false; } } return true; } void RespMatcher::DescribeTo(std::ostream* os) const { *os << "is "; switch (type_) { case RespExpr::STRING: case RespExpr::ERROR: *os << exp_str_; break; case RespExpr::INT64: *os << exp_str_; break; case RespExpr::ARRAY: *os << "array of length " << exp_int_; break; case RespExpr::DOUBLE: *os << exp_double_; break; default: *os << "TBD"; break; } } void RespMatcher::DescribeNegationTo(std::ostream* os) const { *os << "is not "; } bool RespTypeMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const { if (e.type != type_) { *listener << "\nWrong type: " << RespExpr::TypeName(e.type); return false; } return true; } void RespTypeMatcher::DescribeTo(std::ostream* os) const { *os << "is " << RespExpr::TypeName(type_); } void RespTypeMatcher::DescribeNegationTo(std::ostream* os) const { *os << "is not " << RespExpr::TypeName(type_); } void PrintTo(const RespExpr::Vec& vec, std::ostream* os) { *os << "Vec: ["; if (!vec.empty()) { for (size_t i = 0; i < vec.size() - 1; ++i) { *os << vec[i] << ","; } *os << vec.back(); } *os << "]\n"; } } // namespace facade ================================================ FILE: src/facade/facade_test.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "facade/resp_expr.h" namespace facade { class RespMatcher { public: RespMatcher(std::string_view val, RespExpr::Type t = RespExpr::STRING) : type_(t), exp_str_(val) { } RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64) : type_(t), exp_int_(val) { } RespMatcher(double_t val, RespExpr::Type t = RespExpr::DOUBLE) : type_(t), exp_double_(val) { } using is_gtest_matcher = void; bool MatchAndExplain(RespExpr e, testing::MatchResultListener*) const; void DescribeTo(std::ostream* os) const; void DescribeNegationTo(std::ostream* os) const; private: RespExpr::Type type_; std::string exp_str_; int64_t exp_int_ = 0; double_t exp_double_ = 0; }; class RespTypeMatcher { public: RespTypeMatcher(RespExpr::Type type) : type_(type) { } using is_gtest_matcher = void; bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const; void DescribeTo(std::ostream* os) const; void DescribeNegationTo(std::ostream* os) const; private: RespExpr::Type type_; }; inline ::testing::PolymorphicMatcher ErrArg(std::string_view str) { return ::testing::MakePolymorphicMatcher(RespMatcher(str, RespExpr::ERROR)); } inline ::testing::PolymorphicMatcher IntArg(int64_t ival) { return ::testing::MakePolymorphicMatcher(RespMatcher(ival)); } inline ::testing::PolymorphicMatcher DoubleArg(double_t dval) { return ::testing::MakePolymorphicMatcher(RespMatcher(dval)); } inline ::testing::PolymorphicMatcher ArrLen(size_t len) { return ::testing::MakePolymorphicMatcher(RespMatcher((int64_t)len, RespExpr::ARRAY)); } inline ::testing::PolymorphicMatcher ArgType(RespExpr::Type t) { return ::testing::MakePolymorphicMatcher(RespTypeMatcher(t)); } MATCHER_P(RespArray, value, "") { return ExplainMatchResult( testing::AllOf(ArgType(RespExpr::ARRAY), testing::Property(&RespExpr::GetVec, value)), arg, result_listener); } template auto RespElementsAre(const Args&... matchers) { return RespArray(::testing::ElementsAre(matchers...)); } inline bool operator==(const RespExpr& left, std::string_view s) { return left.type == RespExpr::STRING && ToSV(left.GetBuf()) == s; } inline bool operator==(const RespExpr& left, int64_t val) { return left.type == RespExpr::INT64 && left.GetInt() == val; } inline bool operator!=(const RespExpr& left, std::string_view s) { return !(left == s); } inline bool operator==(std::string_view s, const RespExpr& right) { return right == s; } inline bool operator!=(std::string_view s, const RespExpr& right) { return !(right == s); } void PrintTo(const RespExpr::Vec& vec, std::ostream* os); } // namespace facade ================================================ FILE: src/facade/facade_types.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "common/arg_range.h" #include "common/backed_args.h" #include "facade/op_status.h" namespace facade { #if defined(__clang__) #if defined(__has_feature) #if __has_feature(address_sanitizer) constexpr size_t kSanitizerOverhead = 24u; #else constexpr size_t kSanitizerOverhead = 0u; #endif #endif #else #ifdef __SANITIZE_ADDRESS__ constexpr size_t kSanitizerOverhead = 24u; #else constexpr size_t kSanitizerOverhead = 0u; #endif #endif enum class Protocol : uint8_t { MEMCACHE = 1, REDIS = 2 }; enum class CollectionType : uint8_t { ARRAY, SET, MAP, PUSH }; using MutableSlice = std::string_view; using CmdArgVec = std::vector; using cmn::ArgSlice; using CmdArgList = cmn::ArgSlice; using cmn::ArgRange; class ParsedArgs { public: ParsedArgs() = default; // References backed arguments. The object must outlive this ParsedArgs. ParsedArgs(const cmn::BackedArguments& bargs) // NOLINT google-explicit-constructor : args_(&bargs) { } ParsedArgs(ArgSlice slice) // NOLINT google-explicit-constructor : args_(slice) { } ParsedArgs(const ParsedArgs& other) = default; ParsedArgs& operator=(const ParsedArgs& bargs) = default; size_t size() const { return std::visit([](const auto& args) { return args.size(); }, args_); } bool empty() const { return size() == 0; } ParsedArgs Tail() const { return std::visit([](const auto& args) { return args.Tail(); }, args_); } std::string_view Front() const { return std::visit([](const auto& args) { return args.front(); }, args_); } ArgSlice ToSlice(CmdArgVec* scratch) const { return std::visit([scratch](const auto& args) { return args.ToSlice(scratch); }, args_); } void ToVec(CmdArgVec* vec) const { std::visit([vec](const auto& args) { return args.ToVec(vec); }, args_); } private: struct WrapperBacked { WrapperBacked(const cmn::BackedArguments* args) : args_(args) { // NOLINT } const cmn::BackedArguments* args_; uint32_t index_ = 0; ParsedArgs Tail() const { ParsedArgs res(*args_); WrapperBacked* wb = std::get_if(&res.args_); wb->index_ = index_ + 1; return res; }; size_t size() const { return args_->size() - index_; } std::string_view front() const { return args_->at(index_); } ArgSlice ToSlice(CmdArgVec* scratch) const { ToVec(scratch); return *scratch; } void ToVec(CmdArgVec* vec) const { vec->assign(args_->begin() + index_, args_->end()); } }; struct Slice : public ArgSlice { using ArgSlice::ArgSlice; Slice(ArgSlice other) : ArgSlice(other) { // NOLINT } ParsedArgs Tail() const { return ParsedArgs{subspan(1)}; } ArgSlice ToSlice(void* /*scratch*/) const { return *this; } void ToVec(CmdArgVec* vec) const { vec->assign(begin(), end()); } }; std::variant args_; }; inline std::string_view ToSV(std::string_view slice) { return slice; } inline std::string_view ToSV(const std::string& slice) { return slice; } inline std::string_view ToSV(std::string&& slice) = delete; inline std::string_view ArgS(ArgSlice args, size_t i) { return args[i]; } struct ErrorReply { explicit ErrorReply(std::string&& msg, std::string_view kind = {}) : message{std::move(msg)}, kind{kind} { } explicit ErrorReply(std::string_view msg, std::string_view kind = {}) : message{msg}, kind{kind} { } explicit ErrorReply(const char* msg, std::string_view kind = {}) // to resolve ambiguity of constructors above : message{std::string_view{msg}}, kind{kind} { } ErrorReply(OpStatus status) // NOLINT google-explicit-constructor) : status{status} { } std::string_view ToSv() const { return std::visit(cmn::kToSV, message); } std::variant message; std::string_view kind; std::optional status{std::nullopt}; }; struct MemcacheCmdFlags { MemcacheCmdFlags() : raw(0) { } union { uint16_t raw = 0; struct { uint16_t no_reply : 1; // q uint16_t meta : 1; // meta flags uint16_t base64 : 1; // b uint16_t return_flags : 1; // f uint16_t return_value : 1; // v uint16_t return_ttl : 1; // t uint16_t return_access_time : 1; // l uint16_t return_hit : 1; // h uint16_t return_cas : 1; // c }; }; }; static_assert(sizeof(MemcacheCmdFlags) == 2); constexpr unsigned long long operator""_MB(unsigned long long x) { return 1024L * 1024L * x; } constexpr unsigned long long operator""_KB(unsigned long long x) { return 1024L * x; } void ResetStats(); // Constants for socket bufring. constexpr uint16_t kRecvSockGid = 0; // Size of the buffer in bufring (kRecvSockGid). constexpr size_t kRecvBufSize = 1500; } // namespace facade namespace std { ostream& operator<<(ostream& os, cmn::ArgSlice args); ostream& operator<<(ostream& os, facade::Protocol protocol); } // namespace std ================================================ FILE: src/facade/memcache_parser.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/memcache_parser.h" #include #include #include #include #include #include #include #include "base/logging.h" #include "base/stl_util.h" #include "facade/facade_types.h" namespace facade { using namespace std; using MP = MemcacheParser; namespace { int64_t ToAbsolute(uint32_t ts, uint64_t now) { // if expire_ts is greater than month it's a unix timestamp // https://github.com/memcached/memcached/blob/master/doc/protocol.txt#L139 constexpr uint32_t kExpireLimit = 60 * 60 * 24 * 30; int64_t expire_ts = ts && ts <= kExpireLimit ? ts + now : ts; return expire_ts; } MP::CmdType From(string_view token) { static absl::flat_hash_map cmd_map{ {"set", MP::SET}, {"add", MP::ADD}, {"replace", MP::REPLACE}, {"append", MP::APPEND}, {"prepend", MP::PREPEND}, {"cas", MP::CAS}, {"get", MP::GET}, {"gets", MP::GETS}, {"gat", MP::GAT}, {"gats", MP::GATS}, {"stats", MP::STATS}, {"incr", MP::INCR}, {"decr", MP::DECR}, {"delete", MP::DELETE}, {"flush_all", MP::FLUSHALL}, {"quit", MP::QUIT}, {"version", MP::VERSION}, }; if (token.size() == 2) { // META_COMMANDS if (token[0] != 'm') return MP::INVALID; switch (token[1]) { case 's': return MP::META_SET; case 'g': return MP::META_GET; case 'd': return MP::META_DEL; case 'a': return MP::META_ARITHM; case 'n': return MP::META_NOOP; case 'e': return MP::META_DEBUG; } return MP::INVALID; } if (token.size() > 2) { auto it = cmd_map.find(token); if (it == cmd_map.end()) return MP::INVALID; return it->second; } return MP::INVALID; } MP::Result ParseStore(ArgSlice tokens, int64_t now, MP::Command* res, uint32_t max_value_len) { DCHECK_EQ(res->size(), 0u); const size_t num_tokens = tokens.size(); unsigned opt_pos = 4; if (res->type == MP::CAS) { if (num_tokens <= opt_pos) return MP::PARSE_ERROR; ++opt_pos; } // tokens[0] is key uint32_t bytes_len = 0; uint32_t flags; uint32_t expire_ts; if (!absl::SimpleAtoi(tokens[1], &flags) || !absl::SimpleAtoi(tokens[2], &expire_ts) || !absl::SimpleAtoi(tokens[3], &bytes_len)) return MP::BAD_INT; if (bytes_len > max_value_len) { LOG_EVERY_T(WARNING, 1) << "Memcache value size " << bytes_len << " exceeds max_bulk_len " << max_value_len; return MP::PARSE_ERROR; } res->expire_ts = ToAbsolute(expire_ts, now); if (res->type == MP::CAS && !absl::SimpleAtoi(tokens[4], &res->cas_unique)) { return MP::BAD_INT; } res->flags = flags; if (num_tokens == opt_pos + 1) { if (tokens[opt_pos] == "noreply") { res->cmd_flags.no_reply = true; } else { return MP::PARSE_ERROR; } } else if (num_tokens > opt_pos + 1) { return MP::PARSE_ERROR; } string_view key = tokens[0]; res->backed_args->PushArg(key); res->backed_args->PushArg(bytes_len); return MP::OK; } MP::Result ParseValueless(ArgSlice tokens, int64_t now, MP::Command* res) { const size_t num_tokens = tokens.size(); size_t key_pos = 0; uint32_t expire_ts; if (res->type == MP::GAT || res->type == MP::GATS) { if (!absl::SimpleAtoi(tokens[0], &expire_ts)) { return MP::BAD_INT; } res->expire_ts = ToAbsolute(expire_ts, now); ++key_pos; } // We support only `flushall` or `flushall 0` if (key_pos < num_tokens && res->type == MP::FLUSHALL) { DCHECK_EQ(res->size(), 0u); int delay = 0; if (key_pos + 1 == num_tokens && absl::SimpleAtoi(tokens[key_pos], &delay) && delay == 0) return MP::OK; return MP::PARSE_ERROR; } if (key_pos >= num_tokens) return MP::PARSE_ERROR; res->cmd_flags.return_cas = (res->type == MP::GETS || res->type == MP::GATS); res->cmd_flags.return_value = true; res->cmd_flags.return_flags = true; res->backed_args->PushArg(tokens[key_pos++]); if (key_pos < num_tokens && res->type == MP::STATS) return MP::PARSE_ERROR; // we don't support additional arguments to stats for now if (res->type == MP::INCR || res->type == MP::DECR) { if (key_pos == num_tokens) return MP::PARSE_ERROR; if (!absl::SimpleAtoi(tokens[key_pos], &res->delta)) return MP::BAD_DELTA; ++key_pos; } while (key_pos < num_tokens) { res->backed_args->PushArg(tokens[key_pos++]); } if (res->type >= MP::DELETE) { // write commands if (res->size() > 1 && res->backed_args->back() == "noreply") { res->cmd_flags.no_reply = true; res->backed_args->PopArg(); } } return MP::OK; } bool ParseMetaMode(char m, MP::Command* res) { if (res->type == MP::SET) { switch (m) { case 'E': res->type = MP::ADD; break; case 'A': res->type = MP::APPEND; break; case 'R': res->type = MP::REPLACE; break; case 'P': res->type = MP::PREPEND; break; case 'S': break; default: return false; } return true; } if (res->type == MP::INCR) { switch (m) { case 'I': case '+': break; case 'D': case '-': res->type = MP::DECR; break; default: return false; } return true; } return false; } // See https://raw.githubusercontent.com/memcached/memcached/refs/heads/master/doc/protocol.txt MP::Result ParseMeta(ArgSlice tokens, int64_t now, MP::Command* res, uint32_t max_value_len) { DCHECK(!tokens.empty()); if (res->type == MP::META_DEBUG) { LOG(ERROR) << "meta debug not yet implemented"; return MP::PARSE_ERROR; } if (tokens[0].size() > 250) return MP::PARSE_ERROR; res->cmd_flags.meta = true; res->flags = 0; res->expire_ts = 0; string_view arg0 = tokens[0]; tokens.remove_prefix(1); uint32_t bytes_len = 0; // We emulate the behavior by returning the high level commands. // TODO: we should reverse the interface in the future, so that a high level command // will be represented in MemcacheParser::Command by a meta command with flags. // high level commands should not be part of the interface in the future. switch (res->type) { case MP::META_GET: res->type = MP::GET; break; case MP::META_DEL: res->type = MP::DELETE; break; case MP::META_SET: if (tokens.empty()) return MP::PARSE_ERROR; if (!absl::SimpleAtoi(tokens[0], &bytes_len)) return MP::BAD_INT; if (bytes_len > max_value_len) { LOG_EVERY_T(WARNING, 1) << "Memcache value size " << bytes_len << " exceeds max_bulk_len " << max_value_len; return MP::PARSE_ERROR; } res->type = MP::SET; tokens.remove_prefix(1); break; case MP::META_ARITHM: res->type = MP::INCR; res->delta = 1; break; default: return MP::PARSE_ERROR; } string blob; uint32_t expire_ts; for (size_t i = 0; i < tokens.size(); ++i) { string_view token = tokens[i]; switch (token[0]) { case 'T': if (!absl::SimpleAtoi(token.substr(1), &expire_ts)) return MP::BAD_INT; res->expire_ts = ToAbsolute(expire_ts, now); if (res->type == MP::GET) res->type = MP::GAT; break; case 'b': if (token.size() != 1) return MP::PARSE_ERROR; if (!absl::Base64Unescape(arg0, &blob)) return MP::PARSE_ERROR; arg0 = blob; res->cmd_flags.base64 = true; break; case 'F': if (!absl::SimpleAtoi(token.substr(1), &res->flags)) return MP::BAD_INT; break; case 'M': if (token.size() != 2 || !ParseMetaMode(token[1], res)) return MP::PARSE_ERROR; break; case 'D': if (!absl::SimpleAtoi(token.substr(1), &res->delta)) return MP::BAD_INT; break; case 'q': res->cmd_flags.no_reply = true; break; case 'f': res->cmd_flags.return_flags = true; break; case 'v': res->cmd_flags.return_value = true; break; case 't': res->cmd_flags.return_ttl = true; break; case 'l': res->cmd_flags.return_access_time = true; break; case 'h': res->cmd_flags.return_hit = true; break; case 'c': res->cmd_flags.return_cas = true; break; default: LOG(WARNING) << "unknown meta flag: " << token; // not yet implemented return MP::PARSE_ERROR; } } res->backed_args->PushArg(arg0); if (MP::IsStoreCmd(res->type)) { res->backed_args->PushArg(bytes_len); } return MP::OK; } } // namespace auto MP::Parse(string_view str, uint32_t* consumed, Command* cmd) -> Result { DVLOG(1) << "Parsing memcache input: [" << str << "]"; *consumed = 0; if (val_len_to_read_ > 0) { return ConsumeValue(str, consumed, cmd); } cmd->cmd_flags.raw = 0; // re-initialize size_t pos = str.find('\n'); if (pos == string_view::npos) { // We need more data to parse the command. For get/gets commands this line can be very long. // we limit maximum buffer capacity in the higher levels using max_client_iobuf_len. tmp_buf_.append(str); *consumed = str.size(); return INPUT_PENDING; } *consumed = pos + 1; string_view main_cmd; if (tmp_buf_.empty()) { main_cmd = str.substr(0, pos); } else { tmp_buf_.append(str.substr(0, pos)); main_cmd = tmp_buf_; } // main_cmd is \n stripped, so it should end with \r. if (main_cmd.empty() || main_cmd.back() != '\r') { return PARSE_ERROR; } main_cmd.remove_suffix(1); // remove trailing \r // cas [noreply]\r\n // get *\r\n // ms *\r\n absl::InlinedVector tokens = absl::StrSplit(main_cmd, ' ', absl::SkipWhitespace()); Result res = ParseInternal(absl::MakeSpan(tokens), cmd); tmp_buf_.clear(); if (val_len_to_read_ > 0) return ConsumeValue(str.substr(pos + 1), consumed, cmd); return res; }; auto MP::ParseInternal(ArgSlice tokens_view, Command* cmd) -> Result { if (tokens_view.empty()) return PARSE_ERROR; cmd->type = From(tokens_view[0]); if (cmd->type == INVALID) { return UNKNOWN_CMD; } tokens_view.remove_prefix(1); cmd->backed_args->clear(); if (cmd->type <= CAS) { // Store command if (tokens_view.size() < 4 || tokens_view[0].size() > 250) { // key length limit return MP::PARSE_ERROR; } auto res = ParseStore(tokens_view, last_unix_time_, cmd, max_value_len_); if (res != MP::OK) return res; val_len_to_read_ = cmd->value().size() + 2; return MP::OK; } if (cmd->type >= META_SET) { if (tokens_view.empty()) return MP::PARSE_ERROR; auto res = ParseMeta(tokens_view, last_unix_time_, cmd, max_value_len_); if (res != MP::OK) return res; if (IsStoreCmd(cmd->type)) { val_len_to_read_ = cmd->value().size() + 2; res = MP::OK; } return res; } if (tokens_view.empty()) { if (base::_in(cmd->type, {MP::STATS, MP::FLUSHALL, MP::QUIT, MP::VERSION, MP::META_NOOP})) { return MP::OK; } return MP::PARSE_ERROR; } return ParseValueless(tokens_view, last_unix_time_, cmd); } auto MP::ConsumeValue(std::string_view str, uint32_t* consumed, Command* dest) -> Result { DCHECK_EQ(dest->size(), 2u); // key and value DCHECK_GT(val_len_to_read_, 0u); if (val_len_to_read_ > 2) { uint32_t need_copy = val_len_to_read_ - 2; uint32_t dest_len = dest->backed_args->elem_len(1); DCHECK_GE(dest_len, need_copy); // should be ensured during parsing char* start = dest->value_ptr() + (dest_len - need_copy); uint32_t to_fill = std::min(need_copy, str.size()); if (to_fill) { memcpy(start, str.data(), to_fill); val_len_to_read_ -= to_fill; *consumed += to_fill; str.remove_prefix(to_fill); } } if (str.empty()) { return MP::INPUT_PENDING; } DCHECK(val_len_to_read_ <= 2u && val_len_to_read_ > 0); // consume \r\n char end[] = "\r\n"; do { if (str.front() != end[2 - val_len_to_read_]) // val_len_to_read_ 2 -> '\r', 1 -> '\n' return MP::PARSE_ERROR; ++(*consumed); --val_len_to_read_; str.remove_prefix(1); } while (val_len_to_read_ && !str.empty()); return val_len_to_read_ > 0 ? MP::INPUT_PENDING : MP::OK; } } // namespace facade ================================================ FILE: src/facade/memcache_parser.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "common/backed_args.h" #include "facade/facade_types.h" namespace facade { // Memcache parser does not parse value blobs, only the commands. // The expectation is that the caller will parse the command and // then will follow up with reading the blob data directly from source. class MemcacheParser { public: explicit MemcacheParser(uint32_t max_value_len = UINT32_MAX) : max_value_len_(max_value_len) { } enum CmdType : uint8_t { INVALID = 0, SET = 1, ADD = 2, REPLACE = 3, APPEND = 4, PREPEND = 5, CAS = 6, // Retrieval GET = 10, GETS = 11, GAT = 12, GATS = 13, STATS = 14, QUIT = 20, VERSION = 21, // The rest of write commands. DELETE = 31, INCR = 32, DECR = 33, FLUSHALL = 34, // META_COMMANDS META_NOOP = 50, META_SET = 51, META_DEL = 52, META_ARITHM = 53, META_GET = 54, META_DEBUG = 55, }; // According to https://github.com/memcached/memcached/wiki/Commands#standard-protocol struct Command { Command() = default; Command(const Command&) = delete; Command(Command&&) noexcept = default; CmdType type = INVALID; std::string_view key() const { return backed_args->empty() ? std::string_view{} : backed_args->Front(); } // For STORE commands, value is at index 1. // For both key and value we provide convenience accessors that return empty string_view // if not present. std::string_view value() const { return backed_args->size() < 2 ? std::string_view{} : backed_args->at(1); } size_t size() const { return backed_args->size(); } char* value_ptr() { // NOLINT return backed_args->data(1); } union { uint64_t cas_unique = 0; // for CAS COMMAND uint64_t delta; // for DECR/INCR commands. }; int64_t expire_ts = 0; // unix time (expire_ts > month) in seconds // flags for STORE commands uint32_t flags = 0; MemcacheCmdFlags cmd_flags; // Does not own this object, only references it. cmn::BackedArguments* backed_args = nullptr; }; static_assert(sizeof(Command) == 40); enum Result : uint8_t { OK, INPUT_PENDING, UNKNOWN_CMD, BAD_INT, PARSE_ERROR, // request parse error, but can continue parsing within the same connection. BAD_DELTA, }; static bool IsStoreCmd(CmdType type) { return type >= SET && type <= CAS; } size_t UsedMemory() const { return tmp_buf_.capacity(); } void Reset() { val_len_to_read_ = 0; tmp_buf_.clear(); } Result Parse(std::string_view str, uint32_t* consumed, Command* res); void set_last_unix_time(int64_t t) { last_unix_time_ = t; } private: Result ConsumeValue(std::string_view str, uint32_t* consumed, Command* dest); Result ParseInternal(ArgSlice tokens_view, Command* cmd); uint32_t val_len_to_read_ = 0; uint32_t max_value_len_ = UINT32_MAX; std::string tmp_buf_; int64_t last_unix_time_ = 0; }; } // namespace facade ================================================ FILE: src/facade/memcache_parser_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/memcache_parser.h" #include #include "absl/strings/str_cat.h" #include "base/gtest.h" #include "base/logging.h" #include "facade/facade_test.h" using namespace testing; using namespace std; namespace facade { class MCParserTest : public testing::Test { protected: MCParserTest() { cmd_.backed_args = &backed_args_; } MemcacheParser::Result Parse(string_view input) { parser_.Reset(); return parser_.Parse(input, &consumed_, &cmd_); } vector ToArgs() const { return {cmd_.backed_args->begin(), cmd_.backed_args->end()}; } MemcacheParser parser_; cmn::BackedArguments backed_args_; MemcacheParser::Command cmd_; uint32_t consumed_; }; TEST_F(MCParserTest, Basic) { MemcacheParser::Result st = Parse("set a 1 20 3\r\n"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ("a", cmd_.key()); EXPECT_EQ(1, cmd_.flags); EXPECT_EQ(20, cmd_.expire_ts); EXPECT_EQ(3, cmd_.value().size()); EXPECT_EQ(MemcacheParser::SET, cmd_.type); st = Parse("quit\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(MemcacheParser::QUIT, cmd_.type); } TEST_F(MCParserTest, Incr) { MemcacheParser::Result st = Parse("incr a\r\n"); EXPECT_EQ(MemcacheParser::PARSE_ERROR, st); st = Parse("incr a 1\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(MemcacheParser::INCR, cmd_.type); EXPECT_EQ("a", cmd_.key()); EXPECT_EQ(1, cmd_.delta); EXPECT_FALSE(cmd_.cmd_flags.no_reply); st = Parse("incr a -1\r\n"); EXPECT_EQ(MemcacheParser::BAD_DELTA, st); st = Parse("decr b 10 noreply\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(MemcacheParser::DECR, cmd_.type); EXPECT_EQ(10, cmd_.delta); } TEST_F(MCParserTest, Stats) { MemcacheParser::Result st = Parse("stats foo\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(consumed_, 11); EXPECT_EQ(cmd_.type, MemcacheParser::STATS); EXPECT_EQ("foo", cmd_.key()); st = Parse("stats \r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(consumed_, 9); EXPECT_EQ(cmd_.type, MemcacheParser::STATS); EXPECT_EQ(0, cmd_.size()); st = Parse("stats fpp bar\r\n"); EXPECT_EQ(MemcacheParser::PARSE_ERROR, st); } TEST_F(MCParserTest, NoreplyBasic) { MemcacheParser::Result st = Parse("set mykey 1 2 3 noreply\r\n"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ("mykey", cmd_.key()); EXPECT_EQ(1, cmd_.flags); EXPECT_EQ(2, cmd_.expire_ts); EXPECT_EQ(3, cmd_.value().size()); EXPECT_EQ(MemcacheParser::SET, cmd_.type); EXPECT_TRUE(cmd_.cmd_flags.no_reply); st = Parse("set mykey2 4 5 6\r\n"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ("mykey2", cmd_.key()); EXPECT_EQ(4, cmd_.flags); EXPECT_EQ(5, cmd_.expire_ts); EXPECT_EQ(6, cmd_.value().size()); EXPECT_EQ(MemcacheParser::SET, cmd_.type); EXPECT_FALSE(cmd_.cmd_flags.no_reply); } TEST_F(MCParserTest, Meta) { MemcacheParser::Result st = Parse("ms key1 "); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ(8, consumed_); st = parser_.Parse("6 T1 F2\r\naaaaaa\r\n", &consumed_, &cmd_); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(17, consumed_); EXPECT_EQ(MemcacheParser::SET, cmd_.type); EXPECT_EQ("key1", cmd_.key()); EXPECT_EQ(2, cmd_.flags); EXPECT_EQ(1, cmd_.expire_ts); st = Parse("ms 16nXnNeV150= 5 b ME\r\nbbbbb"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ(29, consumed_); EXPECT_EQ(MemcacheParser::ADD, cmd_.type); EXPECT_EQ("שלום", cmd_.key()); EXPECT_EQ(5, cmd_.value().size()); st = Parse("mg 16nXnNeV150= b\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(19, consumed_); EXPECT_EQ(MemcacheParser::GET, cmd_.type); EXPECT_EQ("שלום", cmd_.key()); st = Parse("ma val b\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(10, consumed_); EXPECT_EQ(MemcacheParser::INCR, cmd_.type); st = Parse("ma val M- D10\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(15, consumed_); EXPECT_EQ(MemcacheParser::DECR, cmd_.type); EXPECT_EQ(10, cmd_.delta); st = Parse("mg key f v t l h\r\n"); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(18, consumed_); EXPECT_EQ(MemcacheParser::GET, cmd_.type); EXPECT_EQ("key", cmd_.key()); EXPECT_TRUE(cmd_.cmd_flags.return_flags); EXPECT_TRUE(cmd_.cmd_flags.return_value); EXPECT_TRUE(cmd_.cmd_flags.return_ttl); EXPECT_TRUE(cmd_.cmd_flags.return_access_time); EXPECT_TRUE(cmd_.cmd_flags.return_hit); } TEST_F(MCParserTest, Gat) { auto res = Parse("gat 1000 foo bar baz\r\n"); EXPECT_EQ(MemcacheParser::OK, res); EXPECT_EQ(consumed_, 22); EXPECT_EQ(cmd_.type, MemcacheParser::GAT); EXPECT_THAT(ToArgs(), ElementsAre("foo", "bar", "baz")); EXPECT_EQ(cmd_.expire_ts, 1000); res = Parse("gat foo bar\r\n"); EXPECT_EQ(MemcacheParser::BAD_INT, res); res = Parse("gats 1000 foo bar baz\r\n"); EXPECT_EQ(MemcacheParser::OK, res); EXPECT_EQ(consumed_, 23); EXPECT_EQ(cmd_.type, MemcacheParser::GATS); EXPECT_THAT(ToArgs(), ElementsAre("foo", "bar", "baz")); EXPECT_EQ(cmd_.expire_ts, 1000); parser_.set_last_unix_time(2000); res = Parse("gats 1000 foo bar baz\r\n"); EXPECT_EQ(MemcacheParser::OK, res); EXPECT_EQ(cmd_.expire_ts, 3000); res = Parse("gats 100\r\n"); EXPECT_EQ(MemcacheParser::PARSE_ERROR, res); res = Parse("gat 100\r\n"); EXPECT_EQ(MemcacheParser::PARSE_ERROR, res); } TEST_F(MCParserTest, ValueState) { auto st = Parse("ms key1 6\r\nabc"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ(consumed_, 14); st = parser_.Parse("de", &consumed_, &cmd_); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ(consumed_, 2); st = parser_.Parse("f\r", &consumed_, &cmd_); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ(consumed_, 2); EXPECT_EQ(cmd_.value(), "abcdef"); st = parser_.Parse("\n", &consumed_, &cmd_); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(consumed_, 1); } TEST_F(MCParserTest, MaxValueLen) { MemcacheParser capped_parser(10); cmn::BackedArguments ba; MemcacheParser::Command cmd; cmd.backed_args = &ba; uint32_t consumed; // Value within limit — accepted. auto st = capped_parser.Parse("set k 0 0 10\r\n", &consumed, &cmd); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); // Value exceeds limit — rejected. capped_parser.Reset(); st = capped_parser.Parse("set k 0 0 11\r\n", &consumed, &cmd); EXPECT_EQ(MemcacheParser::PARSE_ERROR, st); // Meta set within limit. capped_parser.Reset(); st = capped_parser.Parse("ms key 10\r\n", &consumed, &cmd); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); // Meta set exceeds limit. capped_parser.Reset(); st = capped_parser.Parse("ms key 11\r\n", &consumed, &cmd); EXPECT_EQ(MemcacheParser::PARSE_ERROR, st); } TEST_F(MCParserTest, ParseError) { EXPECT_EQ(MemcacheParser::PARSE_ERROR, Parse("ms key1 3\r\nabcd")); EXPECT_EQ(MemcacheParser::INPUT_PENDING, Parse("ms key1 3\r\nabc")); EXPECT_EQ(MemcacheParser::PARSE_ERROR, parser_.Parse("\ra", &consumed_, &cmd_)); EXPECT_EQ(MemcacheParser::INPUT_PENDING, Parse("ms key1 3\r\nabc\r")); EXPECT_EQ(MemcacheParser::PARSE_ERROR, parser_.Parse("\r", &consumed_, &cmd_)); } // Test for the bug where \r\n command line terminator split across TCP packets // would cause parse errors. TEST_F(MCParserTest, SplitCRLFInCommandLine) { // Simulate TCP fragmentation where command line ends with \r but \n comes in next packet auto st = Parse("set k10 0 0 3 noreply\r"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); EXPECT_EQ(consumed_, 22); // Now the \n arrives followed by the value and another command st = parser_.Parse("\nd10\r\nget k11\r\n", &consumed_, &cmd_); EXPECT_EQ(MemcacheParser::OK, st); EXPECT_EQ(consumed_, 6); // \n + d10\r\n EXPECT_EQ(cmd_.type, MemcacheParser::SET); EXPECT_EQ(cmd_.key(), "k10"); EXPECT_EQ(cmd_.value(), "d10"); EXPECT_TRUE(cmd_.cmd_flags.no_reply); } // Test edge case: empty command line when \r\n split TEST_F(MCParserTest, SplitCRLFEmptyCommand) { // Just \r with nothing before it auto st = Parse("\r"); EXPECT_EQ(MemcacheParser::INPUT_PENDING, st); // Now \n arrives - should be parse error since command line is empty st = parser_.Parse("\nget key\r\n", &consumed_, &cmd_); EXPECT_EQ(MemcacheParser::PARSE_ERROR, st); } class MCParserNoreplyTest : public MCParserTest { protected: void RunTest(string_view str, bool noreply, MemcacheParser::Result expected_res = MemcacheParser::OK) { MemcacheParser::Result st = Parse(str); EXPECT_EQ(expected_res, st); EXPECT_EQ(cmd_.cmd_flags.no_reply, noreply); } }; TEST_F(MCParserNoreplyTest, StoreCommands) { RunTest("set mykey 0 0 3 noreply\r\n", true, MemcacheParser::INPUT_PENDING); RunTest("set mykey 0 0 3\r\n", false, MemcacheParser::INPUT_PENDING); RunTest("add mykey 0 0 3\r\n", false, MemcacheParser::INPUT_PENDING); RunTest("replace mykey 0 0 3\r\n", false, MemcacheParser::INPUT_PENDING); RunTest("append mykey 0 0 3\r\n", false, MemcacheParser::INPUT_PENDING); RunTest("prepend mykey 0 0 3\r\n", false, MemcacheParser::INPUT_PENDING); } TEST_F(MCParserNoreplyTest, Other) { RunTest("quit\r\n", false); RunTest("delete mykey\r\n", false); RunTest("incr mykey 1\r\n", false); RunTest("decr mykey 1\r\n", false); RunTest("flush_all\r\n", false); } TEST_F(MCParserNoreplyTest, LargeGetRequest) { std::string large_request = "get"; for (size_t i = 0; i < 100; ++i) { absl::StrAppend(&large_request, " mykey", i, " "); } absl::StrAppend(&large_request, "\r\n"); RunTest(large_request, false); EXPECT_EQ(cmd_.type, MemcacheParser::CmdType::GET); auto keys = ToArgs(); EXPECT_TRUE(std::all_of(keys.begin(), keys.end(), [i = 0u](const auto& elem) mutable { return elem == absl::StrCat("mykey", i++); })); } } // namespace facade ================================================ FILE: src/facade/ok_main.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "base/init.h" #include "facade/conn_context.h" #include "facade/dragonfly_connection.h" #include "facade/dragonfly_listener.h" #include "facade/reply_builder.h" #include "facade/service_interface.h" #include "util/accept_server.h" #include "util/fibers/pool.h" ABSL_FLAG(uint32_t, port, 6379, "server port"); using namespace util; using namespace std; using absl::GetFlag; namespace facade { namespace { struct CmdContext : public facade::ParsedCommand { void ReuseInternal() final { } }; class OkService : public ServiceInterface { public: DispatchResult DispatchCommand(ParsedArgs args, ParsedCommand* cmd, AsyncPreference) final { cmd->rb()->SendOk(); return DispatchResult::OK; } DispatchManyResult DispatchManyCommands(std::function arg_gen, unsigned count, SinkReplyBuilder* builder, ConnectionContext* cntx) final { for (unsigned i = 0; i < count; i++) { ParsedArgs args = arg_gen(); ParsedCommand* cmd = AllocateParsedCommand(); cmd->Init(builder, cntx); DispatchCommand(args, cmd, AsyncPreference::ONLY_SYNC); delete cmd; } DispatchManyResult result{ .processed = static_cast(count), .account_in_stats = true, }; return result; } DispatchResult DispatchMC(ParsedCommand* cmd, AsyncPreference) final { cmd->rb()->SendError(""); return DispatchResult::OK; } ConnectionContext* CreateContext(Connection* owner) final { return new ConnectionContext{owner}; } ParsedCommand* AllocateParsedCommand() final { return new CmdContext{}; } }; void RunEngine(ProactorPool* pool, AcceptServer* acceptor) { OkService service; Connection::Init(pool->size()); pool->Await([](auto*) { tl_facade_stats = new FacadeStats; }); acceptor->AddListener(GetFlag(FLAGS_port), new Listener{Protocol::REDIS, &service}); acceptor->Run(); acceptor->Wait(); } } // namespace } // namespace facade #ifdef __linux__ #define USE_URING 1 #else #define USE_URING 0 #endif int main(int argc, char* argv[]) { MainInitGuard guard(&argc, &argv); CHECK_GT(GetFlag(FLAGS_port), 0u); #if USE_URING unique_ptr pp(fb2::Pool::IOUring(1024)); #else unique_ptr pp(fb2::Pool::Epoll()); #endif pp->Run(); AcceptServer acceptor(pp.get()); facade::RunEngine(pp.get(), &acceptor); pp->Stop(); return 0; } ================================================ FILE: src/facade/op_status.cc ================================================ #include "facade/op_status.h" #include "base/logging.h" #include "facade/error.h" #include "facade/resp_expr.h" namespace facade { std::string_view StatusToMsg(OpStatus status) { switch (status) { case OpStatus::OK: return "OK"; case OpStatus::KEY_NOTFOUND: return kKeyNotFoundErr; case OpStatus::WRONG_TYPE: return kWrongTypeErr; case OpStatus::WRONG_JSON_TYPE: return kWrongJsonTypeErr; case OpStatus::OUT_OF_RANGE: return kIndexOutOfRange; case OpStatus::INVALID_FLOAT: return kInvalidFloatErr; case OpStatus::INVALID_INT: return kInvalidIntErr; case OpStatus::SYNTAX_ERR: return kSyntaxErr; case OpStatus::OUT_OF_MEMORY: return kOutOfMemory; case OpStatus::CORRUPTED_HLL: return "-INVALIDOBJ Corrupted HLL object detected."; case OpStatus::BUSY_GROUP: return "-BUSYGROUP Consumer Group name already exists"; case OpStatus::INVALID_NUMERIC_RESULT: return kInvalidNumericResult; case OpStatus::AT_LEAST_ONE_KEY: return "at least 1 input key is needed for this command"; case OpStatus::MEMBER_NOTFOUND: return kKeyNotFoundErr; case OpStatus::INVALID_JSON_PATH: return kInvalidJsonPathErr; case OpStatus::INVALID_JSON: return kJsonParseError; case OpStatus::NAN_OR_INF_DURING_INCR: return kNanOrInfDuringIncr; case OpStatus::IO_ERROR: return kTieredIoError; default: LOG(ERROR) << "Unsupported status " << status; return "Internal error"; } } } // namespace facade ================================================ FILE: src/facade/op_status.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include namespace facade { enum class OpStatus : uint16_t { OK, KEY_EXISTS, KEY_NOTFOUND, KEY_MOVED, SKIPPED, INVALID_VALUE, CORRUPTED_HLL, OUT_OF_RANGE, WRONG_TYPE, WRONG_JSON_TYPE, TIMED_OUT, OUT_OF_MEMORY, INVALID_FLOAT, INVALID_INT, SYNTAX_ERR, BUSY_GROUP, STREAM_ID_SMALL, INVALID_NUMERIC_RESULT, CANCELLED, AT_LEAST_ONE_KEY, MEMBER_NOTFOUND, INVALID_JSON_PATH, INVALID_JSON, IO_ERROR, NAN_OR_INF_DURING_INCR, }; class OpResultBase { public: OpResultBase(OpStatus st = OpStatus::OK) : st_(st) { } constexpr explicit operator bool() const { return st_ == OpStatus::OK; } OpStatus status() const { return st_; } bool operator==(OpStatus st) const { return st_ == st; } bool ok() const { return st_ == OpStatus::OK; } const char* DebugFormat() const; private: OpStatus st_; }; template class OpResult : public OpResultBase { public: using Type = V; OpResult(V&& v) : v_(std::move(v)) { } OpResult(const V& v) : v_(v) { } using OpResultBase::OpResultBase; const V& value() const { return v_; } V& value() { return v_; } V value_or(V v) const { return status() == OpStatus::OK ? v_ : v; } V* operator->() { return &v_; } V& operator*() & { return v_; } V&& operator*() && { return std::move(v_); } const V* operator->() const { return &v_; } const V& operator*() const& { return v_; } private: V v_{}; }; template <> class OpResult : public OpResultBase { public: using OpResultBase::OpResultBase; }; inline bool operator==(OpStatus st, const OpResultBase& ob) { return ob.operator==(st); } std::string_view StatusToMsg(OpStatus status); } // namespace facade namespace std { template std::ostream& operator<<(std::ostream& os, const facade::OpResult& res) { os << res.status(); return os; } inline std::ostream& operator<<(std::ostream& os, const facade::OpStatus op) { os << int(op); return os; } } // namespace std ================================================ FILE: src/facade/parsed_command.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/parsed_command.h" #include "base/logging.h" #include "core/overloaded.h" #include "facade/conn_context.h" #include "facade/dragonfly_connection.h" #include "facade/reply_builder.h" #include "facade/reply_capture.h" #include "facade/reply_payload.h" namespace facade { using namespace std; string MCRender::RenderNotFound() const { if (flags_.no_reply) return {}; return flags_.meta ? "NF" : "NOT_FOUND"; } string MCRender::RenderGetEnd() const { if (flags_.no_reply || flags_.meta) return {}; return "END"; } std::string MCRender::RenderStored(bool ok) const { if (flags_.no_reply) return {}; if (ok) return flags_.meta ? "HD" : "STORED"; return flags_.meta ? "NS" : "NOT_STORED"; } string MCRender::RenderMiss() const { if (flags_.no_reply || !flags_.meta) return {}; return "EN"; } string MCRender::RenderDeleted() const { if (flags_.no_reply) return {}; return flags_.meta ? "HD" : "DELETED"; } void ParsedCommand::ResetForReuse() { is_deferred_reply_ = false; reply_ = std::monostate{}; offsets_.clear(); if (HeapMemory() > 1024) { storage_.clear(); // also deallocates the heap. offsets_.shrink_to_fit(); } ReuseInternal(); } void ParsedCommand::SendError(std::string_view str, std::string_view type) { if (!is_deferred_reply_) { rb_->SendError(str, type); } else { reply_ = payload::make_error(str, type); } } void ParsedCommand::SendError(facade::OpStatus status) { if (!is_deferred_reply_) { rb_->SendError(status); } else { if (status == OpStatus::OK) reply_ = payload::SimpleString{"OK"}; else reply_ = payload::make_error(StatusToMsg(status)); } } void ParsedCommand::SendError(const facade::ErrorReply& error) { if (error.status) return SendError(*error.status); SendError(error.ToSv(), error.kind); } void ParsedCommand::SendSimpleString(std::string_view str) { if (!is_deferred_reply_) { rb_->SendSimpleString(str); } else { reply_ = payload::make_simple_or_noreply(str); } } void ParsedCommand::SendLong(long val) { DCHECK(!is_deferred_reply_); rb_->SendLong(val); } bool ParsedCommand::CanReply() const { DCHECK(is_deferred_reply_); dfly::Overloaded ov{[](const payload::Payload& pl) { return pl.index() > 0 /* not monostate */; }, [](const SuspendedCommand& task) { return task.blocker->IsCompleted(); }}; return std::visit(ov, reply_); } void ParsedCommand::SendReply() { auto payload_handler = [this](payload::Payload& pl) { CapturingReplyBuilder::Apply(std::move(pl), rb_); }; auto task_handler = [](SuspendedCommand& task) { DCHECK(task.coro); task.coro.resume(); task.coro = {}; }; std::visit(dfly::Overloaded{task_handler, payload_handler}, reply_); } ParsedCommand::SuspendedCommand::~SuspendedCommand() { if (coro) { coro.destroy(); coro = {}; } } } // namespace facade ================================================ FILE: src/facade/parsed_command.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include "base/function2.hpp" #include "common/backed_args.h" #include "facade/memcache_parser.h" #include "facade/reply_payload.h" #include "util/fibers/synchronization.h" namespace facade { class ConnectionContext; class SinkReplyBuilder; // Renders simple string responses based on flags. // Returns empty string if no response is to be sent. class MCRender { public: explicit MCRender(MemcacheCmdFlags flags) : flags_(flags) { } std::string RenderNotFound() const; std::string RenderMiss() const; std::string RenderDeleted() const; std::string RenderGetEnd() const; std::string RenderStored(bool ok) const; private: MemcacheCmdFlags flags_; }; // ParsedCommand is a protocol-agnostic holder for parsed request state. // It wraps cmn::BackedArguments so the facade can populate RESP arguments and // optionally attach a MemcacheParser::Command, complementing the arguments // with memcache-specific data. // The purpose of ParsedCommand is to hold the entire state of a parsed request // during its lifetime, from parsing to dispatching and reply building including // any async dispatching. class ParsedCommand : public cmn::BackedArguments { friend class ServiceInterface; protected: SinkReplyBuilder* rb_ = nullptr; // either RedisReplyBuilder or MCReplyBuilder ConnectionContext* conn_cntx_ = nullptr; std::unique_ptr mc_cmd_; // only for memcache protocol ParsedCommand() = default; // Helper function to get the only argument type template static Arg OnlyArgType(void (C::*)(Arg) const); public: using ReplyFunc = fu2::function_base, false, false, void(SinkReplyBuilder*)>; virtual ~ParsedCommand() = default; virtual size_t GetSize() const { return sizeof(ParsedCommand); } // time when the message was parsed as reported by CycleClock::Now() // Also serves as the enqueue timestamp for calculating pipeline wait latency. uint64_t parsed_cycle = 0; ParsedCommand* next = nullptr; void Init(SinkReplyBuilder* rb, ConnectionContext* conn_cntx) { rb_ = rb; conn_cntx_ = conn_cntx; } // If true, creates mc specific fields, false - destroys them. void ConfigureMCExtension(bool is_mc) { if (is_mc && !mc_cmd_) { mc_cmd_ = std::make_unique(); mc_cmd_->backed_args = this; } else if (!is_mc) { mc_cmd_.reset(); } } SinkReplyBuilder* rb() const { return rb_; } ConnectionContext* conn_cntx() const { return conn_cntx_; } MemcacheParser::Command* mc_command() const { return mc_cmd_.get(); } size_t UsedMemory() const { size_t sz = HeapMemory() + GetSize(); if (mc_cmd_) { sz += sizeof(*mc_cmd_); } return sz; } // Marks this command as having reply stored in its payload instead of being sent directly. void SetDeferredReply() { is_deferred_reply_ = true; } bool IsDeferredReply() const { return is_deferred_reply_; } void ResetForReuse(); void SendError(std::string_view str, std::string_view type = std::string_view{}); void SendError(facade::OpStatus status); void SendError(const facade::ErrorReply& error); void SendSimpleString(std::string_view str); void SendOk() { SendSimpleString("OK"); } void SendLong(long val); template void ReplyWith(F&& func) { assert(!is_deferred_reply_); using RbType = decltype(OnlyArgType(&std::decay_t::operator())); func(static_cast(rb_)); } // Below are main commands for the async api and all assume that the command defers replies // Whether SendReply() can be called. If not, it must be waited via Blocker() bool CanReply() const; // Reaching zero on blocker means CanReply() turns true util::fb2::EmbeddedBlockingCounter* Blocker() const { return std::get(reply_).blocker; } // Assumes CanReply() is true. Sends reply void SendReply(); // Resolve deferred command with reply void Resolve(const facade::ErrorReply& error) { SendError(error); } // Resolve deferred command with async task void Resolve(util::fb2::EmbeddedBlockingCounter* blocker, std::coroutine_handle<> coro) { reply_ = SuspendedCommand{blocker, coro}; } protected: virtual void ReuseInternal() = 0; private: // Suspended asynchronous command. Once blocker is done, the coroutine can be resumed. // Deletes the coroutine on drop. struct SuspendedCommand { SuspendedCommand(util::fb2::EmbeddedBlockingCounter* blocker, std::coroutine_handle<> coro) : blocker{blocker}, coro{coro} { } SuspendedCommand(SuspendedCommand&& other) noexcept : blocker{other.blocker}, coro{std::exchange(other.coro, {})} { } SuspendedCommand& operator=(SuspendedCommand&& other) noexcept { blocker = other.blocker; coro = std::exchange(other.coro, {}); return *this; } // To destroy the coroutine when cancelling (as the handle is non owning) ~SuspendedCommand(); util::fb2::EmbeddedBlockingCounter* blocker; std::coroutine_handle<> coro; }; // if false then the reply was sent directly to reply builder, // otherwise, moved asynchronously into reply_payload_ bool is_deferred_reply_ = false; std::variant reply_; }; #ifdef __linux__ static_assert(sizeof(ParsedCommand) == 232); #endif } // namespace facade ================================================ FILE: src/facade/redis_parser.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/redis_parser.h" #include #include #include "base/logging.h" #include "common/heap_size.h" namespace facade { using namespace std; auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> Result { DCHECK(!str.empty()); *consumed = 0; res->clear(); DVLOG(2) << "Parsing: " << absl::CHexEscape(string_view{reinterpret_cast(str.data()), str.size()}); if (state_ == CMD_COMPLETE_S) { if (InitStart(str[0], res)) { // We recognized a non-INLINE state, starting with a special char. str.remove_prefix(1); *consumed += 1; if (server_mode_ && state_ == PARSE_ARG_S) { // server requests start with ARRAY_LEN_S. state_ = CMD_COMPLETE_S; // reject and reset the state. return BAD_ARRAYLEN; } if (str.empty()) return INPUT_PENDING; } } else { // INLINE mode, aka PING\n // We continue parsing in the middle. if (!cached_expr_) cached_expr_ = res; } DCHECK(state_ != CMD_COMPLETE_S); ResultConsumed resultc{OK, 0}; do { switch (state_) { case MAP_LEN_S: case ARRAY_LEN_S: resultc = ConsumeArrayLen(str); break; case PARSE_ARG_TYPE: arg_c_ = str[0]; if (server_mode_ && arg_c_ != '$') // server side only supports bulk strings. return BAD_BULKLEN; resultc.second = 1; state_ = PARSE_ARG_S; break; case PARSE_ARG_S: resultc = ParseArg(str); break; case INLINE_S: DCHECK(parse_stack_.empty()); resultc = ParseInline(str); break; case BULK_STR_S: resultc = ConsumeBulk(str); break; case SLASH_N_S: if (str[0] != '\n') { resultc.first = BAD_STRING; } else { resultc = {OK, 1}; if (arg_c_ == '_') { cached_expr_->emplace_back(RespExpr::NIL); cached_expr_->back().u = Buffer{}; } HandleFinishArg(); } break; default: LOG(FATAL) << "Unexpected state " << int(state_); } *consumed += resultc.second; str.remove_prefix(exchange(resultc.second, 0)); } while (state_ != CMD_COMPLETE_S && resultc.first == OK && !str.empty()); if (state_ != CMD_COMPLETE_S) { if (resultc.first == OK) { resultc.first = INPUT_PENDING; } if (resultc.first == INPUT_PENDING) { // TODO: we still need to handle ':' and ',' cases for client mode // to consume them completely. if (server_mode_ && !str.empty()) { LOG(DFATAL) << "Did not consume all input: " << absl::CHexEscape({reinterpret_cast(str.data()), str.size()}) << ", state: " << int(state_) << " smallbuf: " << absl::CHexEscape( {reinterpret_cast(small_buf_.data()), small_len_}); } StashState(res); } return resultc.first; } if (resultc.first == OK) { DCHECK(cached_expr_); DCHECK_EQ(0, small_len_); if (res != cached_expr_) { DCHECK(!stash_.empty()); *res = *cached_expr_; } } return resultc.first; } bool RedisParser::InitStart(char prefix_b, RespExpr::Vec* res) { buf_stash_.clear(); stash_.clear(); cached_expr_ = res; parse_stack_.clear(); last_stashed_level_ = 0; last_stashed_index_ = 0; switch (prefix_b) { case '$': case ':': case '+': case '-': case '_': // Resp3 NULL case ',': // Resp3 DOUBLE state_ = PARSE_ARG_S; parse_stack_.emplace_back(1, cached_expr_); // expression of length 1. arg_c_ = prefix_b; return true; case '*': case '~': // Resp3 SET state_ = ARRAY_LEN_S; return true; case '%': // Resp3 MAP state_ = MAP_LEN_S; return true; } state_ = INLINE_S; return false; } void RedisParser::StashState(RespExpr::Vec* res) { if (cached_expr_->empty() && stash_.empty()) { cached_expr_ = nullptr; return; } if (cached_expr_ == res) { stash_.emplace_back(new RespExpr::Vec(*res)); cached_expr_ = stash_.back().get(); } DCHECK_LT(last_stashed_level_, stash_.size()); while (true) { auto& cur = *stash_[last_stashed_level_]; for (; last_stashed_index_ < cur.size(); ++last_stashed_index_) { auto& e = cur[last_stashed_index_]; if (RespExpr::STRING == e.type) { Buffer& ebuf = get(e.u); if (ebuf.empty() && last_stashed_index_ + 1 == cur.size()) break; if (!ebuf.empty() && !e.has_support) { Blob blob(ebuf.size()); memcpy(blob.data(), ebuf.data(), ebuf.size()); ebuf = Buffer{blob.data(), blob.size()}; buf_stash_.push_back(std::move(blob)); e.has_support = true; } } } if (last_stashed_level_ + 1 == stash_.size()) break; ++last_stashed_level_; last_stashed_index_ = 0; } } auto RedisParser::ParseInline(Buffer str) -> ResultConsumed { DCHECK(!str.empty()); const uint8_t* ptr = str.begin(); const uint8_t* end = str.end(); const uint8_t* token_start = ptr; auto find_token_end = [](const uint8_t* ptr, const uint8_t* end) { while (ptr != end && *ptr > 32) ++ptr; return ptr; }; if (is_broken_token_) { ptr = find_token_end(ptr, end); size_t len = ptr - token_start; ExtendLastString(Buffer(token_start, len)); if (ptr == end) { return {INPUT_PENDING, ptr - token_start}; } is_broken_token_ = false; } while (ptr != end) { // For inline input we only require \n. if (*ptr == '\n') { if (cached_expr_->empty()) { ++ptr; continue; // skip empty line } break; } if (*ptr <= 32) { // skip ws/control chars ++ptr; continue; } // token start DCHECK(!is_broken_token_); token_start = ptr; ptr = find_token_end(ptr, end); cached_expr_->emplace_back(RespExpr::STRING); cached_expr_->back().u = Buffer{token_start, size_t(ptr - token_start)}; } uint32_t last_consumed = ptr - str.data(); if (ptr == end) { // we have not finished parsing. if (cached_expr_->empty()) { state_ = CMD_COMPLETE_S; // have not found anything besides whitespace. } else { is_broken_token_ = ptr[-1] > 32; // we stopped in the middle of the token. } return {INPUT_PENDING, last_consumed}; } DCHECK_EQ('\n', *ptr); ++last_consumed; // consume \n as well. state_ = CMD_COMPLETE_S; return {OK, last_consumed}; } // Parse lines like:'$5\r\n' or '*2\r\n'. The first character is already consumed by the caller. auto RedisParser::ParseLen(Buffer str, int64_t* res) -> ResultConsumed { DCHECK(!str.empty()); const char* s = reinterpret_cast(str.data()); const char* pos = reinterpret_cast(memchr(s, '\n', str.size())); if (!pos) { if (str.size() + small_len_ < small_buf_.size()) { memcpy(&small_buf_[small_len_], str.data(), str.size()); small_len_ += str.size(); return {INPUT_PENDING, str.size()}; } LOG(WARNING) << "Unexpected format " << string_view{s, str.size()}; return ResultConsumed{BAD_ARRAYLEN, 0}; } unsigned consumed = pos - s + 1; if (small_len_ > 0) { if (small_len_ + consumed >= small_buf_.size()) { return ResultConsumed{BAD_ARRAYLEN, consumed}; } memcpy(&small_buf_[small_len_], str.data(), consumed); small_len_ += consumed; s = small_buf_.data(); pos = s + small_len_ - 1; small_len_ = 0; } if (pos[-1] != '\r') { return {BAD_ARRAYLEN, consumed}; } // Skip 2 last characters (\r\n). string_view len_token{s, size_t(pos - 1 - s)}; bool success = absl::SimpleAtoi(len_token, res); if (success && *res >= -1) { return ResultConsumed{OK, consumed}; } LOG(ERROR) << "Failed to parse len " << absl::CHexEscape(len_token) << " " << absl::CHexEscape(string_view{reinterpret_cast(str.data()), str.size()}) << " " << consumed << " " << int(s == small_buf_.data()); return ResultConsumed{BAD_ARRAYLEN, consumed}; } auto RedisParser::ConsumeArrayLen(Buffer str) -> ResultConsumed { int64_t len; ResultConsumed res = ParseLen(str, &len); if (res.first != OK) { return res; } if (state_ == MAP_LEN_S) { // Map starts with %N followed by an array of 2*N elements. // Even elements are keys, odd elements are values. len *= 2; } if (len > max_arr_len_) { LOG(WARNING) << "Multibulk len is too large " << len; return {BAD_ARRAYLEN, res.second}; } if (server_mode_ && (!parse_stack_.empty() || !cached_expr_->empty())) return {BAD_STRING, res.second}; if (len <= 0) { if (len < 0) { cached_expr_->emplace_back(RespExpr::NIL_ARRAY); cached_expr_->back().u.emplace(nullptr); // nil } else { static RespVec empty_vec; cached_expr_->emplace_back(RespExpr::ARRAY); cached_expr_->back().u = &empty_vec; } if (parse_stack_.empty()) { state_ = CMD_COMPLETE_S; } else { HandleFinishArg(); } return {OK, res.second}; } if (state_ == PARSE_ARG_S) { DCHECK(!server_mode_); cached_expr_->emplace_back(RespExpr::ARRAY); stash_.emplace_back(new RespExpr::Vec()); RespExpr::Vec* arr = stash_.back().get(); arr->reserve(len); cached_expr_->back().u = arr; cached_expr_ = arr; } state_ = PARSE_ARG_TYPE; DVLOG(1) << "PushStack: (" << len << ", " << cached_expr_ << ")"; parse_stack_.emplace_back(len, cached_expr_); return {OK, res.second}; } auto RedisParser::ParseArg(Buffer str) -> ResultConsumed { DCHECK(!str.empty()); if (arg_c_ == '$') { int64_t len; ResultConsumed res = ParseLen(str, &len); if (res.first != OK) { return res; } if (len > 0 && static_cast(len) > max_bulk_len_) { LOG_EVERY_T(WARNING, 1) << "Threshold reached with bulk len: " << len << ", consider increasing max_bulk_len"; return {BAD_ARRAYLEN, res.second}; } if (len == -1) { // Resp2 NIL cached_expr_->emplace_back(RespExpr::NIL); cached_expr_->back().u = Buffer{}; HandleFinishArg(); } else { DVLOG(1) << "String(" << len << ")"; cached_expr_->emplace_back(RespExpr::STRING); cached_expr_->back().u = Buffer{}; bulk_len_ = len; state_ = BULK_STR_S; } return {OK, res.second}; } DCHECK(!server_mode_); if (arg_c_ == '_') { // Resp3 NIL // "_\r\n", with '_' consumed into arg_c_. DCHECK_LT(small_len_, 2u); // must be because we never fill here with more than 2 bytes. DCHECK_GE(str.size(), 1u); if (str[0] != '\r' || (str.size() > 1 && str[1] != '\n')) { return {BAD_STRING, 0}; } if (str.size() == 1) { state_ = SLASH_N_S; return {INPUT_PENDING, 1}; } cached_expr_->emplace_back(RespExpr::NIL); cached_expr_->back().u = Buffer{}; HandleFinishArg(); return {OK, 2}; } if (arg_c_ == '*') { return ConsumeArrayLen(str); } const char* s = reinterpret_cast(str.data()); const char* eol = reinterpret_cast(memchr(s, '\n', str.size())); if (arg_c_ == '+' || arg_c_ == '-') { // Simple string or error. DCHECK(!server_mode_); if (!eol) { // if eol is not found we should still read input as bulk string cached_expr_->emplace_back(RespExpr::STRING); cached_expr_->back().u = Buffer{}; bulk_len_ = str.length(); // eol is not found but if '\r' is present decrease bulk_len if (s[bulk_len_ - 1] == '\r') bulk_len_--; state_ = BULK_STR_S; Result r = str.size() < 256 ? OK : BAD_STRING; return {r, 0}; } if (eol[-1] != '\r') return {BAD_STRING, 0}; cached_expr_->emplace_back(arg_c_ == '+' ? RespExpr::STRING : RespExpr::ERROR); cached_expr_->back().u = Buffer{reinterpret_cast(s), size_t((eol - 1) - s)}; } else if (arg_c_ == ':') { DCHECK(!server_mode_); if (!eol) { Result r = str.size() < 32 ? INPUT_PENDING : BAD_INT; return {r, 0}; } int64_t ival; std::string_view tok{s, size_t((eol - s) - 1)}; if (eol[-1] != '\r' || !absl::SimpleAtoi(tok, &ival)) return {BAD_INT, 0}; cached_expr_->emplace_back(RespExpr::INT64); cached_expr_->back().u = ival; } else if (arg_c_ == ',') { DCHECK(!server_mode_); if (!eol) { Result r = str.size() < 32 ? INPUT_PENDING : BAD_DOUBLE; return {r, 0}; } double_t dval; std::string_view tok{s, size_t((eol - s) - 1)}; if (eol[-1] != '\r' || !absl::SimpleAtod(tok, &dval)) return {BAD_DOUBLE, 0}; cached_expr_->emplace_back(RespExpr::DOUBLE); cached_expr_->back().u = dval; } else { return {BAD_STRING, 0}; } HandleFinishArg(); return {OK, (eol - s) + 1}; } auto RedisParser::ConsumeBulk(Buffer str) -> ResultConsumed { DCHECK_EQ(small_len_, 0); uint32_t consumed = 0; auto& bulk_str = get(cached_expr_->back().u); bool extend = false; // Handle split simple message or error in client mode if (!server_mode_ && (arg_c_ == '+' || arg_c_ == '-') && !bulk_len_) { // Search first '\r' in next partial message which ends bulk string const char* s = reinterpret_cast(str.data()); const char* pos = reinterpret_cast(memchr(s, '\r', str.size())); bulk_len_ = pos ? pos - s : str.size(); extend = true; } if (str.size() >= bulk_len_) { consumed = bulk_len_; if (bulk_len_) { // is_broken_token_ can be false, if we just parsed the bulk length but have // not parsed the token itself. if (is_broken_token_) { memcpy(const_cast(bulk_str.end()), str.data(), bulk_len_); bulk_str = Buffer{bulk_str.data(), bulk_str.size() + bulk_len_}; } else if (extend) { ExtendBulkString(Buffer(str.begin(), bulk_len_)); } else { bulk_str = str.subspan(0, bulk_len_); } str.remove_prefix(exchange(bulk_len_, 0)); is_broken_token_ = false; } if (str.size() >= 2) { if (str[0] != '\r' || str[1] != '\n') { return {BAD_STRING, consumed}; } HandleFinishArg(); return {OK, consumed + 2}; } else if (str.size() == 1) { if (str[0] != '\r') { return {BAD_STRING, consumed}; } state_ = SLASH_N_S; consumed++; } return {INPUT_PENDING, consumed}; } DCHECK(bulk_len_); size_t len = std::min(str.size(), bulk_len_); if (is_broken_token_) { memcpy(const_cast(bulk_str.end()), str.data(), len); bulk_str = Buffer{bulk_str.data(), bulk_str.size() + len}; DVLOG(1) << "Extending bulk stash to size " << bulk_str.size(); } else { DVLOG(1) << "New bulk stash size " << bulk_len_; vector nb(bulk_len_); memcpy(nb.data(), str.data(), len); bulk_str = Buffer{nb.data(), len}; buf_stash_.emplace_back(std::move(nb)); is_broken_token_ = true; cached_expr_->back().has_support = true; } consumed = len; bulk_len_ -= len; return {INPUT_PENDING, consumed}; } void RedisParser::HandleFinishArg() { DCHECK(!parse_stack_.empty()); DCHECK_GT(parse_stack_.back().first, 0u); state_ = PARSE_ARG_TYPE; while (true) { --parse_stack_.back().first; if (parse_stack_.back().first != 0) break; auto* arr = parse_stack_.back().second; DVLOG(1) << "PopStack (" << arr << ")"; parse_stack_.pop_back(); // pop 0. if (parse_stack_.empty()) { state_ = CMD_COMPLETE_S; break; } cached_expr_ = parse_stack_.back().second; } small_len_ = 0; } void RedisParser::ExtendLastString(Buffer str) { DCHECK(!cached_expr_->empty() && cached_expr_->back().type == RespExpr::STRING); DCHECK(!buf_stash_.empty()); Buffer& last_str = get(cached_expr_->back().u); DCHECK(last_str.data() == buf_stash_.back().data()); vector nb(last_str.size() + str.size()); memcpy(nb.data(), last_str.data(), last_str.size()); memcpy(nb.data() + last_str.size(), str.data(), str.size()); last_str = RespExpr::Buffer{nb.data(), last_str.size() + str.size()}; buf_stash_.back() = std::move(nb); } void RedisParser::ExtendBulkString(Buffer str) { DCHECK(!cached_expr_->empty() && cached_expr_->back().type == RespExpr::STRING); Buffer& bulk_str = get(cached_expr_->back().u); DCHECK(bulk_str.data() == buf_stash_.back().data()); vector nb(bulk_str.size() + str.size()); memcpy(nb.data(), bulk_str.data(), bulk_str.size()); memcpy(nb.data() + bulk_str.size(), str.data(), str.size()); bulk_str = RespExpr::Buffer{nb.data(), bulk_str.size() + str.size()}; buf_stash_.back() = std::move(nb); } size_t RedisParser::UsedMemory() const { return cmn::HeapSize(parse_stack_) + cmn::HeapSize(stash_) + cmn::HeapSize(buf_stash_); } } // namespace facade ================================================ FILE: src/facade/redis_parser.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "facade/resp_expr.h" namespace facade { /** * @brief Zero-copy (best-effort) parser. * Note: The client-mode parsing is buggy and should not be used in production. * Currently we only use server-mode parsing in production and client-mode in tests. * It works because tests do not do any incremental parsing. * */ class RedisParser { public: enum Result : uint8_t { OK, INPUT_PENDING, BAD_ARRAYLEN, BAD_BULKLEN, BAD_STRING, BAD_INT, BAD_DOUBLE }; using Buffer = RespExpr::Buffer; enum Mode : uint8_t { SERVER, CLIENT }; explicit RedisParser(Mode mode = Mode::SERVER, uint32_t max_arr_len = UINT32_MAX, uint64_t max_bulk_len = UINT64_MAX) : server_mode_(mode == Mode::SERVER), max_arr_len_(max_arr_len), max_bulk_len_(max_bulk_len) { } /** * @brief Parses str into res. "consumed" stores number of bytes consumed from str. * * A caller should not invalidate str if the parser returns RESP_OK as long as he continues * accessing res. However, if parser returns INPUT_PENDING a caller may discard consumed * part of str because parser caches the intermediate state internally according to 'consumed' * result. * * */ Result Parse(Buffer str, uint32_t* consumed, RespVec* res); void SetClientMode() { server_mode_ = false; } size_t parselen_hint() const { return bulk_len_; } size_t stash_size() const { return stash_.size(); } const std::vector>& stash() const { return stash_; } size_t UsedMemory() const; private: using ResultConsumed = std::pair; // Returns true if this is a RESP message, false if INLINE. bool InitStart(char prefix_b, RespVec* res); void StashState(RespVec* res); // Skips the first character (*). ResultConsumed ConsumeArrayLen(Buffer str); ResultConsumed ParseArg(Buffer str); ResultConsumed ConsumeBulk(Buffer str); ResultConsumed ParseInline(Buffer str); ResultConsumed ParseLen(Buffer str, int64_t* res); void HandleFinishArg(); void ExtendLastString(Buffer str); void ExtendBulkString(Buffer str); enum State : uint8_t { INLINE_S, ARRAY_LEN_S, MAP_LEN_S, PARSE_ARG_TYPE, // Parse [$:+-] PARSE_ARG_S, // Parse string\r\n BULK_STR_S, SLASH_N_S, CMD_COMPLETE_S, }; State state_ = CMD_COMPLETE_S; bool is_broken_token_ = false; // true, if a token (inline or bulk) is broken during the parsing. bool server_mode_ = true; uint8_t small_len_ = 0; char arg_c_ = 0; uint32_t bulk_len_ = 0; uint32_t last_stashed_level_ = 0, last_stashed_index_ = 0; uint32_t max_arr_len_; uint64_t max_bulk_len_; // Points either to the result passed by the caller or to the stash. RespVec* cached_expr_ = nullptr; // expected expression length, pointer to expression vector. // For server mode, the length is at most 1. absl::InlinedVector, 4> parse_stack_; std::vector> stash_; using Blob = std::vector; std::vector buf_stash_; std::array small_buf_; }; } // namespace facade ================================================ FILE: src/facade/redis_parser_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/redis_parser.h" #include #include #include "absl/strings/str_cat.h" #include "base/gtest.h" #include "base/logging.h" #include "common/heap_size.h" #include "facade/facade_test.h" using namespace testing; using namespace std; namespace facade { MATCHER_P(ArrArg, expected, absl::StrCat(negation ? "is not" : "is", " equal to:\n", expected)) { if (arg.type != RespExpr::ARRAY) { *result_listener << "\nWrong type: " << arg.type; return false; } size_t exp_sz = expected; size_t actual = get(arg.u)->size(); if (exp_sz != actual) { *result_listener << "\nActual size: " << actual; return false; } return true; } class RedisParserTest : public testing::Test { protected: static void SetUpTestSuite() { } RedisParser::Result Parse(std::string_view str); RedisParser parser_; RespExpr::Vec args_; uint32_t consumed_; unique_ptr stash_; }; RedisParser::Result RedisParserTest::Parse(std::string_view str) { stash_.reset(new uint8_t[str.size()]); auto* ptr = stash_.get(); memcpy(ptr, str.data(), str.size()); return parser_.Parse(RedisParser::Buffer{ptr, str.size()}, &consumed_, &args_); } TEST_F(RedisParserTest, Inline) { RespExpr e{RespExpr::STRING}; ASSERT_EQ(RespExpr::STRING, e.type); const char kCmd1[] = "KEY VAL\r\n"; ASSERT_EQ(RedisParser::OK, Parse(kCmd1)); EXPECT_EQ(strlen(kCmd1), consumed_); EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("KEY")); EXPECT_EQ(3, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" FOO ")); EXPECT_EQ(5, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" BAR")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RedisParser::OK, Parse(" \r\n ")); EXPECT_EQ(3, consumed_); EXPECT_THAT(args_, ElementsAre("KEY", "FOO", "BAR")); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 1 2")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 45")); EXPECT_EQ(3, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\r\n")); EXPECT_EQ(2, consumed_); EXPECT_THAT(args_, ElementsAre("1", "2", "45")); // Empty queries return INPUT_PENDING. EXPECT_EQ(RedisParser::INPUT_PENDING, Parse("\r\n")); EXPECT_EQ(2, consumed_); } TEST_F(RedisParserTest, InlineEscaping) { LOG(ERROR) << "TBD: to be compliant with sdssplitargs"; // TODO: } TEST_F(RedisParserTest, Multi1) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n")); EXPECT_EQ(4, consumed_); EXPECT_EQ(0, parser_.parselen_hint()); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("$4\r\n")); EXPECT_EQ(4, consumed_); EXPECT_EQ(4, parser_.parselen_hint()); ASSERT_EQ(RedisParser::OK, Parse("PING\r\n")); EXPECT_EQ(6, consumed_); EXPECT_EQ(0, parser_.parselen_hint()); EXPECT_THAT(args_, ElementsAre("PING")); } TEST_F(RedisParserTest, Multi2) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n$")); EXPECT_EQ(5, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("4\r\nMSET")); EXPECT_EQ(7, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\r\n*2\r\n")); EXPECT_EQ(2, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*2\r\n$3\r\nKEY\r\n$3\r\nVAL")); EXPECT_EQ(20, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\r\n")); EXPECT_EQ(2, consumed_); EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); } TEST_F(RedisParserTest, Multi3) { const char kFirst[] = "*3\r\n$3\r\nSET\r\n$16\r\nkey:"; const char kSecond[] = "000002273458\r\n$3\r\nVXK"; ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kFirst)); ASSERT_EQ(strlen(kFirst), consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kSecond)); ASSERT_EQ(strlen(kSecond), consumed_); ASSERT_EQ(RedisParser::OK, Parse("\r\n*3\r\n$3\r\nSET")); ASSERT_EQ(2, consumed_); EXPECT_THAT(args_, ElementsAre("SET", "key:000002273458", "VXK")); } TEST_F(RedisParserTest, ClientMode) { parser_.SetClientMode(); ASSERT_EQ(RedisParser::OK, Parse(":-1\r\n")); EXPECT_THAT(args_, ElementsAre(IntArg(-1))); ASSERT_EQ(RedisParser::OK, Parse("+OK\r\n")); EXPECT_EQ(args_[0], "OK"); ASSERT_EQ(RedisParser::OK, Parse("-ERR foo bar\r\n")); EXPECT_THAT(args_, ElementsAre(ErrArg("ERR foo"))); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("_")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\r")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\n")); EXPECT_EQ(1, consumed_); EXPECT_THAT(args_, ElementsAre(ArgType(RespExpr::NIL))); ASSERT_EQ(RedisParser::OK, Parse("*2\r\n_\r\n_\r\n")); ASSERT_EQ(10, consumed_); ASSERT_EQ(RedisParser::OK, Parse("*3\r\n+OK\r\n$1\r\n1\r\n*2\r\n$1\r\n1\r\n$-1\r\n")); ASSERT_THAT(args_, ElementsAre("OK", "1", ArrLen(2))); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("+O")); EXPECT_EQ(2, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("K\r")); EXPECT_EQ(2, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\n")); ASSERT_THAT(args_, ElementsAre("OK")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("+OK\r")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\n")); ASSERT_THAT(args_, ElementsAre("OK")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("+")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("O")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("K")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\r")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\n")); EXPECT_EQ(1, consumed_); ASSERT_THAT(args_, ElementsAre("OK")); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("-")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::OK, Parse("ERR\r\n")); EXPECT_EQ(5, consumed_); ASSERT_THAT(args_, ElementsAre(ErrArg("ERR"))); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("-ERR foo")); EXPECT_EQ(8, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\r\n")); EXPECT_EQ(2, consumed_); ASSERT_THAT(args_, ElementsAre("ERR foo")); } TEST_F(RedisParserTest, Hierarchy) { parser_.SetClientMode(); const char* kThirdArg = "*2\r\n$3\r\n100\r\n$3\r\n200\r\n"; string resp = absl::StrCat("*3\r\n$3\r\n900\r\n$3\r\n800\r\n", kThirdArg); ASSERT_EQ(RedisParser::OK, Parse(resp)); ASSERT_THAT(args_, ElementsAre("900", "800", ArrArg(2))); EXPECT_THAT(args_[2].GetVec(), ElementsAre("100", "200")); ASSERT_EQ(RedisParser::OK, Parse("*2\r\n*1\r\n$3\r\n1-0\r\n*1\r\n$2\r\nf1\r\n")); ASSERT_THAT(args_, ElementsAre(ArrLen(1), ArrLen(1))); } TEST_F(RedisParserTest, InvalidMult1) { ASSERT_EQ(RedisParser::BAD_BULKLEN, Parse("*2\r\n$3\r\nFOO\r\nBAR\r\n")); } TEST_F(RedisParserTest, Empty) { ASSERT_EQ(RedisParser::OK, Parse("*2\r\n$0\r\n\r\n$0\r\n\r\n")); } TEST_F(RedisParserTest, LargeBulk) { string_view prefix("*1\r\n$1024\r\n"); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(prefix)); ASSERT_EQ(prefix.size(), consumed_); ASSERT_GE(parser_.parselen_hint(), 1024); string half(512, 'a'); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half)); ASSERT_EQ(512, consumed_); ASSERT_GE(parser_.parselen_hint(), 512); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half)); ASSERT_EQ(512, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\r")); ASSERT_EQ(1, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\n")); EXPECT_EQ(1, consumed_); string part1 = absl::StrCat(prefix, half); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(part1)); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(half)); ASSERT_EQ(RedisParser::OK, Parse("\r\n")); prefix = "*1\r\n$270000000\r\n"; ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(prefix)); ASSERT_EQ(prefix.size(), consumed_); string chunk(1000000, 'a'); for (unsigned i = 0; i < 270; ++i) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(chunk)); ASSERT_EQ(chunk.size(), consumed_); } ASSERT_EQ(RedisParser::OK, Parse("\r\n")); ASSERT_THAT(args_, ElementsAre(ArgType(RespExpr::STRING))); EXPECT_EQ(270000000, args_[0].GetBuf().size()); } TEST_F(RedisParserTest, NILs) { ASSERT_EQ(RedisParser::BAD_ARRAYLEN, Parse("_\r\n")); parser_.SetClientMode(); ASSERT_EQ(RedisParser::OK, Parse("_\r\nfooobar")); EXPECT_EQ(3, consumed_); } TEST_F(RedisParserTest, NestedArray) { parser_.SetClientMode(); // [[['foo'],['bar']],['car']] ASSERT_EQ(RedisParser::OK, Parse("*2\r\n*2\r\n*1\r\n$3\r\nfoo\r\n*1\r\n$3\r\nbar\r\n*1\r\n$3\r\ncar\r\n")); ASSERT_THAT(args_, ElementsAre(ArrArg(2), ArrArg(1))); ASSERT_THAT(args_[0].GetVec(), ElementsAre(ArrArg(1), ArrArg(1))); ASSERT_THAT(args_[1].GetVec(), ElementsAre("car")); } TEST_F(RedisParserTest, UsedMemory) { vector> blobs; for (size_t i = 0; i < 100; ++i) { blobs.emplace_back(vector(200)); } EXPECT_GT(cmn::HeapSize(blobs), 20000); std::vector> stash; RespVec vec; for (unsigned i = 0; i < 10; ++i) { vec.emplace_back(RespExpr::STRING); vec.back().u = RespExpr::Buffer(nullptr, 0); } for (unsigned i = 0; i < 100; i++) { stash.emplace_back(new RespExpr::Vec(vec)); } EXPECT_GT(cmn::HeapSize(stash), 30000); } TEST_F(RedisParserTest, Eol) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r")); EXPECT_EQ(3, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\n$5\r\n")); EXPECT_EQ(5, consumed_); } TEST_F(RedisParserTest, BulkSplit) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n$4\r\nSADD\r")); ASSERT_EQ(13, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\n")); } TEST_F(RedisParserTest, InlineSplit) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\n")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::OK, Parse("\nPING\n\n")); EXPECT_EQ(6, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\n")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("P")); ASSERT_EQ(RedisParser::OK, Parse("ING\n")); } TEST_F(RedisParserTest, InlineReset) { ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("\t \r\n")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RedisParser::OK, Parse("*1\r\n$3\r\nfoo\r\n")); EXPECT_EQ(13, consumed_); } } // namespace facade ================================================ FILE: src/facade/reply_builder.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/reply_builder.h" #include #include #include #include #include #include #include "absl/strings/escaping.h" #include "absl/types/span.h" #include "base/logging.h" #include "facade/error.h" #include "util/fibers/proactor_base.h" #ifdef __APPLE__ #ifndef UIO_MAXIOV // Some versions of MacOSX dont have IOV_MAX #define UIO_MAXIOV 1024 #endif #endif using namespace std; using namespace double_conversion; namespace facade { namespace { constexpr char kCRLF[] = "\r\n"; constexpr char kSimplePref[] = "+"; constexpr char kLengthPrefix[] = "$"; constexpr char kDoublePref[] = ","; constexpr char kLongPref[] = ":"; constexpr char kNullStringR2[] = "$-1\r\n"; constexpr char kNullStringR3[] = "_\r\n"; constexpr unsigned kConvFlags = DoubleToStringConverter::UNIQUE_ZERO | DoubleToStringConverter::EMIT_POSITIVE_EXPONENT_SIGN; DoubleToStringConverter dfly_conv(kConvFlags, "inf", "nan", 'e', -6, 21, 6, 0); template size_t piece_size(const T& v) { if constexpr (is_array_v) return ABSL_ARRAYSIZE(v) - 1; // expect null terminated else if constexpr (is_integral_v) return absl::numbers_internal::kFastToBufferSize; else // string_view return v.size(); } template char* write_piece(const char (&arr)[S], char* dest) { return (char*)memcpy(dest, arr, S - 1) + (S - 1); } template enable_if_t, char*> write_piece(T num, char* dest) { static_assert(!is_same_v, "Use arrays for single chars"); return absl::numbers_internal::FastIntToBuffer(num, dest); } char* write_piece(string_view str, char* dest) { return (char*)memcpy(dest, str.data(), str.size()) + str.size(); } } // namespace thread_local SinkReplyBuilder::PendingList SinkReplyBuilder::pending_list; SinkReplyBuilder::ReplyAggregator::~ReplyAggregator() { rb->batched_ = prev; if (!prev) rb->Flush(); } SinkReplyBuilder::ReplyScope::~ReplyScope() { rb->scoped_ = prev; if (!prev) rb->FinishScope(); } void SinkReplyBuilder::SendError(ErrorReply error) { if (error.status) return SendError(*error.status); SendError(error.ToSv(), error.kind); } void SinkReplyBuilder::SendError(OpStatus status) { if (status == OpStatus::OK) return SendSimpleString("OK"); SendError(StatusToMsg(status)); } void SinkReplyBuilder::CloseConnection() { if (!ec_) ec_ = std::make_error_code(std::errc::connection_aborted); } template void SinkReplyBuilder::WritePieces(Ts&&... pieces) { if (size_t required = (piece_size(pieces) + ...); buffer_.AppendLen() <= required) Flush(required); auto iovec_end = [](const iovec& v) { return reinterpret_cast(v.iov_base) + v.iov_len; }; // Ensure last iovec points to buffer segment char* dest = reinterpret_cast(buffer_.AppendBuffer().data()); if (vecs_.empty()) { vecs_.push_back(iovec{dest, 0}); } else if (iovec_end(vecs_.back()) != dest) { if (vecs_.size() >= IOV_MAX - 2) Flush(); dest = reinterpret_cast(buffer_.AppendBuffer().data()); vecs_.push_back(iovec{dest, 0}); } DCHECK(iovec_end(vecs_.back()) == dest); char* ptr = dest; ([&]() { ptr = write_piece(pieces, ptr); }(), ...); size_t written = ptr - dest; buffer_.CommitWrite(written); vecs_.back().iov_len += written; total_size_ += written; } void SinkReplyBuilder::WriteRef(std::string_view str) { if (vecs_.size() >= IOV_MAX - 2) Flush(); vecs_.push_back(iovec{const_cast(str.data()), str.size()}); total_size_ += str.size(); } void SinkReplyBuilder::Flush(size_t expected_buffer_cap) { if (!vecs_.empty()) Send(); // Grow backing buffer if was at least half full and still below it's max size if (buffer_.InputLen() * 2 > buffer_.Capacity() && buffer_.Capacity() * 2 <= kMaxBufferSize) expected_buffer_cap = max(expected_buffer_cap, buffer_.Capacity() * 2); total_size_ = 0; buffer_.Clear(); vecs_.clear(); guaranteed_pieces_ = 0; DCHECK_LE(expected_buffer_cap, kMaxBufferSize); // big strings should be enqueued as iovecs if (expected_buffer_cap > buffer_.Capacity()) buffer_.Reserve(expected_buffer_cap); } uint64_t SinkReplyBuilder::GetLastSendTimeNs() const { return send_time_ns_; } void SinkReplyBuilder::Send() { DCHECK(sink_ != nullptr); DCHECK(!vecs_.empty()); auto& reply_stats = tl_facade_stats->reply_stats; send_time_ns_ = util::fb2::ProactorBase::GetMonotonicTimeNs(); PendingPin pin(send_time_ns_); pending_list.push_back(pin); reply_stats.io_write_cnt++; reply_stats.io_write_bytes += total_size_; DVLOG(2) << "Writing " << total_size_ << " bytes"; if (auto ec = sink_->Write(vecs_.data(), vecs_.size()); ec) ec_ = ec; auto it = PendingList::s_iterator_to(pin); pending_list.erase(it); send_time_ns_ = 0; uint64_t after_ns = util::fb2::ProactorBase::GetMonotonicTimeNs(); reply_stats.send_stats.count++; reply_stats.send_stats.total_duration += (after_ns - pin.timestamp_ns); DVLOG(2) << "Finished writing " << total_size_ << " bytes"; } void SinkReplyBuilder::FinishScope() { replies_recorded_++; if (!batched_ || total_size_ * 2 >= kMaxBufferSize /* copying isn't worth it */) return Flush(); // Check if we have enough space to copy all refs to buffer size_t ref_bytes = total_size_ - buffer_.InputLen(); if (ref_bytes > buffer_.AppendLen()) return Flush(ref_bytes); // Copy all external references to buffer to safely keep batching for (size_t i = guaranteed_pieces_; i < vecs_.size(); i++) { auto ib = buffer_.InputBuffer(); if (vecs_[i].iov_base >= ib.data() && vecs_[i].iov_base <= ib.data() + ib.size()) continue; // this is a piece DCHECK_LE(vecs_[i].iov_len, buffer_.AppendLen()); void* dest = buffer_.AppendBuffer().data(); memcpy(dest, vecs_[i].iov_base, vecs_[i].iov_len); buffer_.CommitWrite(vecs_[i].iov_len); vecs_[i].iov_base = dest; } guaranteed_pieces_ = vecs_.size(); // all vecs are pieces } MCReplyBuilder::MCReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) { } void MCReplyBuilder::SendValue(MemcacheCmdFlags cmd_flags, std::string_view key, std::string_view value, uint64_t mc_token, uint32_t mc_flag, uint32_t ttl_sec) { ReplyScope scope(this); if (cmd_flags.meta) { string flags; if (cmd_flags.return_flags) absl::StrAppend(&flags, " f", mc_flag); if (cmd_flags.return_cas) absl::StrAppend(&flags, " c", mc_token); if (cmd_flags.return_ttl) absl::StrAppend(&flags, " t", ttl_sec); if (cmd_flags.return_value) { WritePieces("VA ", value.size(), flags, kCRLF); if (value.size() <= kMaxInlineSize) { WritePieces(value, kCRLF); } else { WriteRef(value); WritePieces(kCRLF); } } else { WritePieces("HD ", flags, kCRLF); } } else { WritePieces("VALUE ", key, " ", mc_flag, " ", value.size()); if (cmd_flags.return_cas) WritePieces(" ", mc_token); if (value.size() <= kMaxInlineSize) { WritePieces(kCRLF, value, kCRLF); } else { WritePieces(kCRLF); WriteRef(value); WritePieces(kCRLF); } } } void MCReplyBuilder::SendSimpleString(std::string_view str) { if (str.empty()) return; ReplyScope scope(this); WritePieces(str, kCRLF); } void MCReplyBuilder::SendLong(long val) { SendSimpleString(absl::StrCat(val)); } void MCReplyBuilder::SendError(string_view str, std::string_view type) { last_error_ = str; SendSimpleString(absl::StrCat("SERVER_ERROR ", str)); } void MCReplyBuilder::SendProtocolError(std::string_view str) { SendSimpleString(absl::StrCat("CLIENT_ERROR ", str)); } void MCReplyBuilder::SendClientError(string_view str) { SendSimpleString(absl::StrCat("CLIENT_ERROR ", str)); } void MCReplyBuilder::SendRaw(std::string_view str) { ReplyScope scope(this); WriteRef(str); } void RedisReplyBuilderBase::SendNull() { ReplyScope scope(this); IsResp3() ? WritePieces(kNullStringR3) : WritePieces(kNullStringR2); } void RedisReplyBuilderBase::SendSimpleString(std::string_view str) { ReplyScope scope(this); if (str.size() <= kMaxInlineSize * 2) return WritePieces(kSimplePref, str, kCRLF); WritePieces(kSimplePref); WriteRef(str); WritePieces(kCRLF); } void RedisReplyBuilderBase::SendBulkString(std::string_view str) { ReplyScope scope(this); if (str.size() <= kMaxInlineSize) return WritePieces(kLengthPrefix, uint32_t(str.size()), kCRLF, str, kCRLF); DVLOG(1) << "SendBulk " << str.size(); WritePieces(kLengthPrefix, uint32_t(str.size()), kCRLF); WriteRef(str); WritePieces(kCRLF); } void RedisReplyBuilderBase::SendLong(long val) { ReplyScope scope(this); WritePieces(kLongPref, val, kCRLF); } void RedisReplyBuilderBase::SendDouble(double val) { char buf[DoubleToStringConverter::kBase10MaximalLength + 8]; // +8 to be on the safe side. static_assert(ABSL_ARRAYSIZE(buf) < kMaxInlineSize, "Write temporary string from buf inline"); string_view val_str = FormatDouble(val, buf, ABSL_ARRAYSIZE(buf)); if (!IsResp3()) return SendBulkString(val_str); ReplyScope scope(this); WritePieces(kDoublePref, val_str, kCRLF); } void RedisReplyBuilderBase::SendNullArray() { ReplyScope scope(this); WritePieces("*-1", kCRLF); } constexpr static const char START_SYMBOLS2[4][2] = {"*", "~", "%", ">"}; static_assert(START_SYMBOLS2[unsigned(CollectionType::MAP)][0] == '%' && START_SYMBOLS2[unsigned(CollectionType::SET)][0] == '~'); void RedisReplyBuilderBase::StartCollection(unsigned len, CollectionType ct) { if (!IsResp3()) { // RESP2 supports only arrays if (ct == CollectionType::MAP) len *= 2; ct = CollectionType::ARRAY; } ReplyScope scope(this); WritePieces(START_SYMBOLS2[unsigned(ct)], len, kCRLF); } void RedisReplyBuilderBase::SendError(std::string_view str, std::string_view type) { ReplyScope scope(this); if (type.empty()) { type = str; if (type == kSyntaxErr) type = kSyntaxErrType; } tl_facade_stats->reply_stats.err_count[type]++; last_error_ = str; if (str[0] != '-') { WritePieces("-ERR "); } if (str.size() <= kMaxInlineSize) { WritePieces(str, kCRLF); } else { WriteRef(str); WritePieces(kCRLF); } } void RedisReplyBuilderBase::SendProtocolError(std::string_view str) { SendError(absl::StrCat("-ERR Protocol error: ", str), "protocol_error"); } char* RedisReplyBuilderBase::FormatDouble(double d, char* dest, unsigned len) { StringBuilder sb(dest, len); CHECK(dfly_conv.ToShortest(d, &sb)); return sb.Finalize(); } void RedisReplyBuilderBase::SendVerbatimString(std::string_view str, VerbatimFormat format) { DCHECK(format <= VerbatimFormat::MARKDOWN); if (!IsResp3()) return SendBulkString(str); ReplyScope scope(this); WritePieces("=", str.size() + 4, kCRLF, format == VerbatimFormat::MARKDOWN ? "mkd:" : "txt:"); if (str.size() <= kMaxInlineSize) WritePieces(str); else WriteRef(str); WritePieces(kCRLF); } std::string RedisReplyBuilderBase::SerializeCommand(std::string_view command) { return string{command} + kCRLF; } void RedisReplyBuilder::SendSimpleStrArr(const facade::ArgRange& strs) { ReplyScope scope(this); StartArray(strs.Size()); for (std::string_view str : strs) SendSimpleString(str); } void RedisReplyBuilder::SendBulkStrArr(const facade::ArgRange& strs, CollectionType ct) { ReplyScope scope(this); StartCollection(ct == CollectionType::MAP ? strs.Size() / 2 : strs.Size(), ct); for (std::string_view str : strs) SendBulkString(str); } void RedisReplyBuilder::SendScoredArray(ScoredArray arr, bool with_scores) { ReplyScope scope(this); StartArray((with_scores && !IsResp3()) ? arr.size() * 2 : arr.size()); for (const auto& [str, score] : arr) { if (with_scores && IsResp3()) StartArray(2); SendBulkString(str); if (with_scores) SendDouble(score); } } void RedisReplyBuilder::SendLabeledScoredArray(std::string_view arr_label, ScoredArray arr) { ReplyScope scope(this); StartArray(2); SendBulkString(arr_label); StartArray(arr.size()); for (const auto& [str, score] : arr) { StartArray(2); SendBulkString(str); SendDouble(score); } } template void RedisReplyBuilder::SendLongArr(absl::Span longs) { static_assert(std::is_integral_v, "Must use integral type"); ReplyScope scope(this); StartArray(longs.size()); for (auto v : longs) { if constexpr (std::is_unsigned_v) DCHECK_LE(uint64_t(v), uint64_t(std::numeric_limits::max())); SendLong(v); } } template void RedisReplyBuilder::SendLongArr(absl::Span); template void RedisReplyBuilder::SendLongArr(absl::Span); template void RedisReplyBuilder::SendLongArr(absl::Span); template void RedisReplyBuilder::SendLongArr(absl::Span); void RedisReplyBuilder::StartArray(unsigned len) { StartCollection(len, CollectionType::ARRAY); } void RedisReplyBuilder::SendEmptyArray() { StartArray(0); } } // namespace facade ================================================ FILE: src/facade/reply_builder.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "facade/facade_stats.h" #include "facade/facade_types.h" #include "io/io.h" namespace facade { enum class RespVersion { kResp2, kResp3 }; // Base class for all reply builders. Offer a simple high level interface for controlling output // modes and sending basic response types. class SinkReplyBuilder { struct GuardBase { bool prev; SinkReplyBuilder* rb; }; public: constexpr static size_t kMaxInlineSize = 32; constexpr static size_t kMaxBufferSize = 8192; struct PendingPin : public boost::intrusive::list_base_hook< ::boost::intrusive::link_mode<::boost::intrusive::normal_link>> { uint64_t timestamp_ns; PendingPin(uint64_t v = 0) : timestamp_ns(v) { } }; using PendingList = boost::intrusive::list, boost::intrusive::cache_last>; static thread_local PendingList pending_list; explicit SinkReplyBuilder(io::Sink* sink) : sink_(sink) { } virtual ~SinkReplyBuilder() = default; // USE WITH CARE! ReplyScope assumes that all string views in Send calls keep valid for the scopes // lifetime. This allows the builder to avoid copies by enqueueing long strings directly for // vectorized io. struct ReplyScope : GuardBase { explicit ReplyScope(SinkReplyBuilder* rb) : GuardBase{std::exchange(rb->scoped_, true), rb} { } ~ReplyScope(); }; // Aggregator reduces the number of raw send calls by copying data in an intermediate buffer. // Prefer ReplyScope if possible to additionally reduce the number of copies. struct ReplyAggregator : GuardBase { explicit ReplyAggregator(SinkReplyBuilder* rb) : GuardBase{std::exchange(rb->batched_, true), rb} { } ~ReplyAggregator(); }; void Flush(size_t expected_buffer_cap = 0); // Send all accumulated data and reset to clear state std::error_code GetError() const { return ec_; } size_t UsedMemory() const { return buffer_.Capacity(); } size_t RepliesRecorded() const { return replies_recorded_; } bool IsSendActive() const { return send_time_ns_ > 0; } void SetBatchMode(bool b) { batched_ = b; } void CloseConnection(); static const ReplyStats& GetThreadLocalStats() { return tl_facade_stats->reply_stats; } public: // High level interface virtual Protocol GetProtocol() const = 0; virtual void SendLong(long val) = 0; virtual void SendSimpleString(std::string_view str) = 0; void SendOk() { SendSimpleString("OK"); } virtual void SendError(std::string_view str, std::string_view type = {}) = 0; // MC and Redis void SendError(OpStatus status); void SendError(ErrorReply error); virtual void SendProtocolError(std::string_view str) = 0; std::string ConsumeLastError() { return std::exchange(last_error_, {}); } uint64_t GetLastSendTimeNs() const; protected: template void WritePieces(Ts&&... pieces); // Copy pieces into buffer and reference buffer void WriteRef(std::string_view str); // Add iovec bypassing buffer void FinishScope(); // Called when scope ends to flush buffer if needed void Send(); protected: size_t replies_recorded_ = 0; std::string last_error_; private: io::Sink* sink_; std::error_code ec_; bool scoped_ = false, batched_ = false; size_t total_size_ = 0; // sum of vec_ lengths base::IoBuf buffer_; // backing buffer for pieces // Stores iovecs for a single writev call. Can reference either the buffer (WritePiece) or // external data (WriteRef). Validity is ensured by FinishScope that either flushes before ref // lifetime ends or copies refs to the buffer. absl::InlinedVector vecs_; size_t guaranteed_pieces_ = 0; // length of prefix of vecs_ that are guaranteed to be pieces uint64_t send_time_ns_ = 0; }; class MCReplyBuilder : public SinkReplyBuilder { public: explicit MCReplyBuilder(::io::Sink* sink); ~MCReplyBuilder() override = default; Protocol GetProtocol() const final { return Protocol::MEMCACHE; } void SendError(std::string_view str, std::string_view type = std::string_view{}) final; void SendLong(long val) final; void SendClientError(std::string_view str); void SendValue(MemcacheCmdFlags cmd_flags, std::string_view key, std::string_view value, uint64_t mc_token, uint32_t mc_flag, uint32_t ttl_sec); void SendSimpleString(std::string_view str) final; void SendProtocolError(std::string_view str) final; void SendRaw(std::string_view str); }; // Redis reply builder interface for sending RESP data. class RedisReplyBuilderBase : public SinkReplyBuilder { public: enum VerbatimFormat : uint8_t { TXT, MARKDOWN }; explicit RedisReplyBuilderBase(io::Sink* sink) : SinkReplyBuilder(sink) { } ~RedisReplyBuilderBase() override = default; Protocol GetProtocol() const final { return Protocol::REDIS; } virtual void SendNull(); void SendSimpleString(std::string_view str) override; virtual void SendBulkString(std::string_view str); // RESP: Blob String void SendLong(long val) override; virtual void SendDouble(double val); // RESP: Number virtual void SendNullArray(); virtual void StartCollection(unsigned len, CollectionType ct); using SinkReplyBuilder::SendError; void SendError(std::string_view str, std::string_view type = {}) override; void SendProtocolError(std::string_view str) override; virtual void SendVerbatimString(std::string_view str, VerbatimFormat format = TXT); static char* FormatDouble(double d, char* dest, unsigned len); static std::string SerializeCommand(std::string_view command); bool IsResp3() const { return resp_ == RespVersion::kResp3; } void SetRespVersion(RespVersion resp_version) { resp_ = resp_version; } RespVersion GetRespVersion() { return resp_; } private: RespVersion resp_ = RespVersion::kResp2; }; // Non essential redis reply builder functions implemented on top of the base resp protocol class RedisReplyBuilder : public RedisReplyBuilderBase { public: using ScoredArray = absl::Span>; RedisReplyBuilder(io::Sink* sink) : RedisReplyBuilderBase(sink) { } ~RedisReplyBuilder() override = default; // One-liner for ReplyScope + StartArray struct ArrayScope : ReplyScope { ArrayScope(RedisReplyBuilder* rb, size_t len) : ReplyScope(rb) { rb->StartArray(len); } }; void SendSimpleStrArr(const facade::ArgRange& strs); void SendBulkStrArr(const facade::ArgRange& strs, CollectionType ct = CollectionType::ARRAY); template void SendLongArr(absl::Span longs); void SendScoredArray(ScoredArray arr, bool with_scores); void SendLabeledScoredArray(std::string_view arr_label, ScoredArray arr); void StartArray(unsigned len); void SendEmptyArray(); }; #define RETURN_ON_PARSE_ERROR(parser, rb) \ do { \ if (auto err = (parser).TakeError(); err) { \ return (rb)->SendError(err.MakeReply()); \ } \ } while (0) } // namespace facade ================================================ FILE: src/facade/reply_builder_test.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/reply_builder.h" #include #include #include #include #include #include "base/gtest.h" #include "base/logging.h" #include "facade/error.h" #include "facade/facade_test.h" #include "facade/redis_parser.h" #include "facade/reply_capture.h" #include "facade/resp_expr_test_utils.h" using namespace testing; using namespace std; namespace facade { namespace { const std::string_view kErrorStrPreFix = "-ERR "; constexpr std::string_view kCRLF = "\r\n"; constexpr char kErrorStartChar = '-'; constexpr char kStringStartChar = '+'; constexpr std::string_view kOKMessage = "+OK\r\n"; constexpr char kArrayStart = '*'; constexpr char kBulkString = '$'; constexpr char kIntStart = ':'; const std::string_view kIntStartString = ":"; const std::string_view kNullBulkString = "$-1\r\n"; const std::string_view kBulkStringStart = "$"; const std::string_view kStringStart = "+"; const std::string_view kErrorStart = "-"; const std::string_view kArrayStartString = "*"; constexpr std::size_t kMinPayloadLen = 3; // the begin type char and "\r\n" at the end std::string BuildExpectedErrorString(std::string_view msg) { if (msg.at(0) == kErrorStartChar) { return absl::StrCat(msg, kCRLF); } else { return absl::StrCat(kErrorStrPreFix, msg, kCRLF); } } std::string_view GetErrorType(std::string_view err) { return err == kSyntaxErr ? kSyntaxErrType : err; } } // namespace class RedisReplyBuilderTest : public testing::Test { public: struct ParsingResults { RedisParser::Result result = RedisParser::OK; RespExpr::Vec args; std::uint32_t consumed = 0; ParsingResults(std::optional obj = std::nullopt, size_t buf_pos = 0) { if (!obj.has_value() || obj->Empty()) { return; } holder_.emplace(std::move(*obj)); result = RedisParser::OK; consumed = buf_pos; if (holder_->GetType() == RESPObj::Type::ARRAY) { auto arr = holder_->As(); if (!arr.has_value()) { result = RedisParser::BAD_ARRAYLEN; return; } args.reserve(arr->Size()); for (size_t i = 0; i < arr->Size(); ++i) { args.push_back(expr_builder_.BuildExpr((*arr)[i])); } return; } args.push_back(expr_builder_.BuildExpr(*holder_)); } bool Verify(std::uint32_t expected) const { return consumed == expected && result == RedisParser::OK; } bool IsError() const { return result != RedisParser::OK || (args.size() == 1 && args[0].type == RespExpr::ERROR); } bool IsOk() const { return IsString(); } bool IsNull() const { return result == RedisParser::OK && args.size() == 1 && args.at(0).type == RespExpr::NIL; } bool IsString() const { return args.size() == 1 && result == RedisParser::OK && args[0].type == RespExpr::STRING; } private: std::optional holder_; RespExprBuilder expr_builder_; }; void SetUp() { sink_.Clear(); builder_.reset(new RedisReplyBuilder(&sink_)); ResetStats(); } static void SetUpTestSuite() { tl_facade_stats = new FacadeStats; init_zmalloc_threadlocal(mi_heap_get_backing()); } protected: std::vector RawTokenizedMessage() const { CHECK(!str().empty()); return absl::StrSplit(str(), kCRLF); } std::string_view str() const { return sink_.str(); } std::string TakePayload() { std::string ret = sink_.str(); sink_.Clear(); return ret; } std::size_t SinkSize() const { return str().size(); } unsigned GetError(string_view err) const { const auto& map = SinkReplyBuilder::GetThreadLocalStats().err_count; auto it = map.find(err); return it == map.end() ? 0 : it->second; } static bool NoErrors() { return tl_facade_stats->reply_stats.err_count.empty(); } static const ReplyStats& GetReplyStats() { return tl_facade_stats->reply_stats; } // Breaks the string we have in sink into tokens. // In RESP each token is build up from series of bytes follow by "\r\n" // This function don't try to parse the message, only to break the strings based // on the delimiter "\r\n". It is up to the test to verify these tokens std::vector TokenizeMessage() const; // Call the redis parser with the data in the sink ParsingResults Parse(); io::StringSink sink_; std::unique_ptr builder_; std::unique_ptr parser_buffer_; }; std::vector RedisReplyBuilderTest::TokenizeMessage() const { std::vector message_tokens = RawTokenizedMessage(); CHECK(message_tokens.back().empty()); // we're expecting to last to be empty as it only has \r\n message_tokens.pop_back(); // remove this empty entry std::string_view data = str(); switch (data[0]) { case kArrayStart: // in the case of array. we cannot tell the expected tokens number without doing parsing for // sub elements break; case kBulkString: if (data == kNullBulkString) { CHECK(message_tokens.size() == 1) << "NULL bulk string should only have one token, got " << message_tokens.size(); } else { CHECK(message_tokens.size() == 2) << "bulk string should only have two tokens, got " << message_tokens.size(); } break; case kErrorStartChar: case kStringStartChar: case kIntStart: // for errors and string and ints we don't really need to split as there must be only one // entry for \r\n CHECK(message_tokens.size() == 1) << "string/error message must have only one token got " << message_tokens.size(); break; default: LOG(FATAL) << "invalid start char [" << data[0] << "]"; break; } return message_tokens; } std::ostream& operator<<(std::ostream& os, const RedisReplyBuilderTest::ParsingResults& res) { os << "result{consumed bytes:" << res.consumed << ", status: " << res.result << " result count " << res.args.size() << ", first entry result: "; if (!res.args.empty()) { if (res.args.size() > 1) { os << "ARRAY: "; } for (const auto& e : res.args) { os << e << "\n"; } } else { os << "NILL"; } return os << "}"; } RedisReplyBuilderTest::ParsingResults RedisReplyBuilderTest::Parse() { parser_buffer_.reset(new uint8_t[SinkSize()]); auto* ptr = parser_buffer_.get(); memcpy(ptr, str().data(), SinkSize()); RESPParser parser; auto resp_obj = parser.Feed(reinterpret_cast(ptr), SinkSize()); size_t buf_pos = parser.BufferPos(); buf_pos = resp_obj && !buf_pos ? SinkSize() : buf_pos; // after parsing if success buf_pos can be 0 ParsingResults result(std::move(resp_obj), buf_pos); return result; } /////////////////////////////////////////////////////////////////////////////// TEST_F(RedisReplyBuilderTest, MessageSend) { // Test each message that is "sent" to the sink builder_->SinkReplyBuilder::SendOk(); ASSERT_EQ(TakePayload(), kOKMessage); builder_->StartArray(10); std::string_view hello_msg = "hello"; builder_->SendBulkString(hello_msg); std::string expected_bulk_string = absl::StrCat( "*10\r\n", kBulkStringStart, std::to_string(hello_msg.size()), kCRLF, hello_msg, kCRLF); ASSERT_EQ(TakePayload(), expected_bulk_string); } TEST_F(RedisReplyBuilderTest, SimpleError) { // test with simple error case. This means that we must comply to // https://redis.io/docs/reference/protocol-spec/#resp-errors std::string_view error = "my error"; std::string_view empty_type; builder_->SendError(error, empty_type); // must start with "-" and ends with "\r\n" // ASSERT_EQ(sink_.str().at(0), kErrorStartChar); ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); ASSERT_EQ(GetError(error), 1); ASSERT_EQ(str(), BuildExpectedErrorString(error)) << " error different from expected - '" << str() << "'"; auto parsing = Parse(); ASSERT_TRUE(parsing.Verify(SinkSize())); ASSERT_TRUE(parsing.IsError()) << " result: " << parsing; EXPECT_THAT(parsing.args, ElementsAre(ErrArg(absl::StrCat("ERR ", error)))); sink_.Clear(); builder_->SendError(OpStatus::OK); // in this case we should not have an error string ASSERT_TRUE(absl::StartsWith(str(), kStringStart)); ASSERT_EQ(str(), kOKMessage); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); ASSERT_EQ(GetError(error), 1); parsing = Parse(); ASSERT_TRUE(parsing.Verify(SinkSize())); ASSERT_TRUE(parsing.IsOk()) << " result: " << parsing; EXPECT_THAT(parsing.args, ElementsAre("OK")); } TEST_F(RedisReplyBuilderTest, VeryLongError) { std::string long_error(10 * 1024, 'X'); // 10KB error std::string_view empty_type; builder_->SendError(long_error, empty_type); ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); } TEST_F(RedisReplyBuilderTest, ErrorBuiltInMessage) { OpStatus error_codes[] = { OpStatus::KEY_NOTFOUND, OpStatus::OUT_OF_RANGE, OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY, OpStatus::INVALID_FLOAT, OpStatus::INVALID_INT, OpStatus::SYNTAX_ERR, OpStatus::BUSY_GROUP, OpStatus::INVALID_NUMERIC_RESULT}; for (const auto& err : error_codes) { const std::string_view error_name = StatusToMsg(err); const std::string_view error_type = GetErrorType(error_name); sink_.Clear(); builder_->SendError(err); ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)) << " invalid start char for " << err; ASSERT_TRUE(absl::EndsWith(str(), kCRLF)) << " failed to find correct termination at " << err; ASSERT_EQ(GetError(error_type), 1) << " number of error count is invalid for " << err; ASSERT_EQ(str(), BuildExpectedErrorString(error_name)) << " error different from expected - '" << str() << "'"; auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())) << " verify for the result is invalid for " << err; ASSERT_TRUE(parsing_output.IsError()) << " expecting error for " << err; } } TEST_F(RedisReplyBuilderTest, ErrorReplyBuiltInMessage) { ErrorReply err{OpStatus::OUT_OF_RANGE}; builder_->SendError(err); ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); ASSERT_EQ(GetError(kIndexOutOfRange), 1); ASSERT_EQ(str(), BuildExpectedErrorString(kIndexOutOfRange)); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())); ASSERT_TRUE(parsing_output.IsError()); sink_.Clear(); err = ErrorReply{"e1", "e2"}; builder_->SendError(err); ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); ASSERT_EQ(GetError("e2"), 1); ASSERT_EQ(str(), BuildExpectedErrorString("e1")); parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())); ASSERT_TRUE(parsing_output.IsError()); } TEST_F(RedisReplyBuilderTest, ErrorNoneBuiltInMessage) { // All these op codes creating the same error message OpStatus none_unique_codes[] = {OpStatus::SKIPPED, OpStatus::KEY_EXISTS, OpStatus::INVALID_VALUE, OpStatus::TIMED_OUT, OpStatus::STREAM_ID_SMALL}; uint64_t error_count = 0; for (const auto& err : none_unique_codes) { const std::string_view error_name = StatusToMsg(err); const std::string_view error_type = GetErrorType(error_name); sink_.Clear(); builder_->SendError(err); ASSERT_TRUE(absl::StartsWith(str(), kErrorStart)) << " invalid start char for " << err; ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); auto current_error_count = GetError(error_type); error_count++; ASSERT_EQ(current_error_count, error_count) << " number of error count is invalid for " << err; auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())) << " verify for the result is invalid for " << err; ASSERT_TRUE(parsing_output.IsError()) << " expecting error for " << err; } } TEST_F(RedisReplyBuilderTest, StringMessage) { // This would test a message that contain a string in it // For string this is simple, any string message should start with + and ends with \r\n // there can never be more than single \r\n in it as well as no special chars const std::string_view payloads[] = { "this is a string message", "$$$$$", "12334", "1v%6&*", "@@@", "----", "!!!"}; for (auto payload : payloads) { const std::size_t expected_len = payload.size() + kCRLF.size() + 1; // include '+' at the start sink_.Clear(); builder_->SendSimpleString(payload); ASSERT_EQ(SinkSize(), expected_len); ASSERT_TRUE(absl::StartsWith(str(), kStringStart)); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); // auto message_payload = SimpleStringPayload(); // ASSERT_EQ(message_payload, payload); ASSERT_TRUE(absl::StartsWith(str(), kStringStart)); ASSERT_TRUE(absl::EndsWith(str(), kCRLF)); auto data = str(); data.remove_suffix(kCRLF.size()); ASSERT_TRUE(absl::EndsWith(data, payload)); } } TEST_F(RedisReplyBuilderTest, EmptyArray) { // This test would build an array and try sending it over the "wire" // The array starts with the '*', then the number of elements in the array // then "\r\n", then each element inside is encoded accordingly // an empty array has this "*0\r\n" form const std::string_view empty_array = "*0\r\n"; const std::string_view null_array = "*-1\r\n"; builder_->StartArray(0); ASSERT_EQ(str(), empty_array); sink_.Clear(); builder_->SendNullArray(); ASSERT_EQ(null_array, str()); sink_.Clear(); builder_->SendEmptyArray(); ASSERT_EQ(str(), empty_array); } TEST_F(RedisReplyBuilderTest, StrArray) { std::vector string_vector{"hello", "world", "111", "@3#$^&*~"}; builder_->StartArray(string_vector.size()); std::size_t expected_size = kCRLF.size() + 2; for (auto s : string_vector) { builder_->SendSimpleString(s); expected_size += s.size() + kCRLF.size() + 1; ASSERT_TRUE(NoErrors()); } ASSERT_EQ(SinkSize(), expected_size); // ASSERT_EQ(kArrayStart, str().at(0)); ASSERT_TRUE(absl::StartsWith(str(), absl::StrCat(kArrayStartString, 4))); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())) << " invalid parsing for the array message by the parser: " << parsing_output; ASSERT_EQ(string_vector.size(), parsing_output.args.size()); ASSERT_THAT(parsing_output.args, ElementsAre(string_vector[0], string_vector[1], string_vector[2], string_vector[3])); std::vector message_tokens = TokenizeMessage(); ASSERT_THAT(message_tokens, ElementsAre("*4", absl::StrCat(kStringStart, string_vector[0]), absl::StrCat(kStringStart, string_vector[1]), absl::StrCat(kStringStart, string_vector[2]), absl::StrCat(kStringStart, string_vector[3]))); } TEST_F(RedisReplyBuilderTest, SendSimpleStrArr) { // This would send array of strings, but with different API than TestStrArray test const std::string_view kArrayMessage[] = { // random values "+++", "---", "$$$", "~~~~", "@@@", "^^^", "1234", "foo"}; const std::size_t kArrayLen = sizeof(kArrayMessage) / sizeof(kArrayMessage[0]); builder_->SendSimpleStrArr(kArrayMessage); ASSERT_TRUE(NoErrors()); // Tokenize the message and verify content std::vector message_tokens = TokenizeMessage(); ASSERT_THAT(message_tokens, ElementsAre(absl::StrCat(kArrayStartString, kArrayLen), absl::StrCat(kStringStart, kArrayMessage[0]), absl::StrCat(kStringStart, kArrayMessage[1]), absl::StrCat(kStringStart, kArrayMessage[2]), absl::StrCat(kStringStart, kArrayMessage[3]), absl::StrCat(kStringStart, kArrayMessage[4]), absl::StrCat(kStringStart, kArrayMessage[5]), absl::StrCat(kStringStart, kArrayMessage[6]), absl::StrCat(kStringStart, kArrayMessage[7]))); auto parsed_message = Parse(); ASSERT_THAT(parsed_message.args, ElementsAre(kArrayMessage[0], kArrayMessage[1], kArrayMessage[2], kArrayMessage[3], kArrayMessage[4], kArrayMessage[5], kArrayMessage[6], kArrayMessage[7])); } TEST_F(RedisReplyBuilderTest, SendStringViewArr) { // This would send array of strings, but with different API than TestStrArray test const std::vector kArrayMessage{ // random values "(((", "}}}", "&&&&", "####", "___", "+++", "0.1234", "bar"}; builder_->SendBulkStrArr(kArrayMessage); ASSERT_TRUE(NoErrors()); // verify content std::vector message_tokens = TokenizeMessage(); // the form of this is *\r\n$\r\n..$\r\n\r\n ASSERT_THAT( message_tokens, ElementsAre(absl::StrCat(kArrayStartString, kArrayMessage.size()), // array size // size + string 0..N absl::StrCat(kBulkStringStart, kArrayMessage[0].size()), kArrayMessage[0], absl::StrCat(kBulkStringStart, kArrayMessage[1].size()), kArrayMessage[1], absl::StrCat(kBulkStringStart, kArrayMessage[2].size()), kArrayMessage[2], absl::StrCat(kBulkStringStart, kArrayMessage[3].size()), kArrayMessage[3], absl::StrCat(kBulkStringStart, kArrayMessage[4].size()), kArrayMessage[4], absl::StrCat(kBulkStringStart, kArrayMessage[5].size()), kArrayMessage[5], absl::StrCat(kBulkStringStart, kArrayMessage[6].size()), kArrayMessage[6], absl::StrCat(kBulkStringStart, kArrayMessage[7].size()), kArrayMessage[7])); // Check the parsed message auto parsed_message = Parse(); ASSERT_THAT(parsed_message.args, ElementsAre(kArrayMessage[0], kArrayMessage[1], kArrayMessage[2], kArrayMessage[3], kArrayMessage[4], kArrayMessage[5], kArrayMessage[6], kArrayMessage[7])); } TEST_F(RedisReplyBuilderTest, SendBulkStringArr) { // This would send array of strings, but with different API than TestStrArray test const std::vector kArrayMessage{ // Test this one with large values std::string(1024, '.'), std::string(2048, ','), std::string(4096, ' ')}; builder_->SendBulkStrArr(kArrayMessage); ASSERT_TRUE(NoErrors()); std::vector message_tokens = TokenizeMessage(); // the form of this is *\r\n$\r\n..$\r\n\r\n ASSERT_THAT( message_tokens, ElementsAre(absl::StrCat(kArrayStartString, kArrayMessage.size()), // array size // size + string 0..N absl::StrCat(kBulkStringStart, kArrayMessage[0].size()), kArrayMessage[0], absl::StrCat(kBulkStringStart, kArrayMessage[1].size()), kArrayMessage[1], absl::StrCat(kBulkStringStart, kArrayMessage[2].size()), kArrayMessage[2])); // Check the parsed message auto parsed_message = Parse(); ASSERT_TRUE(parsed_message.Verify(SinkSize())) << "message was not successfully parsed: " << parsed_message; ASSERT_THAT(parsed_message.args, ElementsAre(kArrayMessage[0], kArrayMessage[1], kArrayMessage[2])); } TEST_F(RedisReplyBuilderTest, NullBulkString) { // null bulk string == "$-1\r\n" i.e. '$' + -1 + \r + \n builder_->SendNull(); ASSERT_TRUE(NoErrors()); ASSERT_EQ(str(), kNullBulkString); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())); ASSERT_TRUE(parsing_output.IsNull()); ASSERT_THAT(parsing_output.args, ElementsAre(ArgType(RespExpr::NIL))); } TEST_F(RedisReplyBuilderTest, EmptyBulkString) { // empty bulk string is in the form of "$0\r\n\r\n", i.e. length 0 after $ follow by \r\n*2 const std::string_view kEmptyBulkString = "$0\r\n\r\n"; builder_->SendBulkString(std::string_view{}); ASSERT_TRUE(NoErrors()); ASSERT_EQ(str(), kEmptyBulkString); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.Verify(SinkSize())); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(std::string_view{})); } TEST_F(RedisReplyBuilderTest, NoAsciiBulkString) { // Bulk string may contain none ascii chars const char random_bytes[] = {0x12, 0x25, 0x37}; std::size_t data_size = sizeof(random_bytes) / sizeof(random_bytes[0]); std::string_view none_ascii_payload{random_bytes, data_size}; builder_->SendBulkString(none_ascii_payload); ASSERT_TRUE(NoErrors()); const std::string expected_payload = absl::StrCat(kBulkStringStart, data_size, kCRLF, none_ascii_payload, kCRLF); ASSERT_EQ(str(), expected_payload); std::vector message_tokens = TokenizeMessage(); ASSERT_EQ(message_tokens.size(), 2); // length and payload ASSERT_THAT(message_tokens, ElementsAre(absl::StrCat(kBulkStringStart, data_size), none_ascii_payload)); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(none_ascii_payload)); } TEST_F(RedisReplyBuilderTest, BulkStringWithCRLF) { // Verify bulk string that contains the \r\n as payload std::string_view crlf_chars{"\r\n"}; builder_->SendBulkString(crlf_chars); ASSERT_TRUE(NoErrors()); // the expected message in this case is $2\r\n\r\n\r\n std::string expected_message = absl::StrCat(kBulkStringStart, crlf_chars.size(), kCRLF, crlf_chars, kCRLF); ASSERT_EQ(str(), expected_message); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(crlf_chars)); } TEST_F(RedisReplyBuilderTest, BulkStringWithStartBulkString) { // check a bulk string that contains $ as payload std::string message = absl::StrCat(kBulkStringStart, "10"); std::string expected_message = absl::StrCat(kBulkStringStart, message.size(), kCRLF, message, kCRLF); builder_->SendBulkString(message); ASSERT_TRUE(NoErrors()); ASSERT_EQ(str(), expected_message); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(message)); } TEST_F(RedisReplyBuilderTest, BulkStringWithStarString) { std::string message = absl::StrCat(kStringStart, "a string message"); std::string expected_message = absl::StrCat(kBulkStringStart, message.size(), kCRLF, message, kCRLF); builder_->SendBulkString(message); ASSERT_EQ(str(), expected_message); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(message)); } TEST_F(RedisReplyBuilderTest, BulkStringWithErrorString) { std::string message = absl::StrCat(kErrorStrPreFix, kSyntaxErrType); std::string expected_message = absl::StrCat(kBulkStringStart, message.size(), kCRLF, message, kCRLF); builder_->SendBulkString(message); ASSERT_TRUE(NoErrors()); ASSERT_EQ(str(), expected_message); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(message)); } TEST_F(RedisReplyBuilderTest, Int) { // message in the form of ":0\r\n" and ":1000\r\n" // this message just starts with ':' and ends with \r\n // and the payload must be successfully parsed into int type const long kPayloadInt = 12345; const std::string expected_output = absl::StrCat(kIntStartString, kPayloadInt, kCRLF); builder_->SendLong(kPayloadInt); ASSERT_EQ(str(), expected_output); long value = 0; std::string_view expected_payload = str().substr(1, SinkSize() - kMinPayloadLen); ASSERT_TRUE(absl::SimpleAtoi(expected_payload, &value)); ASSERT_EQ(value, kPayloadInt); auto parsing_output = Parse(); ASSERT_THAT(parsing_output.args, ElementsAre(IntArg(kPayloadInt))); } TEST_F(RedisReplyBuilderTest, Double) { // There is no direct support for double types in RESP // to send this, it is sent as bulk string const std::string_view kPayloadStr = "23.456"; double double_value = 0; CHECK(absl::SimpleAtod(kPayloadStr, &double_value)); const std::string expected_payload = absl::StrCat(kBulkStringStart, kPayloadStr.size(), kCRLF, kPayloadStr, kCRLF); builder_->SendDouble(double_value); ASSERT_TRUE(NoErrors()); std::vector message_tokens = TokenizeMessage(); ASSERT_EQ(str(), expected_payload); ASSERT_THAT(message_tokens, ElementsAre(absl::StrCat(kBulkStringStart, kPayloadStr.size()), kPayloadStr)); auto parsing_output = Parse(); ASSERT_TRUE(parsing_output.IsString()); ASSERT_THAT(parsing_output.args, ElementsAre(kPayloadStr)); } TEST_F(RedisReplyBuilderTest, MixedTypeArray) { // For arrays, we can send an array that contains more than a single type (string/bulk // string/simple string/null..) In this test we are verifying that this is actually working. note // that this is not part of class RedisReplyBuilder API // The entries are: // array start // bulk string // int // int // simple string // simple string // empty bulk string // double (bulk string) std::string long_string(1024, '-'); const unsigned int kArraySize = 6; const char random_bytes[] = {0x12, 0x15, 0x2F}; const std::string_view kFirstBulkString{random_bytes, 3}; const long kFirstLongValue = 54321; const long kSecondLongValue = 87654321; const std::string_view kLongSimpleString{long_string}; const std::string_view kPayloadDoubleStr = "9987654321.0123"; double double_value = 0; CHECK(absl::SimpleAtod(kPayloadDoubleStr, &double_value)); builder_->StartArray(kArraySize); builder_->SendBulkString(kFirstBulkString); builder_->SendLong(kFirstLongValue); builder_->SendLong(kSecondLongValue); builder_->SendSimpleString(kLongSimpleString); // builder_->SendNull(); builder_->SendBulkString(std::string_view{}); builder_->SendDouble(double_value); const std::string_view output_msg = str(); ASSERT_FALSE(output_msg.empty()); ASSERT_TRUE(NoErrors()); std::vector message_tokens = TokenizeMessage(); ASSERT_THAT( message_tokens, ElementsAre(absl::StrCat(kArrayStartString, kArraySize), // the length absl::StrCat(kBulkStringStart, kFirstBulkString.size()), kFirstBulkString, absl::StrCat(kIntStartString, kFirstLongValue), absl::StrCat(kIntStartString, kSecondLongValue), absl::StrCat(kStringStart, kLongSimpleString), // ArgType(RespExpr::NIL), absl::StrCat(kBulkStringStart, "0"), std::string_view{}, absl::StrCat(kBulkStringStart, kPayloadDoubleStr.size()), kPayloadDoubleStr)); // // Now we need to parse it and make sure that its a valid message by the parser as well auto parsed_message = Parse(); ASSERT_THAT( parsed_message.args, ElementsAre(ArgType(RespExpr::STRING), ArgType(RespExpr::INT64), ArgType(RespExpr::INT64), ArgType(RespExpr::STRING), ArgType(RespExpr::STRING), ArgType(RespExpr::STRING))); } TEST_F(RedisReplyBuilderTest, BatchMode) { GTEST_SKIP() << "Some differences"; // Test that when the batch mode is enabled, we are getting the same correct results builder_->SetBatchMode(true); // Some random values and sizes const std::vector kInputArray{ std::string(10, 'p'), std::string(48, 'o'), std::string(67, 'y'), std::string(167, 'e'), std::string(478, '*'), std::string(164, 't'), }; builder_->StartArray(kInputArray.size()); ASSERT_EQ(SinkSize(), 0); int count = 0; std::size_t total_bytes = 0; for (const auto& val : kInputArray) { builder_->SendBulkString(val); ASSERT_EQ(SinkSize(), 0) << " sink is not empty at iteration number " << count; ASSERT_EQ(GetReplyStats().io_write_bytes, 0); ASSERT_EQ(GetReplyStats().io_write_cnt, 0); total_bytes += val.size(); ++count; } // in order to actually see the message, we need to disable the batching, then // write something builder_->SetBatchMode(false); builder_->SendBulkString(std::string_view{}); ASSERT_EQ(GetReplyStats().io_write_cnt, 1); // We expecting to have more than the total bytes we count, // since we are not counting the \r\n and the type char as well // as length entries ASSERT_GT(GetReplyStats().io_write_bytes, total_bytes); std::vector array_members = TokenizeMessage(); ASSERT_THAT(array_members, ElementsAre(absl::StrCat(kArrayStartString, kInputArray.size()), absl::StrCat(kBulkStringStart, kInputArray[0].size()), kInputArray[0], absl::StrCat(kBulkStringStart, kInputArray[1].size()), kInputArray[1], absl::StrCat(kBulkStringStart, kInputArray[2].size()), kInputArray[2], absl::StrCat(kBulkStringStart, kInputArray[3].size()), kInputArray[3], absl::StrCat(kBulkStringStart, kInputArray[4].size()), kInputArray[4], absl::StrCat(kBulkStringStart, kInputArray[5].size()), kInputArray[5], absl::StrCat(kBulkStringStart, "0"), std::string_view{})); } TEST_F(RedisReplyBuilderTest, Resp3Double) { builder_->SetRespVersion(RespVersion::kResp3); builder_->SendDouble(5.5); ASSERT_TRUE(NoErrors()); ASSERT_EQ(str(), ",5.5\r\n"); } TEST_F(RedisReplyBuilderTest, Resp3NullString) { builder_->SetRespVersion(RespVersion::kResp3); builder_->SendNull(); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "_\r\n"); } TEST_F(RedisReplyBuilderTest, SendStringArrayAsMap) { const std::vector map_array{"k1", "v1", "k2", "v2"}; builder_->SetRespVersion(RespVersion::kResp2); builder_->SendBulkStrArr(map_array, CollectionType::MAP); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*4\r\n$2\r\nk1\r\n$2\r\nv1\r\n$2\r\nk2\r\n$2\r\nv2\r\n") << "SendStringArrayAsMap Resp2 Failed."; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendBulkStrArr(map_array, CollectionType::MAP); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "%2\r\n$2\r\nk1\r\n$2\r\nv1\r\n$2\r\nk2\r\n$2\r\nv2\r\n") << "SendStringArrayAsMap Resp3 Failed."; } TEST_F(RedisReplyBuilderTest, SendStringArrayAsSet) { const std::vector set_array{"e1", "e2", "e3"}; builder_->SetRespVersion(RespVersion::kResp2); builder_->SendBulkStrArr(set_array, CollectionType::SET); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n") << "SendStringArrayAsSet Resp2 Failed."; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendBulkStrArr(set_array, CollectionType::SET); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "~3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n") << "SendStringArrayAsSet Resp3 Failed."; } TEST_F(RedisReplyBuilderTest, SendScoredArray) { const std::vector> scored_array{ {"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}}; builder_->SetRespVersion(RespVersion::kResp2); builder_->SendScoredArray(scored_array, false); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n") << "Resp2 WITHOUT scores failed."; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendScoredArray(scored_array, false); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n") << "Resp3 WITHOUT scores failed."; builder_->SetRespVersion(RespVersion::kResp2); builder_->SendScoredArray(scored_array, true); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*6\r\n$2\r\ne1\r\n$3\r\n1.1\r\n$2\r\ne2\r\n$3\r\n2.2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n") << "Resp3 WITHSCORES failed."; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendScoredArray(scored_array, true); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*3\r\n*2\r\n$2\r\ne1\r\n,1.1\r\n*2\r\n$2\r\ne2\r\n,2.2\r\n*2\r\n$2\r\ne3\r\n,3.3\r\n") << "Resp3 WITHSCORES failed."; } TEST_F(RedisReplyBuilderTest, SendLabeledScoredArray) { const std::vector> scored_array{ {"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}}; builder_->SetRespVersion(RespVersion::kResp2); builder_->SendLabeledScoredArray("foobar", scored_array); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n$3\r\n1.1\r\n*2\r\n$2\r\ne2\r\n$3\r\n2." "2\r\n*2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n") << "Resp3 failed.\n"; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendLabeledScoredArray("foobar", scored_array); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n,1.1\r\n*2\r\n$2\r\ne2\r\n,2.2\r\n*" "2\r\n$2\r\ne3\r\n,3.3\r\n") << "Resp3 failed."; } TEST_F(RedisReplyBuilderTest, BasicCapture) { using namespace std; string_view kTestSws[] = {"a1"sv, "a2"sv, "a3"sv, "a4"sv}; CapturingReplyBuilder crb{}; using RRB = RedisReplyBuilder; auto big_arr_cb = [](RRB* r) { r->StartArray(4); { r->StartArray(2); r->SendLong(1); r->StartArray(2); { r->SendLong(2); r->SendLong(3); } } r->SendLong(4); { r->StartArray(2); { r->StartArray(2); r->SendLong(5); r->SendLong(6); } r->SendLong(7); } r->SendLong(8); }; function funcs[] = { [](RRB* r) { r->SendNull(); }, [](RRB* r) { r->SendLong(1L); }, [](RRB* r) { r->SendDouble(6.7); }, [](RRB* r) { r->SendSimpleString("ok"); }, [](RRB* r) { r->SendEmptyArray(); }, [](RRB* r) { r->SendNullArray(); }, [](RRB* r) { r->SendError("e1", "e2"); }, [kTestSws](RRB* r) { r->SendSimpleStrArr(kTestSws); }, [kTestSws](RRB* r) { r->SendBulkStrArr(kTestSws); }, [kTestSws](RRB* r) { r->SendBulkStrArr(kTestSws, CollectionType::SET); }, [kTestSws](RRB* r) { r->SendBulkStrArr(kTestSws, CollectionType::MAP); }, [kTestSws](RRB* r) { r->StartArray(3); r->SendLong(1L); r->SendDouble(2.5); r->SendSimpleStrArr(kTestSws); }, big_arr_cb, }; crb.SetRespVersion(RespVersion::kResp3); builder_->SetRespVersion(RespVersion::kResp3); // Run generator functions on both a regular redis builder // and the capturing builder with its capture applied. for (auto& f : funcs) { f(builder_.get()); auto expected = TakePayload(); f(&crb); CapturingReplyBuilder::Apply(crb.Take(), builder_.get()); auto actual = TakePayload(); EXPECT_EQ(expected, actual); } builder_->SetRespVersion(RespVersion::kResp2); } TEST_F(RedisReplyBuilderTest, FormatDouble) { char buf[64]; auto format = [&](double d) { return RedisReplyBuilder::FormatDouble(d, buf, sizeof(buf)); }; EXPECT_STREQ("0.1", format(0.1)); EXPECT_STREQ("0.2", format(0.2)); EXPECT_STREQ("0.8", format(0.8)); EXPECT_STREQ("1.1", format(1.1)); EXPECT_STREQ("inf", format(INFINITY)); EXPECT_STREQ("-inf", format(-INFINITY)); EXPECT_STREQ("0", format(-0.0)); EXPECT_STREQ("1e-7", format(0.0000001)); EXPECT_STREQ("111111111111111110000", format(111111111111111111111.0)); EXPECT_STREQ("1.1111111111111111e+21", format(1111111111111111111111.0)); EXPECT_STREQ("1e-23", format(1e-23)); } TEST_F(RedisReplyBuilderTest, VerbatimString) { // test resp3 std::string str = "A simple string!"; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendVerbatimString(str, RedisReplyBuilder::VerbatimFormat::TXT); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "=20\r\ntxt:A simple string!\r\n") << "Resp3 VerbatimString TXT failed."; builder_->SetRespVersion(RespVersion::kResp3); builder_->SendVerbatimString(str, RedisReplyBuilder::VerbatimFormat::MARKDOWN); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "=20\r\nmkd:A simple string!\r\n") << "Resp3 VerbatimString TXT failed."; builder_->SetRespVersion(RespVersion::kResp2); builder_->SendVerbatimString(str); ASSERT_TRUE(NoErrors()); ASSERT_EQ(TakePayload(), "$16\r\nA simple string!\r\n") << "Resp3 VerbatimString TXT failed."; } TEST_F(RedisReplyBuilderTest, Issue3449) { vector records; for (unsigned i = 0; i < 10'000; ++i) { records.push_back(absl::StrCat(i)); } builder_->SendBulkStrArr(records); ASSERT_TRUE(NoErrors()); ParsingResults parse_result = Parse(); ASSERT_FALSE(parse_result.IsError()); EXPECT_EQ(10000, parse_result.args.size()); } TEST_F(RedisReplyBuilderTest, Issue4424) { vector records; for (unsigned i = 0; i < 800; ++i) { records.push_back(string(100, 'a')); } for (unsigned j = 0; j < 2; ++j) { builder_->SendBulkStrArr(records); ASSERT_TRUE(NoErrors()); ParsingResults parse_result = Parse(); ASSERT_FALSE(parse_result.IsError()) << int(parse_result.result); ASSERT_TRUE(parse_result.Verify(SinkSize())); EXPECT_EQ(800, parse_result.args.size()); sink_.Clear(); } } TEST_F(RedisReplyBuilderTest, MCMetaGetLargeValue) { io::StringSink mc_sink; MCReplyBuilder mc_builder(&mc_sink); MemcacheCmdFlags flags; flags.meta = true; flags.return_value = true; string large_val(16000, 'x'); mc_builder.SendValue(flags, "key", large_val, 0, 0, 0); string_view output = mc_sink.str(); EXPECT_THAT(output, HasSubstr("VA 16000")); EXPECT_THAT(output, HasSubstr(large_val)); } static void BM_FormatDouble(benchmark::State& state) { vector values; char buf[64]; uniform_real_distribution unif(0, 1e9); default_random_engine re; for (unsigned i = 0; i < 100; i++) { values.push_back(unif(re)); } while (state.KeepRunning()) { for (auto d : values) { RedisReplyBuilder::FormatDouble(d, buf, sizeof(buf)); } } } BENCHMARK(BM_FormatDouble); } // namespace facade ================================================ FILE: src/facade/reply_capture.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/reply_capture.h" #include "absl/types/span.h" #include "base/logging.h" #include "reply_capture.h" #define SKIP_LESS(needed) \ replies_recorded_++; \ if (reply_mode_ < needed) { \ current_ = monostate{}; \ return; \ } namespace facade { using namespace std; using namespace payload; void CapturingReplyBuilder::SendError(std::string_view str, std::string_view type) { last_error_ = str; SKIP_LESS(ReplyMode::ONLY_ERR); Capture(make_error(str, type)); } void CapturingReplyBuilder::SendNullArray() { SKIP_LESS(ReplyMode::FULL); Capture(unique_ptr{nullptr}); } void CapturingReplyBuilder::SendNull() { SKIP_LESS(ReplyMode::FULL); Capture(nullptr_t{}); } void CapturingReplyBuilder::SendLong(long val) { SKIP_LESS(ReplyMode::FULL); Capture(val); } void CapturingReplyBuilder::SendDouble(double val) { SKIP_LESS(ReplyMode::FULL); Capture(val); } void CapturingReplyBuilder::SendSimpleString(std::string_view str) { SKIP_LESS(ReplyMode::FULL); Capture(SimpleString{string{str}}); } void CapturingReplyBuilder::SendBulkString(std::string_view str) { SKIP_LESS(ReplyMode::FULL); Capture(BulkString{string{str}}); } void CapturingReplyBuilder::StartCollection(unsigned len, CollectionType type) { SKIP_LESS(ReplyMode::FULL); stack_.emplace(make_unique(len, type), type == CollectionType::MAP ? len * 2 : len); // If we added an empty collection, it must be collapsed immediately. CollapseFilledCollections(); } CapturingReplyBuilder::Payload CapturingReplyBuilder::Take() { CHECK(stack_.empty()); Payload pl = std::move(current_); current_ = monostate{}; return pl; } void CapturingReplyBuilder::SendDirect(Payload&& val) { replies_recorded_ += !holds_alternative(val); bool is_err = holds_alternative(val); ReplyMode min_mode = is_err ? ReplyMode::ONLY_ERR : ReplyMode::FULL; if (reply_mode_ >= min_mode) { DCHECK_EQ(current_.index(), 0u); current_ = std::move(val); } else { current_ = monostate{}; } } void CapturingReplyBuilder::Capture(Payload val, bool collapse_if_needed) { if (!stack_.empty()) { auto& last = stack_.top(); last.first->arr.push_back(std::move(val)); if (last.second-- == 1 && collapse_if_needed) { CollapseFilledCollections(); } } else { DCHECK_EQ(current_.index(), 0u); current_ = std::move(val); } } void CapturingReplyBuilder::CollapseFilledCollections() { while (!stack_.empty() && stack_.top().second == 0) { auto pl = std::move(stack_.top()); stack_.pop(); Capture(std::move(pl.first), false); } } struct CaptureVisitor { void operator()(monostate) { } void operator()(long v) { rb->SendLong(v); } void operator()(double v) { static_cast(rb)->SendDouble(v); } void operator()(const payload::SimpleString& ss) { rb->SendSimpleString(ss); } void operator()(const payload::BulkString& bs) { static_cast(rb)->SendBulkString(bs); } void operator()(payload::Null) { static_cast(rb)->SendNull(); } void operator()(const payload::Error& err) { rb->SendError(err->first, err->second); } void operator()(const unique_ptr& cp) { auto* builder = static_cast(rb); if (!cp) { builder->SendNullArray(); return; } if (cp->len == 0 && cp->type == CollectionType::ARRAY) { builder->SendEmptyArray(); return; } builder->StartCollection(cp->len, cp->type); for (auto& pl : cp->arr) visit(*this, std::move(pl)); } SinkReplyBuilder* rb; }; void CapturingReplyBuilder::Apply(Payload&& pl, SinkReplyBuilder* rb) { if (auto* crb = dynamic_cast(rb); crb != nullptr) { crb->SendDirect(std::move(pl)); return; } CaptureVisitor cv{rb}; visit(cv, std::move(pl)); } void CapturingReplyBuilder::SetReplyMode(ReplyMode mode) { reply_mode_ = mode; current_ = monostate{}; } optional CapturingReplyBuilder::TryExtractError( const Payload& pl) { if (auto* err = get_if(&pl); err != nullptr) { return ErrorRef{(*err)->first, (*err)->second}; } return nullopt; } } // namespace facade ================================================ FILE: src/facade/reply_capture.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "facade/reply_builder.h" #include "facade/reply_mode.h" #include "facade/reply_payload.h" namespace facade { struct CaptureVisitor; // CapturingReplyBuilder allows capturing replies and retrieveing them with Take(). // Those replies can be stored standalone and sent with // CapturingReplyBuilder::Apply() to another reply builder. class CapturingReplyBuilder : public RedisReplyBuilder { friend struct CaptureVisitor; public: using RedisReplyBuilder::SendError; void SendError(std::string_view str, std::string_view type) override; void SendLong(long val) override; void SendDouble(double val) override; void SendSimpleString(std::string_view str) override; void SendBulkString(std::string_view str) override; void StartCollection(unsigned len, CollectionType type) override; void SendNullArray() override; void SendNull() override; explicit CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL, RespVersion resp_v = RespVersion::kResp2) : RedisReplyBuilder{nullptr}, reply_mode_{mode} { SetRespVersion(resp_v); } using Payload = payload::Payload; // Non owned Error based on SendError arguments (msg, type) using ErrorRef = std::pair; void SetReplyMode(ReplyMode mode); // Take payload and clear state. Payload Take(); // Send payload to builder. static void Apply(Payload&& pl, SinkReplyBuilder* builder); // If an error is stored inside payload, get a reference to it. static std::optional TryExtractError(const Payload& pl); private: // Send payload directly, bypassing external interface. For efficient passing between two // captures. void SendDirect(Payload&& val); // Capture value and store eiter in current topmost collection or as a standalone value. void Capture(Payload val, bool collapse_if_needed = true); // While topmost collection in stack is full, finalize it and add it as a regular value. void CollapseFilledCollections(); ReplyMode reply_mode_; // List of nested active collections that are being built. std::stack, int>> stack_; // Root payload. Payload current_; }; } // namespace facade ================================================ FILE: src/facade/reply_mode.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once namespace facade { // Reply mode allows filtering replies. enum class ReplyMode { NONE, // No replies are recorded ONLY_ERR, // Only errors are recorded FULL // All replies are recorded }; class RedisReplyBuilder; } // namespace facade ================================================ FILE: src/facade/reply_payload.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "base/function2.hpp" #include "facade/facade_types.h" namespace facade { class SinkReplyBuilder; namespace payload { // SendError (msg, type) using Error = std::unique_ptr>; using Null = std::nullptr_t; // SendNull or SendNullArray struct CollectionPayload; struct SimpleString : public std::string {}; // SendSimpleString struct BulkString : public std::string {}; // SendBulkString using Payload = std::variant>; #ifdef __linux__ static_assert(sizeof(Payload) == 40); #endif struct CollectionPayload { CollectionPayload(unsigned _len, CollectionType _type) : len{_len}, type{_type} { arr.reserve(type == CollectionType::MAP ? len * 2 : len); } unsigned len; CollectionType type; std::vector arr; }; inline Error make_error(std::string_view msg, std::string_view type = "") { return std::make_unique>(msg, type); } inline Payload make_simple_or_noreply(std::string_view resp) { if (resp.empty()) return std::monostate{}; else return SimpleString{std::string(resp)}; } } // namespace payload } // namespace facade ================================================ FILE: src/facade/resp_expr.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/resp_expr.h" #include "base/logging.h" namespace facade { void FillBackedArgs(const RespVec& src, cmn::BackedArguments* dest) { auto map = [](const RespExpr& expr) { return expr.GetView(); }; auto range = base::it::Transform(map, base::it::Range(src.begin(), src.end())); dest->Assign(range.begin(), range.end(), src.size()); } } // namespace facade ================================================ FILE: src/facade/resp_expr.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include "facade/facade_types.h" namespace facade { class RespExpr { public: using Buffer = absl::Span; enum Type : uint8_t { STRING, ARRAY, INT64, DOUBLE, NIL, NIL_ARRAY, ERROR }; using Vec = std::vector; Type type; bool has_support; // whether pointers in this item are supported by the external storage. std::variant u; RespExpr(Type t = NIL) : type(t), has_support(false) { } static Buffer buffer(std::string* s) { return Buffer{reinterpret_cast(s->data()), s->size()}; } std::string_view GetView() const { Buffer buffer = GetBuf(); return {reinterpret_cast(buffer.data()), buffer.size()}; } std::string GetString() const { return std::string(GetView()); } Buffer GetBuf() const { return std::get(u); } const Vec& GetVec() const { return *std::get(u); } std::optional GetInt() const { return std::holds_alternative(u) ? std::make_optional(std::get(u)) : std::nullopt; } size_t UsedMemory() const { return 0; } static const char* TypeName(Type t); }; using RespVec = RespExpr::Vec; using RespSpan = absl::Span; inline std::string_view ToSV(RespExpr::Buffer buf) { return std::string_view{reinterpret_cast(buf.data()), buf.size()}; } void FillBackedArgs(const RespVec& src, cmn::BackedArguments* dest); } // namespace facade namespace std { ostream& operator<<(ostream& os, const facade::RespExpr& e); ostream& operator<<(ostream& os, facade::RespSpan rspan); } // namespace std ================================================ FILE: src/facade/resp_expr_test_utils.cc ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/resp_expr_test_utils.h" #include #include namespace facade { RespExpr RespExprBuilder::BuildExpr(const RESPObj& obj) { RespExpr expr{RespExpr::NIL}; switch (obj.GetType()) { case RESPObj::Type::INTEGER: { expr.type = RespExpr::INT64; expr.u = obj.As().value(); break; } case RESPObj::Type::DOUBLE: { expr.type = RespExpr::DOUBLE; expr.u = obj.As().value(); break; } case RESPObj::Type::NIL: { expr.type = RespExpr::NIL; break; } case RESPObj::Type::ERROR: { expr.type = RespExpr::ERROR; SetStringPayload(obj, &expr); break; } case RESPObj::Type::STRING: case RESPObj::Type::REPLY_STATUS: { expr.type = RespExpr::STRING; SetStringPayload(obj, &expr); break; } case RESPObj::Type::ARRAY: case RESPObj::Type::MAP: case RESPObj::Type::SET: { auto arr = obj.As(); if (arr.has_value()) { // Check if this is a null array (elements == SIZE_MAX which represents -1) if (arr->Size() == SIZE_MAX) { expr.type = RespExpr::NIL_ARRAY; expr.u.emplace(nullptr); } else { expr.type = RespExpr::ARRAY; auto vec = std::make_unique(); vec->reserve(arr->Size()); for (size_t i = 0; i < arr->Size(); ++i) { vec->push_back(BuildExpr((*arr)[i])); } expr.u = vec.get(); owned_arrays_.emplace_back(std::move(vec)); expr.has_support = true; } } break; } } return expr; } void RespExprBuilder::SetStringPayload(const RESPObj& obj, RespExpr* expr) { auto sv = obj.As().value_or(std::string_view{}); // Copy the string data so we don't hold references into zmalloc-allocated // hiredis replies. The replies can then be freed on their allocating thread. auto owned = std::make_unique(sv.size()); memcpy(owned.get(), sv.data(), sv.size()); expr->u = RespExpr::Buffer{reinterpret_cast(owned.get()), sv.size()}; expr->has_support = true; owned_strings_.emplace_back(std::move(owned)); } } // namespace facade ================================================ FILE: src/facade/resp_expr_test_utils.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include "facade/resp_expr.h" #include "facade/resp_parser.h" namespace facade { class RespExprBuilder { public: RespExpr BuildExpr(const RESPObj& obj); void Clear() { owned_arrays_.clear(); // Note: owned_strings_ is NOT cleared here because test code may still hold // string_view/Buffer references to data from prior ParseResponse calls // (e.g., SHA values, DUMP payloads). This mirrors the old behavior where // tmp_str_vec_ accumulated across calls within a test. } private: void SetStringPayload(const RESPObj& obj, RespExpr* expr); std::vector> owned_arrays_; // Own copies of string data so we don't hold references to zmalloc-allocated // hiredis replies (which must be freed on the same thread they were allocated). std::vector> owned_strings_; }; } // namespace facade ================================================ FILE: src/facade/resp_parser.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/resp_parser.h" #include #include "base/logging.h" extern "C" { #include "redis/hiredis.h" } namespace facade { RESPParser::RESPParser() { reader_ = redisReaderCreate(); } RESPObj::RESPObj(RESPObj&& other) noexcept : reply_(other.reply_), needs_to_free_(other.needs_to_free_) { other.reply_ = nullptr; other.needs_to_free_ = false; } RESPObj& RESPObj::operator=(RESPObj&& other) noexcept { std::swap(needs_to_free_, other.needs_to_free_); std::swap(reply_, other.reply_); return *this; } RESPObj::~RESPObj() { if (needs_to_free_) freeReplyObject(reply_); } RESPObj::Type RESPObj::GetType() const { DCHECK(reply_); return static_cast(reply_->type); } size_t RESPObj::Size() const { if (!reply_) return 0; Type type = GetType(); return (type == Type::ARRAY || type == Type::MAP || type == Type::SET) ? reply_->elements : 1; } std::optional RESPParser::Feed(const char* data, size_t len) { int status = REDIS_OK; if (len != 0) { // if no new data we check is previoud data produced a reply status = redisReaderFeed(reader_, data, len); if (status != REDIS_OK) { LOG(ERROR) << "RESP parser error: " << status << " description: " << reader_->errstr << " data: " << std::string_view{data, len}; return std::nullopt; } } void* reply_obj = nullptr; status = redisReaderGetReply(reader_, &reply_obj); if (status != REDIS_OK) { LOG(ERROR) << "RESP parser error: " << status << " description: " << reader_->errstr << " data: " << data; return std::nullopt; } return RESPObj(static_cast(reply_obj), reply_obj != nullptr); } std::ostream& operator<<(std::ostream& os, const RESPObj& obj) { if (obj.Empty()) { os << "nullptr RESPObj"; return os; } switch (obj.GetType()) { // because we check type we don't expect As to return nullopt here case RESPObj::Type::INTEGER: { os << *obj.As(); break; } case RESPObj::Type::DOUBLE: { os << *obj.As(); break; } case RESPObj::Type::ARRAY: { os << *obj.As(); break; } case RESPObj::Type::MAP: [[fallthrough]]; case RESPObj::Type::SET: { os << *obj.As(); break; } case RESPObj::Type::STRING: [[fallthrough]]; case RESPObj::Type::NIL: [[fallthrough]]; case RESPObj::Type::ERROR: [[fallthrough]]; case RESPObj::Type::REPLY_STATUS: { os << *obj.As(); break; } default: os << "Unknown RESPObj type: " << static_cast(obj.GetType()); } return os; } std::ostream& operator<<(std::ostream& os, const RESPArray& arr) { os << "["; for (int64_t i = 0; i < (int64_t)arr.Size() - 1; ++i) { os << arr[i] << ", "; } os << arr[arr.Size() - 1] << "]"; return os; } } // namespace facade ================================================ FILE: src/facade/resp_parser.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "io/io.h" extern "C" { #include "redis/hiredis.h" } namespace facade { class RESPArray; class RESPIterator; class RESPObj { public: enum class Type { STRING = REDIS_REPLY_STRING, ARRAY = REDIS_REPLY_ARRAY, INTEGER = REDIS_REPLY_INTEGER, NIL = REDIS_REPLY_NIL, REPLY_STATUS = REDIS_REPLY_STATUS, DOUBLE = REDIS_REPLY_DOUBLE, ERROR = REDIS_REPLY_ERROR, MAP = REDIS_REPLY_MAP, SET = REDIS_REPLY_SET, }; RESPObj() = default; RESPObj(redisReply* reply, bool needs_to_free) : reply_(reply), needs_to_free_(needs_to_free) { } // TODO remove copy ctor, because it is not a deep copy RESPObj(const RESPObj& other) : reply_(other.reply_), needs_to_free_(false) { } RESPObj& operator=(const RESPObj& other) = delete; RESPObj(RESPObj&& other) noexcept; RESPObj& operator=(RESPObj&& other) noexcept; ~RESPObj(); bool Empty() const { return reply_ == nullptr; } size_t Size() const; Type GetType() const; template std::optional As() const; private: redisReply* reply_ = nullptr; bool needs_to_free_ = true; }; class RESPArray { public: RESPArray(redisReply* arr_obj = nullptr) : arr_obj_(arr_obj) { } size_t Size() const { return arr_obj_->elements; } bool Empty() const { return Size() == 0; } RESPObj operator[](size_t index) const { return RESPObj(arr_obj_->element[index], false); } private: redisReply* arr_obj_ = nullptr; }; class RESPParser { public: RESPParser(); ~RESPParser() { redisReaderFree(reader_); } std::optional Feed(const char* data, size_t len); size_t BufferPos() const { return reader_->pos; } private: redisReader* reader_; }; std::ostream& operator<<(std::ostream& os, const RESPObj& obj); std::ostream& operator<<(std::ostream& os, const RESPArray& arr); class RESPIterator { public: RESPIterator() = default; RESPIterator(const RESPObj& obj) : obj_(obj) { } RESPIterator(RESPIterator&&) = default; RESPIterator& operator=(RESPIterator&&) = default; bool HasNext() const { return index_ < obj_.Size(); } bool HasError() const { return index_ == std::numeric_limits::max(); } // Consume next values and return as tuple or single value // if extraction fails, set error state template auto Next() { std::conditional_t> res{}; bool success = true; if constexpr (sizeof...(Ts) == 0) { success = Check(&res); } else { success = std::apply([this](auto&... args) { return Check(&args...); }, res); } SetError(!success); return res; } // increase index only if all args are successfully extracted template bool Check(Arg* arg, Args*... args) { auto tmp_index = index_; if (index_ + sizeof...(Args) < obj_.Size()) { if (auto arr = obj_.As(); arr.has_value()) { if (GetEntry(*arr, index_++, arg) && (GetEntry(*arr, index_++, args) && ...)) { return true; } } else if (auto val = obj_.As(); val.has_value()) { assert(sizeof...(Args) == 0 && index_ == 0); *arg = std::move(*val); return true; } } index_ = tmp_index; return false; } void SetError(bool set = true) { if (set) index_ = std::numeric_limits::max(); } private: template bool GetEntry(const RESPArray& arr, size_t idx, Arg* arg) { if (auto val = arr[idx].As(); val.has_value()) { *arg = std::move(*val); return true; } return false; } private: RESPObj obj_; size_t index_ = 0; }; template std::optional RESPObj::As() const { if (!reply_) { return std::nullopt; } if constexpr (std::is_constructible_v) { if (reply_->type == REDIS_REPLY_STRING || reply_->type == REDIS_REPLY_ERROR || reply_->type == REDIS_REPLY_STATUS) { return T{std::string_view{reply_->str, reply_->len}}; } else if (reply_->type == REDIS_REPLY_NIL) { return T{std::string_view("NIL")}; } } else if constexpr (std::is_integral_v) { if (reply_->type == REDIS_REPLY_INTEGER) { return static_cast(reply_->integer); } } else if constexpr (std::is_floating_point_v) { if (reply_->type == REDIS_REPLY_DOUBLE) { return static_cast(reply_->dval); } } else if constexpr (std::is_same_v) { // MAP and SET use the same elements/element structure as ARRAY in hiredis if (reply_->type == REDIS_REPLY_ARRAY || reply_->type == REDIS_REPLY_MAP || reply_->type == REDIS_REPLY_SET) { return RESPArray(reply_); } } else if constexpr (std::is_same_v) { return RESPObj(reply_, false); } else if constexpr (std::is_same_v) { return RESPIterator(RESPObj(reply_, false)); } // TODO add other types and errors processing return std::nullopt; } } // namespace facade ================================================ FILE: src/facade/resp_parser_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/resp_parser.h" #include #include "base/gtest.h" #include "base/logging.h" using namespace testing; using namespace std; namespace facade { class RESPParserTest : public testing::Test { protected: static void SetUpTestSuite() { init_zmalloc_threadlocal(mi_heap_get_backing()); } }; TEST_F(RESPParserTest, BaseRespTypesTest) { using Fields = std::map; using Docs = std::map; std::string msg1 = "*17\r\n:8\r\n$2\r\ns0\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "0\r\n$2\r\ns3\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "3\r\n$2\r\ns7\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "7\r\n$2\r\ns8\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "8\r\n$2\r\ns4\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "4\r\n$2\r\ns9\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest 9\r\n"; std::string msg2 = "$2\r\ns1\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "1\r\n$2\r\ns5\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest 5\r\n"; RESPParser reader; auto reply = reader.Feed(msg1.c_str(), msg1.size()); ASSERT_TRUE(reply->Empty()); reply = reader.Feed(msg2.c_str(), msg2.size()); ASSERT_FALSE(reply->Empty()); EXPECT_EQ(reply->GetType(), RESPObj::Type::ARRAY); auto array = *reply->As(); EXPECT_GE(array.Size(), 1); EXPECT_EQ(array[0].GetType(), RESPObj::Type::INTEGER); Docs search_results; for (size_t i = 1; i < array.Size(); i += 2) { auto& fields = search_results[*array[i].As()]; auto field_array = *array[i + 1].As(); for (size_t j = 0; j < field_array.Size(); j += 2) { std::string field_name = *field_array[j].As(); std::string field_value = *field_array[j + 1].As(); fields[field_name] = field_value; } } EXPECT_EQ(search_results.size(), 8); EXPECT_EQ(search_results["s0"]["title"], "test 0"); EXPECT_EQ(search_results["s1"]["title"], "test 1"); EXPECT_EQ(search_results["s3"]["title"], "test 3"); EXPECT_EQ(search_results["s4"]["title"], "test 4"); EXPECT_EQ(search_results["s5"]["title"], "test 5"); EXPECT_EQ(search_results["s7"]["title"], "test 7"); EXPECT_EQ(search_results["s8"]["title"], "test 8"); EXPECT_EQ(search_results["s9"]["title"], "test 9"); } TEST_F(RESPParserTest, RESPIteratorTest) { using Fields = std::map; using Docs = std::map; std::string msg1 = "*17\r\n:8\r\n$2\r\ns0\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "0\r\n$2\r\ns3\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "3\r\n$2\r\ns7\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "7\r\n$2\r\ns8\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "8\r\n$2\r\ns4\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "4\r\n$2\r\ns9\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest 9\r\n"; std::string msg2 = "$2\r\ns1\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest " "1\r\n$2\r\ns5\r\n*2\r\n$5\r\ntitle\r\n$6\r\ntest 5\r\n"; RESPParser reader; auto reply = reader.Feed(msg1.c_str(), msg1.size()); ASSERT_TRUE(reply->Empty()); reply = reader.Feed(msg2.c_str(), msg2.size()); ASSERT_FALSE(reply->Empty()); RESPIterator it(*reply); EXPECT_EQ(it.Next(), 8); Docs search_results; while (it.HasNext()) { auto [doc_id, field_it] = it.Next(); auto& fields = search_results[std::move(doc_id)]; while (field_it.HasNext()) { auto [field_name, field_value] = field_it.Next(); fields.emplace(field_name, field_value); } } EXPECT_EQ(search_results.size(), 8); EXPECT_EQ(search_results["s0"]["title"], "test 0"); EXPECT_EQ(search_results["s1"]["title"], "test 1"); EXPECT_EQ(search_results["s3"]["title"], "test 3"); EXPECT_EQ(search_results["s4"]["title"], "test 4"); EXPECT_EQ(search_results["s5"]["title"], "test 5"); EXPECT_EQ(search_results["s7"]["title"], "test 7"); EXPECT_EQ(search_results["s8"]["title"], "test 8"); EXPECT_EQ(search_results["s9"]["title"], "test 9"); } } // namespace facade ================================================ FILE: src/facade/resp_srv_parser.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/resp_srv_parser.h" #include #include #include "base/logging.h" #include "common/backed_args.h" #include "common/heap_size.h" namespace facade { using namespace std; auto RespSrvParser::Parse(Buffer str, uint32_t* consumed, cmn::BackedArguments* args) -> Result { DCHECK(!str.empty()); *consumed = 0; DVLOG(2) << "Parsing: " << absl::CHexEscape(string_view{reinterpret_cast(str.data()), str.size()}); if (state_ == CMD_COMPLETE_S) { args->clear(); buf_stash_.clear(); if (str[0] == '*') { // We recognized a non-INLINE state, starting with '*' str.remove_prefix(1); *consumed += 1; state_ = ARRAY_LEN_S; if (str.empty()) return INPUT_PENDING; } else { // INLINE mode, aka PING\n state_ = INLINE_S; } } ResultConsumed resultc{OK, 0}; do { switch (state_) { case ARRAY_LEN_S: resultc = ConsumeArrayLen(str, args); break; case PARSE_ARG_TYPE: if (str[0] != '$') // server side only supports bulk strings. return BAD_BULKLEN; resultc.second = 1; state_ = PARSE_ARG_S; break; case PARSE_ARG_S: resultc = ParseArg(str, args); break; case INLINE_S: resultc = ParseInline(str, args); break; case BULK_STR_S: resultc = ConsumeBulk(str, args); break; case SLASH_N_S: if (str[0] != '\n') { resultc.first = BAD_STRING; } else { resultc = {OK, 1}; HandleFinishArg(); } break; default: LOG(FATAL) << "Unexpected state " << int(state_); } *consumed += resultc.second; str.remove_prefix(exchange(resultc.second, 0)); } while (state_ != CMD_COMPLETE_S && resultc.first == OK && !str.empty()); if (state_ != CMD_COMPLETE_S) { if (resultc.first == OK) { resultc.first = INPUT_PENDING; } if (resultc.first == INPUT_PENDING) { if (!str.empty()) { LOG(DFATAL) << "Did not consume all input: " << absl::CHexEscape({reinterpret_cast(str.data()), str.size()}) << ", state: " << int(state_) << " smallbuf: " << absl::CHexEscape( {reinterpret_cast(small_buf_.data()), small_len_}); } } return resultc.first; } return resultc.first; } auto RespSrvParser::ParseInline(Buffer str, cmn::BackedArguments* args) -> ResultConsumed { DCHECK(!str.empty()); const uint8_t* ptr = str.begin(); const uint8_t* end = str.end(); const uint8_t* token_start = ptr; auto find_token_end = [](const uint8_t* ptr, const uint8_t* end) { while (ptr != end && *ptr > 32) ++ptr; return ptr; }; if (!buf_stash_.empty()) { ptr = find_token_end(ptr, end); size_t len = ptr - token_start; buf_stash_.append(reinterpret_cast(token_start), len); if (ptr == end) { return {INPUT_PENDING, ptr - token_start}; } args->PushArg(buf_stash_); buf_stash_.clear(); } while (ptr != end) { // For inline input we only require \n. if (*ptr == '\n') { if (args->empty()) { ++ptr; continue; // skip empty line } break; } if (*ptr <= 32) { // skip ws/control chars ++ptr; continue; } // token start DCHECK(buf_stash_.empty()); token_start = ptr; ptr = find_token_end(ptr, end); if (ptr != end) { args->PushArg( string_view{reinterpret_cast(token_start), size_t(ptr - token_start)}); } } uint32_t last_consumed = ptr - str.data(); if (ptr == end) { // we have not finished parsing. bool is_broken_token = ptr[-1] > 32; // we stopped in the middle of the token. if (is_broken_token) { DCHECK(buf_stash_.empty()); buf_stash_.append(reinterpret_cast(token_start), size_t(ptr - token_start)); } else if (args->empty()) { state_ = CMD_COMPLETE_S; // have not found anything besides whitespace. } return {INPUT_PENDING, last_consumed}; } DCHECK_EQ('\n', *ptr); ++last_consumed; // consume \n as well. state_ = CMD_COMPLETE_S; return {OK, last_consumed}; } // Parse lines like:'$5\r\n' or '*2\r\n'. The first character is already consumed by the caller. auto RespSrvParser::ParseLen(Buffer str, int64_t* res) -> ResultConsumed { DCHECK(!str.empty()); const char* s = reinterpret_cast(str.data()); const char* pos = reinterpret_cast(memchr(s, '\n', str.size())); if (!pos) { if (str.size() + small_len_ < small_buf_.size()) { memcpy(&small_buf_[small_len_], str.data(), str.size()); small_len_ += str.size(); return {INPUT_PENDING, str.size()}; } LOG(WARNING) << "Unexpected format " << string_view{s, str.size()}; return ResultConsumed{BAD_ARRAYLEN, 0}; } unsigned consumed = pos - s + 1; if (small_len_ > 0) { if (small_len_ + consumed >= small_buf_.size()) { return ResultConsumed{BAD_ARRAYLEN, consumed}; } memcpy(&small_buf_[small_len_], str.data(), consumed); small_len_ += consumed; s = small_buf_.data(); pos = s + small_len_ - 1; small_len_ = 0; } if (pos[-1] != '\r') { return {BAD_ARRAYLEN, consumed}; } // Skip 2 last characters (\r\n). string_view len_token{s, size_t(pos - 1 - s)}; bool success = absl::SimpleAtoi(len_token, res); if (success && *res >= -1) { return ResultConsumed{OK, consumed}; } LOG(ERROR) << "Failed to parse len " << absl::CHexEscape(len_token) << " " << absl::CHexEscape(string_view{reinterpret_cast(str.data()), str.size()}) << " " << consumed << " " << int(s == small_buf_.data()); return ResultConsumed{BAD_ARRAYLEN, consumed}; } auto RespSrvParser::ConsumeArrayLen(Buffer str, cmn::BackedArguments* args) -> ResultConsumed { int64_t len; ResultConsumed res = ParseLen(str, &len); if (res.first != OK) { return res; } if (len <= 0) { return {BAD_ARRAYLEN, res.second}; } if (len > max_arr_len_) { LOG(WARNING) << "Multibulk len is too large " << len; return {BAD_ARRAYLEN, res.second}; } state_ = PARSE_ARG_TYPE; arg_len_ = len; args->Reserve(len, 0); return {OK, res.second}; } auto RespSrvParser::ParseArg(Buffer str, cmn::BackedArguments* args) -> ResultConsumed { DCHECK(!str.empty()); int64_t len; ResultConsumed res = ParseLen(str, &len); if (res.first != OK) { return res; } if (len > 0 && static_cast(len) > max_bulk_len_) { LOG_EVERY_T(WARNING, 1) << "Threshold reached with bulk len: " << len << ", consider increasing max_bulk_len"; return {BAD_BULKLEN, res.second}; } if (len < 0) { return {BAD_BULKLEN, res.second}; } bulk_len_ = len; state_ = BULK_STR_S; args->PushArg(size_t(len)); return {OK, res.second}; } auto RespSrvParser::ConsumeBulk(Buffer str, cmn::BackedArguments* args) -> ResultConsumed { DCHECK_EQ(small_len_, 0); uint32_t consumed = 0; if (str.size() >= bulk_len_) { consumed = bulk_len_; if (bulk_len_) { char* last_arg = args->data(args->size() - 1); // Get pointer to last argument. DCHECK_GE(args->elem_len(args->size() - 1), bulk_len_); char* start = last_arg + (args->elem_len(args->size() - 1) - bulk_len_); memcpy(start, str.data(), bulk_len_); str.remove_prefix(exchange(bulk_len_, 0)); } if (str.size() >= 2) { if (str[0] != '\r' || str[1] != '\n') { return {BAD_STRING, consumed}; } HandleFinishArg(); return {OK, consumed + 2}; } if (str.size() == 1) { if (str[0] != '\r') { return {BAD_STRING, consumed}; } state_ = SLASH_N_S; consumed++; } return {INPUT_PENDING, consumed}; } DCHECK(bulk_len_); DCHECK_GE(args->elem_len(args->size() - 1), bulk_len_); size_t len = std::min(str.size(), bulk_len_); char* last_arg = args->data(args->size() - 1); // Get pointer to last argument. char* start = last_arg + (args->elem_len(args->size() - 1) - bulk_len_); memcpy(start, str.data(), len); consumed = len; bulk_len_ -= len; return {INPUT_PENDING, consumed}; } void RespSrvParser::HandleFinishArg() { state_ = (--arg_len_ == 0) ? CMD_COMPLETE_S : PARSE_ARG_TYPE; small_len_ = 0; } size_t RespSrvParser::UsedMemory() const { return cmn::HeapSize(buf_stash_); } } // namespace facade ================================================ FILE: src/facade/resp_srv_parser.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "common/backed_args.h" namespace facade { /** * @brief RESP server-side parser. */ class RespSrvParser { public: enum Result : uint8_t { OK, INPUT_PENDING, BAD_ARRAYLEN, BAD_BULKLEN, BAD_STRING, }; using Buffer = absl::Span; explicit RespSrvParser(uint32_t max_arr_len = UINT32_MAX, uint32_t max_bulk_len = UINT32_MAX) : max_arr_len_(max_arr_len), max_bulk_len_(max_bulk_len) { } /** * @brief Parses str into res. "consumed" stores number of bytes consumed from str. * * A caller should not invalidate str if the parser returns RESP_OK as long as he continues * accessing res. However, if parser returns INPUT_PENDING a caller may discard consumed * part of str because parser caches the intermediate state internally according to 'consumed' * result. */ Result Parse(Buffer str, uint32_t* consumed, cmn::BackedArguments* dest); size_t parselen_hint() const { return bulk_len_; } size_t UsedMemory() const; private: using ResultConsumed = std::pair; // Skips the first character (*). ResultConsumed ConsumeArrayLen(Buffer str, cmn::BackedArguments* args); ResultConsumed ParseArg(Buffer str, cmn::BackedArguments* args); ResultConsumed ConsumeBulk(Buffer str, cmn::BackedArguments* args); ResultConsumed ParseInline(Buffer str, cmn::BackedArguments* args); ResultConsumed ParseLen(Buffer str, int64_t* res); void HandleFinishArg(); enum State : uint8_t { INLINE_S, ARRAY_LEN_S, PARSE_ARG_TYPE, PARSE_ARG_S, // Parse string\r\n BULK_STR_S, SLASH_N_S, CMD_COMPLETE_S, }; State state_ = CMD_COMPLETE_S; uint8_t small_len_ = 0; uint32_t bulk_len_ = 0, arg_len_ = 0; uint32_t max_arr_len_; uint32_t max_bulk_len_; std::string buf_stash_; std::array small_buf_; }; } // namespace facade ================================================ FILE: src/facade/resp_srv_parser_test.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/resp_srv_parser.h" #include #include #include "base/gtest.h" #include "base/logging.h" using namespace testing; using namespace std; namespace facade { // Custom printer for RespSrvParser::Result to make test output more readable void PrintTo(const RespSrvParser::Result& result, std::ostream* os) { switch (result) { case RespSrvParser::OK: *os << "OK"; break; case RespSrvParser::INPUT_PENDING: *os << "INPUT_PENDING"; break; case RespSrvParser::BAD_ARRAYLEN: *os << "BAD_ARRAYLEN"; break; case RespSrvParser::BAD_BULKLEN: *os << "BAD_BULKLEN"; break; case RespSrvParser::BAD_STRING: *os << "BAD_STRING"; break; default: *os << "UNKNOWN(" << static_cast(result) << ")"; break; } } class RespSrvParserTest : public testing::Test { protected: RespSrvParser::Result Parse(std::string_view str); RespSrvParser parser_; cmn::BackedArguments args_; uint32_t consumed_; }; RespSrvParser::Result RespSrvParserTest::Parse(std::string_view str) { RespSrvParser::Buffer buf{reinterpret_cast(str.data()), str.size()}; return parser_.Parse(buf, &consumed_, &args_); } TEST_F(RespSrvParserTest, Inline) { const char kCmd1[] = "KEY VAL\r\n"; ASSERT_EQ(RespSrvParser::OK, Parse(kCmd1)); EXPECT_EQ(strlen(kCmd1), consumed_); EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("KEY")); EXPECT_EQ(3, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(" FOO ")); EXPECT_EQ(5, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(" BAR")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse(" \r\n ")); EXPECT_EQ(3, consumed_); EXPECT_THAT(args_, ElementsAre("KEY", "FOO", "BAR")); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(" 1 2")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(" 45")); EXPECT_EQ(3, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\r\n")); EXPECT_EQ(2, consumed_); EXPECT_THAT(args_, ElementsAre("1", "2", "45")); // Empty queries return INPUT_PENDING. EXPECT_EQ(RespSrvParser::INPUT_PENDING, Parse("\r\n")); EXPECT_EQ(2, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("_\r\n")); EXPECT_THAT(args_, ElementsAre("_")); } TEST_F(RespSrvParserTest, Multi1) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("*1\r\n")); EXPECT_EQ(4, consumed_); EXPECT_EQ(0, parser_.parselen_hint()); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("$4\r\n")); EXPECT_EQ(4, consumed_); EXPECT_EQ(4, parser_.parselen_hint()); ASSERT_EQ(RespSrvParser::OK, Parse("PING\r\n")); EXPECT_EQ(6, consumed_); EXPECT_EQ(0, parser_.parselen_hint()); EXPECT_THAT(args_, ElementsAre("PING")); } TEST_F(RespSrvParserTest, Multi2) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("*1\r\n$")); EXPECT_EQ(5, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("4\r\nMSET")); EXPECT_EQ(7, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\r\n*2\r\n")); EXPECT_EQ(2, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("*2\r\n$3\r\nKEY\r\n$3\r\nVAL")); EXPECT_EQ(20, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\r\n")); EXPECT_EQ(2, consumed_); EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); } TEST_F(RespSrvParserTest, Multi3) { const char kFirst[] = "*3\r\n$3\r\nSET\r\n$16\r\nkey:"; const char kSecond[] = "000002273458\r\n$3\r\nVXK"; ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(kFirst)); ASSERT_EQ(strlen(kFirst), consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(kSecond)); ASSERT_EQ(strlen(kSecond), consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\r\n*3\r\n$3\r\nSET")); ASSERT_EQ(2, consumed_); EXPECT_THAT(args_, ElementsAre("SET", "key:000002273458", "VXK")); } TEST_F(RespSrvParserTest, InvalidMult1) { ASSERT_EQ(RespSrvParser::BAD_BULKLEN, Parse("*2\r\n$3\r\nFOO\r\nBAR\r\n")); } TEST_F(RespSrvParserTest, Empty) { ASSERT_EQ(RespSrvParser::OK, Parse("*2\r\n$0\r\n\r\n$0\r\n\r\n")); } TEST_F(RespSrvParserTest, LargeBulk) { string_view prefix("*1\r\n$1024\r\n"); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(prefix)); ASSERT_EQ(prefix.size(), consumed_); ASSERT_GE(parser_.parselen_hint(), 1024); string half(512, 'a'); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(half)); ASSERT_EQ(512, consumed_); ASSERT_GE(parser_.parselen_hint(), 512); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(half)); ASSERT_EQ(512, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("\r")); ASSERT_EQ(1, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\n")); EXPECT_EQ(1, consumed_); string part1 = absl::StrCat(prefix, half); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(part1)); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(half)); ASSERT_EQ(RespSrvParser::OK, Parse("\r\n")); prefix = "*1\r\n$27000000\r\n"; ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(prefix)); ASSERT_EQ(prefix.size(), consumed_); string chunk(1000000, 'a'); for (unsigned i = 0; i < 27; ++i) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse(chunk)); ASSERT_EQ(chunk.size(), consumed_); } ASSERT_EQ(RespSrvParser::OK, Parse("\r\n")); ASSERT_EQ(args_.size(), 1); EXPECT_EQ(27000000u, args_[0].size()); } TEST_F(RespSrvParserTest, Eol) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("*1\r")); EXPECT_EQ(3, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("\n$5\r\n")); EXPECT_EQ(5, consumed_); } TEST_F(RespSrvParserTest, BulkSplit) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("*1\r\n$4\r\nSADD\r")); ASSERT_EQ(13, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\n")); } TEST_F(RespSrvParserTest, InlineSplit) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("\n")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("\nPING\n\n")); EXPECT_EQ(6, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("\n")); EXPECT_EQ(1, consumed_); ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("P")); ASSERT_EQ(RespSrvParser::OK, Parse("ING\n")); } TEST_F(RespSrvParserTest, InlineReset) { ASSERT_EQ(RespSrvParser::INPUT_PENDING, Parse("\t \r\n")); EXPECT_EQ(4, consumed_); ASSERT_EQ(RespSrvParser::OK, Parse("*1\r\n$3\r\nfoo\r\n")); EXPECT_EQ(13, consumed_); } } // namespace facade ================================================ FILE: src/facade/resp_validator.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #include #include #include "base/flags.h" #include "base/init.h" #include "facade/redis_parser.h" #include "io/io.h" using namespace facade; using namespace std; ABSL_FLAG(string, input, "", "If not empty - reads data from the file instead of stdin. "); // Validates RESP3 server responses by using RespParser. // Server traffic can be recorded using: // tcpflow -i any port 6379 -o /tmp/tcp_flow int main(int argc, char* argv[]) { MainInitGuard guard(&argc, &argv); RedisParser parser(RedisParser::Mode::CLIENT); RedisParser::Result parse_result = RedisParser::OK; char buf[1024]; istream* input_stream = &cin; if (!absl::GetFlag(FLAGS_input).empty()) { input_stream = new ifstream(absl::GetFlag(FLAGS_input), ios::binary); if (!input_stream->good()) { cerr << "Failed to open input file: " << absl::GetFlag(FLAGS_input) << "\n"; return -1; } } size_t len = 0, offset = 0; do { input_stream->read(buf + len, sizeof(buf) - len); size_t read = input_stream->gcount(); if (read == 0) { if (parse_result != RedisParser::OK) { cerr << "unexpected: " << parse_result << "\n"; } break; } DVLOG(1) << "Read " << read << " bytes from input stream, offset: " << offset; len += read; RespExpr::Vec args; uint32_t consumed = 0; char* next = buf; while (len) { string_view sv{next, len}; parse_result = parser.Parse(io::Buffer(sv), &consumed, &args); if (parse_result != RedisParser::OK && parse_result != RedisParser::INPUT_PENDING) { cerr << "Parse error: " << int(parse_result) << " at offset " << offset << " when parsing: " << absl::CHexEscape({reinterpret_cast(next), len}) << "\n"; return -1; } if (consumed == 0) { // not enough data to parse. DVLOG(1) << "No data consumed, waiting for more input."; memcpy(buf, next, len); // move the remaining data to the start of the buffer. break; } len -= consumed; next += consumed; offset += consumed; } } while (!input_stream->eof()); if (input_stream != &cin) { delete input_stream; } cout << "LGTM\n"; return 0; } ================================================ FILE: src/facade/service_interface.cc ================================================ // Copyright 2024, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "facade/service_interface.h" #include #include "facade/facade_types.h" namespace facade { std::string ServiceInterface::ContextInfo::Format() const { char buf[16] = {0}; std::string res = absl::StrCat("db=", db_index); unsigned index = 0; if (async_dispatch) buf[index++] = 'a'; if (conn_closing) buf[index++] = 't'; if (subscribers) buf[index++] = 'P'; if (blocked) buf[index++] = 'b'; if (index) absl::StrAppend(&res, " flags=", buf); return res; } DispatchResult ServiceInterface::DispatchCommandSimple(ParsedCommand* cmd, AsyncPreference mode) { return DispatchCommand(ParsedArgs{*cmd}, cmd, mode); } } // namespace facade ================================================ FILE: src/facade/service_interface.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include "facade/facade_types.h" #include "facade/parsed_command.h" #include "util/fiber_socket_base.h" namespace util { class HttpListenerBase; } // namespace util namespace facade { class ConnectionContext; class Connection; class SinkReplyBuilder; class MCReplyBuilder; // Controls asynchronicity of command dispatch enum class AsyncPreference : uint8_t { ONLY_SYNC, // Caller supports only synchronous dispatch PREFER_ASYNC, // Prefer async if available ONLY_ASYNC, // Only async execution is possible (command is dispatched in pipeline) }; enum class DispatchResult : uint8_t { OK, OOM, ERROR, WOULD_BLOCK // Returned if ONLY_ASYNC was set, but only synchronous execution is possible }; struct DispatchManyResult { uint32_t processed; // how many commands out of passed were actually processed // whether to account the processed commands in stats. This is needed to consistently // account commands that were included based on squash_stats_latency_lower_limit filter. bool account_in_stats; }; class ServiceInterface { public: virtual ~ServiceInterface() { } virtual DispatchResult DispatchCommand(ParsedArgs args, ParsedCommand* cmd, AsyncPreference) = 0; DispatchResult DispatchCommandSimple(ParsedCommand* cmd, AsyncPreference mode); virtual DispatchManyResult DispatchManyCommands(std::function arg_gen, unsigned count, SinkReplyBuilder* builder, ConnectionContext* cntx) = 0; virtual DispatchResult DispatchMC(ParsedCommand* cmd, AsyncPreference) = 0; virtual ConnectionContext* CreateContext(Connection* owner) = 0; virtual ParsedCommand* AllocateParsedCommand() = 0; virtual void ConfigureHttpHandlers(util::HttpListenerBase* base, bool is_privileged) { } virtual void OnConnectionClose(ConnectionContext* cntx) { } struct ContextInfo { std::string Format() const; unsigned db_index; bool async_dispatch, conn_closing, subscribers, blocked; }; virtual ContextInfo GetContextInfo(ConnectionContext* cntx) const { return {}; } }; } // namespace facade ================================================ FILE: src/facade/socket_utils.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "socket_utils.h" #include #include #ifdef __linux__ #include #include #include "absl/strings/str_cat.h" #include "io/proc_reader.h" #endif namespace { int get_socket_family(int fd) { struct sockaddr_storage ss; socklen_t len = sizeof(ss); if (getsockname(fd, (struct sockaddr*)&ss, &len) == -1) { return -1; // Indicate an error } return ss.ss_family; } } // namespace namespace dfly { // Returns information about the TCP socket state by its descriptor std::string GetSocketInfo(int socket_fd) { if (socket_fd < 0) return "invalid socket"; #ifdef __linux__ struct stat sock_stat; if (fstat(socket_fd, &sock_stat) != 0) { return "could not stat socket"; } io::Result tcp_info; int family = get_socket_family(socket_fd); if (family == AF_INET) { tcp_info = io::ReadTcpInfo(sock_stat.st_ino); } else if (family == AF_INET6) { tcp_info = io::ReadTcp6Info(sock_stat.st_ino); } else { return "unsupported socket family"; } if (!tcp_info) { return "socket not found in /proc/net/tcp or /proc/net/tcp6"; } std::string state_str = io::TcpStateToString(tcp_info->state); if (tcp_info->is_ipv6) { char local_ip[INET6_ADDRSTRLEN], remote_ip[INET6_ADDRSTRLEN]; inet_ntop(AF_INET6, &tcp_info->local_addr6, local_ip, sizeof(local_ip)); inet_ntop(AF_INET6, &tcp_info->remote_addr6, remote_ip, sizeof(remote_ip)); return absl::StrCat("State: ", state_str, ", Local: [", local_ip, "]:", tcp_info->local_port, ", Remote: [", remote_ip, "]:", tcp_info->remote_port, ", Inode: ", tcp_info->inode); } else { char local_ip[INET_ADDRSTRLEN], remote_ip[INET_ADDRSTRLEN]; struct in_addr addr; addr.s_addr = htonl(tcp_info->local_addr); inet_ntop(AF_INET, &addr, local_ip, sizeof(local_ip)); addr.s_addr = htonl(tcp_info->remote_addr); inet_ntop(AF_INET, &addr, remote_ip, sizeof(remote_ip)); return absl::StrCat("State: ", state_str, ", Local: ", local_ip, ":", tcp_info->local_port, ", Remote: ", remote_ip, ":", tcp_info->remote_port, ", Inode: ", tcp_info->inode); } #else return "socket info not available on this platform"; #endif } } // namespace dfly ================================================ FILE: src/facade/socket_utils.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include namespace dfly { // Returns information about the TCP socket state by its descriptor std::string GetSocketInfo(int socket_fd); } // namespace dfly ================================================ FILE: src/facade/tls_helpers.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "tls_helpers.h" #include #ifdef DFLY_USE_SSL #include #endif #include #include #include "base/flags.h" #include "base/logging.h" #include "facade/facade_stats.h" #include "facade/facade_types.h" ABSL_FLAG(std::string, tls_cert_file, "", "cert file for tls connections"); ABSL_FLAG(std::string, tls_key_file, "", "key file for tls connections"); ABSL_FLAG(std::string, tls_ca_cert_file, "", "ca signed certificate to validate tls connections"); ABSL_FLAG(std::string, tls_ca_cert_dir, "", "ca signed certificates directory. Use c_rehash before, read description in " "https://www.openssl.org/docs/man3.0/man1/c_rehash.html"); ABSL_FLAG(std::string, tls_ciphers, "DEFAULT:!MEDIUM", "TLS ciphers configuration for tls1.2"); ABSL_FLAG(std::string, tls_cipher_suites, "", "TLS ciphers configuration for tls1.3"); ABSL_FLAG(bool, tls_prefer_server_ciphers, false, "If true, prefer server ciphers over client ciphers"); ABSL_FLAG(bool, tls_session_caching, false, "If true enables session caching and tickets"); ABSL_FLAG(size_t, tls_session_cache_size, 20 * 1024, "Size of the cache for tls sessions"); ABSL_FLAG(size_t, tls_session_cache_timeout, 300, "Timeout for each session/ticket"); namespace facade { #ifdef DFLY_USE_SSL // Creates the TLS context. Returns nullptr if the TLS configuration is invalid. // To connect: openssl s_client -state -crlf -connect 127.0.0.1:6380 SSL_CTX* CreateSslCntx(TlsContextRole role) { using absl::GetFlag; const auto& tls_key_file = GetFlag(FLAGS_tls_key_file); if (tls_key_file.empty()) { LOG(ERROR) << "To use TLS, a server certificate must be provided with the --tls_key_file flag!"; return nullptr; } SSL_CTX* ctx; if (role == TlsContextRole::SERVER) { ctx = SSL_CTX_new(TLS_server_method()); } else { ctx = SSL_CTX_new(TLS_client_method()); } unsigned mask = SSL_VERIFY_NONE; if (SSL_CTX_use_PrivateKey_file(ctx, tls_key_file.c_str(), SSL_FILETYPE_PEM) != 1) { LOG(ERROR) << "Failed to load TLS key"; return nullptr; } const auto& tls_cert_file = GetFlag(FLAGS_tls_cert_file); if (!tls_cert_file.empty()) { // TO connect with redis-cli you need both tls-key-file and tls-cert-file // loaded. Use `redis-cli --tls -p 6380 --insecure PING` to test if (SSL_CTX_use_certificate_chain_file(ctx, tls_cert_file.c_str()) != 1) { LOG(ERROR) << "Failed to load TLS certificate"; return nullptr; } } const auto tls_ca_cert_file = GetFlag(FLAGS_tls_ca_cert_file); const auto tls_ca_cert_dir = GetFlag(FLAGS_tls_ca_cert_dir); if (!tls_ca_cert_file.empty() || !tls_ca_cert_dir.empty()) { const auto* file = tls_ca_cert_file.empty() ? nullptr : tls_ca_cert_file.data(); const auto* dir = tls_ca_cert_dir.empty() ? nullptr : tls_ca_cert_dir.data(); if (SSL_CTX_load_verify_locations(ctx, file, dir) != 1) { LOG(ERROR) << "Failed to load TLS verify locations (CA cert file or CA cert dir)"; return nullptr; } mask = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT; } if (!GetFlag(FLAGS_tls_ciphers).empty()) { DFLY_SSL_CHECK(1 == SSL_CTX_set_cipher_list(ctx, GetFlag(FLAGS_tls_ciphers).c_str())); } // Relevant only for TLS 1.3 connections. if (!GetFlag(FLAGS_tls_cipher_suites).empty()) { SSL_CTX_set_ciphersuites(ctx, GetFlag(FLAGS_tls_cipher_suites).c_str()); } SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); SSL_CTX_set_verify(ctx, mask, NULL); DFLY_SSL_CHECK(1 == SSL_CTX_set_dh_auto(ctx, 1)); if (GetFlag(FLAGS_tls_prefer_server_ciphers)) { SSL_CTX_set_options(ctx, SSL_OP_CIPHER_SERVER_PREFERENCE); } if (GetFlag(FLAGS_tls_session_caching)) { SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_SERVER); SSL_CTX_sess_set_cache_size(ctx, GetFlag(FLAGS_tls_session_cache_size)); SSL_CTX_set_timeout(ctx, GetFlag(FLAGS_tls_session_cache_timeout)); SSL_CTX_set_session_id_context(ctx, (const unsigned char*)"dragonfly", 9); } SSL_CTX_set_info_callback(ctx, [](const SSL* ssl, int where, int ret) { // When we skip the handshake we never reach this state. if (where & SSL_CB_HANDSHAKE_START) { ++tl_facade_stats->conn_stats.handshakes_started; } // When we skip the handshake, we never reach this state. if (where & SSL_CB_HANDSHAKE_DONE) { ++tl_facade_stats->conn_stats.handshakes_completed; } }); return ctx; } void PrintSSLError() { ERR_print_errors_cb( [](const char* str, size_t len, void* u) { LOG(ERROR) << std::string_view(str, len); return 1; }, nullptr); } #endif } // namespace facade ================================================ FILE: src/facade/tls_helpers.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #ifdef DFLY_USE_SSL #include #endif namespace facade { #ifdef DFLY_USE_SSL enum class TlsContextRole { SERVER, CLIENT }; SSL_CTX* CreateSslCntx(TlsContextRole role); void PrintSSLError(); #define DFLY_SSL_CHECK(condition) \ if (!(condition)) { \ LOG(ERROR) << "OpenSSL Error: " #condition; \ PrintSSLError(); \ exit(17); \ } #endif } // namespace facade ================================================ FILE: src/huff/LICENSE ================================================ BSD License For Zstandard software Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name Facebook, nor Meta, nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: src/huff/README.md ================================================ The code in this folder exposes internal functions that are used by ZSTD. These functions are part of https://github.com/Cyan4973/FiniteStateEntropy project. Since we already link to ZSTD, it is convenient that we get this functionality for free. ================================================ FILE: src/huff/hist.h ================================================ /* ****************************************************************** * hist : Histogram functions * part of Finite State Entropy project * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - FSE source repository : https://github.com/Cyan4973/FiniteStateEntropy * - Public forum : https://groups.google.com/forum/#!forum/lz4c * * This source code is licensed under both the BSD-style license (found in the * LICENSE file in the root directory of this source tree) and the GPLv2 (found * in the COPYING file in the root directory of this source tree). * You may select, at your option, one of the above-listed licenses. ****************************************************************** */ /* --- dependencies --- */ #include /* size_t */ /* --- simple histogram functions --- */ /*! HIST_count(): * Provides the precise count of each byte within a table 'count'. * 'count' is a table of unsigned int, of minimum size (*maxSymbolValuePtr+1). * Updates *maxSymbolValuePtr with actual largest symbol value detected. * @return : count of the most frequent symbol (which isn't identified). * or an error code, which can be tested using HIST_isError(). * note : if return == srcSize, there is only one symbol. */ size_t HIST_count(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize); unsigned HIST_isError(size_t code); /**< tells if a return value is an error code */ /* --- advanced histogram functions --- */ #define HIST_WKSP_SIZE_U32 1024 #define HIST_WKSP_SIZE (HIST_WKSP_SIZE_U32 * sizeof(unsigned)) /** HIST_count_wksp() : * Same as HIST_count(), but using an externally provided scratch buffer. * Benefit is this function will use very little stack space. * `workSpace` is a writable buffer which must be 4-bytes aligned, * `workSpaceSize` must be >= HIST_WKSP_SIZE */ size_t HIST_count_wksp(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, void* workSpace, size_t workSpaceSize); /** HIST_countFast() : * same as HIST_count(), but blindly trusts that all byte values within src are <= *maxSymbolValuePtr. * This function is unsafe, and will segfault if any value within `src` is `> *maxSymbolValuePtr` */ size_t HIST_countFast(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize); /** HIST_countFast_wksp() : * Same as HIST_countFast(), but using an externally provided scratch buffer. * `workSpace` is a writable buffer which must be 4-bytes aligned, * `workSpaceSize` must be >= HIST_WKSP_SIZE */ size_t HIST_countFast_wksp(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, void* workSpace, size_t workSpaceSize); /*! HIST_count_simple() : * Same as HIST_countFast(), this function is unsafe, * and will segfault if any value within `src` is `> *maxSymbolValuePtr`. * It is also a bit slower for large inputs. * However, it does not need any additional memory (not even on stack). * @return : count of the most frequent symbol. * Note this function doesn't produce any error (i.e. it must succeed). */ unsigned HIST_count_simple(unsigned* count, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize); /*! HIST_add() : * Lowest level: just add nb of occurrences of characters from @src into @count. * @count is not reset. @count array is presumed large enough (i.e. 1 KB). @ This function does not need any additional stack memory. */ void HIST_add(unsigned* count, const void* src, size_t srcSize); ================================================ FILE: src/huff/huf.h ================================================ /* ****************************************************************** * huff0 huffman codec, * part of Finite State Entropy library * Copyright (c) Meta Platforms, Inc. and affiliates. * * You can contact the author at : * - Source repository : https://github.com/Cyan4973/FiniteStateEntropy * * This source code is licensed under both the BSD-style license (found in the * LICENSE file in the root directory of this source tree) and the GPLv2 (found * in the COPYING file in the root directory of this source tree). * You may select, at your option, one of the above-listed licenses. ****************************************************************** */ #ifndef HUF_H_298734234 #define HUF_H_298734234 /* *** Dependencies *** */ #include /* size_t */ #include "mem.h" /* U32 */ /* *** Tool functions *** */ #define HUF_BLOCKSIZE_MAX (128 * 1024) /**< maximum input size for a single block compressed with HUF_compress */ size_t HUF_compressBound(size_t size); /**< maximum compressed size (worst case) */ /* Error Management */ unsigned HUF_isError(size_t code); /**< tells if a return value is an error code */ const char* HUF_getErrorName(size_t code); /**< provides error code string (useful for debugging) */ #define HUF_WORKSPACE_SIZE ((8 << 10) + 512 /* sorting scratch space */) #define HUF_WORKSPACE_SIZE_U64 (HUF_WORKSPACE_SIZE / sizeof(U64)) /* *** Constants *** */ #define HUF_TABLELOG_MAX 12 /* max runtime value of tableLog (due to static allocation); can be modified up to HUF_TABLELOG_ABSOLUTEMAX */ #define HUF_TABLELOG_DEFAULT 11 /* default tableLog value when none specified */ #define HUF_SYMBOLVALUE_MAX 255 #define HUF_TABLELOG_ABSOLUTEMAX 12 /* absolute limit of HUF_MAX_TABLELOG. Beyond that value, code does not work */ #if (HUF_TABLELOG_MAX > HUF_TABLELOG_ABSOLUTEMAX) # error "HUF_TABLELOG_MAX is too large !" #endif /* **************************************** * Static allocation ******************************************/ /* HUF buffer bounds */ #define HUF_CTABLEBOUND 129 #define HUF_BLOCKBOUND(size) (size + (size>>8) + 8) /* only true when incompressible is pre-filtered with fast heuristic */ #define HUF_COMPRESSBOUND(size) (HUF_CTABLEBOUND + HUF_BLOCKBOUND(size)) /* Macro version, useful for static allocation */ /* static allocation of HUF's Compression Table */ /* this is a private definition, just exposed for allocation and strict aliasing purpose. never EVER access its members directly */ typedef size_t HUF_CElt; /* consider it an incomplete type */ #define HUF_CTABLE_SIZE_ST(maxSymbolValue) ((maxSymbolValue)+2) /* Use tables of size_t, for proper alignment */ #define HUF_CTABLE_SIZE(maxSymbolValue) (HUF_CTABLE_SIZE_ST(maxSymbolValue) * sizeof(size_t)) #define HUF_CREATE_STATIC_CTABLE(name, maxSymbolValue) \ HUF_CElt name[HUF_CTABLE_SIZE_ST(maxSymbolValue)] /* no final ; */ /* static allocation of HUF's DTable */ typedef U32 HUF_DTable; #define HUF_DTABLE_SIZE(maxTableLog) (1 + (1<<(maxTableLog))) #define HUF_CREATE_STATIC_DTABLEX1(DTable, maxTableLog) \ HUF_DTable DTable[HUF_DTABLE_SIZE((maxTableLog)-1)] = { ((U32)((maxTableLog)-1) * 0x01000001) } #define HUF_CREATE_STATIC_DTABLEX2(DTable, maxTableLog) \ HUF_DTable DTable[HUF_DTABLE_SIZE(maxTableLog)] = { ((U32)(maxTableLog) * 0x01000001) } /* **************************************** * Advanced decompression functions ******************************************/ /** * Huffman flags bitset. * For all flags, 0 is the default value. */ typedef enum { /** * If compiled with DYNAMIC_BMI2: Set flag only if the CPU supports BMI2 at runtime. * Otherwise: Ignored. */ HUF_flags_bmi2 = (1 << 0), /** * If set: Test possible table depths to find the one that produces the smallest header + encoded size. * If unset: Use heuristic to find the table depth. */ HUF_flags_optimalDepth = (1 << 1), /** * If set: If the previous table can encode the input, always reuse the previous table. * If unset: If the previous table can encode the input, reuse the previous table if it results in a smaller output. */ HUF_flags_preferRepeat = (1 << 2), /** * If set: Sample the input and check if the sample is uncompressible, if it is then don't attempt to compress. * If unset: Always histogram the entire input. */ HUF_flags_suspectUncompressible = (1 << 3), /** * If set: Don't use assembly implementations * If unset: Allow using assembly implementations */ HUF_flags_disableAsm = (1 << 4), /** * If set: Don't use the fast decoding loop, always use the fallback decoding loop. * If unset: Use the fast decoding loop when possible. */ HUF_flags_disableFast = (1 << 5) } HUF_flags_e; /* **************************************** * HUF detailed API * ****************************************/ #define HUF_OPTIMAL_DEPTH_THRESHOLD ZSTD_btultra /*! HUF_compress() does the following: * 1. count symbol occurrence from source[] into table count[] using FSE_count() (exposed within "fse.h") * 2. (optional) refine tableLog using HUF_optimalTableLog() * 3. build Huffman table from count using HUF_buildCTable() * 4. save Huffman table to memory buffer using HUF_writeCTable() * 5. encode the data stream using HUF_compress4X_usingCTable() * * The following API allows targeting specific sub-functions for advanced tasks. * For example, it's possible to compress several blocks using the same 'CTable', * or to save and regenerate 'CTable' using external methods. */ unsigned HUF_minTableLog(unsigned symbolCardinality); unsigned HUF_cardinality(const unsigned* count, unsigned maxSymbolValue); unsigned HUF_optimalTableLog(unsigned maxTableLog, size_t srcSize, unsigned maxSymbolValue, void* workSpace, size_t wkspSize, HUF_CElt* table, const unsigned* count, int flags); /* table is used as scratch space for building and testing tables, not a return value */ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize, const HUF_CElt* CTable, unsigned maxSymbolValue, unsigned huffLog, void* workspace, size_t workspaceSize); size_t HUF_compress4X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue); typedef enum { HUF_repeat_none, /**< Cannot use the previous table */ HUF_repeat_check, /**< Can use the previous table but it must be checked. Note : The previous table must have been constructed by HUF_compress{1, 4}X_repeat */ HUF_repeat_valid /**< Can use the previous table and it is assumed to be valid */ } HUF_repeat; /** HUF_compress4X_repeat() : * Same as HUF_compress4X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. * If it uses hufTable it does not modify hufTable or repeat. * If it doesn't, it sets *repeat = HUF_repeat_none, and it sets hufTable to the table used. * If preferRepeat then the old table will always be used if valid. * If suspectUncompressible then some sampling checks will be run to potentially skip huffman coding */ size_t HUF_compress4X_repeat(void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize, /**< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ HUF_CElt* hufTable, HUF_repeat* repeat, int flags); /** HUF_buildCTable_wksp() : * Same as HUF_buildCTable(), but using externally allocated scratch buffer. * `workSpace` must be aligned on 4-bytes boundaries, and its size must be >= HUF_CTABLE_WORKSPACE_SIZE. */ #define HUF_CTABLE_WORKSPACE_SIZE_U32 ((4 * (HUF_SYMBOLVALUE_MAX + 1)) + 192) #define HUF_CTABLE_WORKSPACE_SIZE (HUF_CTABLE_WORKSPACE_SIZE_U32 * sizeof(unsigned)) size_t HUF_buildCTable_wksp (HUF_CElt* tree, const unsigned* count, U32 maxSymbolValue, U32 maxNbBits, void* workSpace, size_t wkspSize); /*! HUF_readStats() : * Read compact Huffman tree, saved by HUF_writeCTable(). * `huffWeight` is destination buffer. * @return : size read from `src` , or an error Code . * Note : Needed by HUF_readCTable() and HUF_readDTableXn() . */ size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, const void* src, size_t srcSize); /*! HUF_readStats_wksp() : * Same as HUF_readStats() but takes an external workspace which must be * 4-byte aligned and its size must be >= HUF_READ_STATS_WORKSPACE_SIZE. * If the CPU has BMI2 support, pass bmi2=1, otherwise pass bmi2=0. */ #define HUF_READ_STATS_WORKSPACE_SIZE_U32 FSE_DECOMPRESS_WKSP_SIZE_U32(6, HUF_TABLELOG_MAX-1) #define HUF_READ_STATS_WORKSPACE_SIZE (HUF_READ_STATS_WORKSPACE_SIZE_U32 * sizeof(unsigned)) size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats, U32* nbSymbolsPtr, U32* tableLogPtr, const void* src, size_t srcSize, void* workspace, size_t wkspSize, int flags); /** HUF_readCTable() : * Loading a CTable saved with HUF_writeCTable() */ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void* src, size_t srcSize, unsigned *hasZeroWeights); /** HUF_getNbBitsFromCTable() : * Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX * Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0 * Note 2 : is not inlined, as HUF_CElt definition is private */ U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue); typedef struct { BYTE tableLog; BYTE maxSymbolValue; BYTE unused[sizeof(size_t) - 2]; } HUF_CTableHeader; /** HUF_readCTableHeader() : * @returns The header from the CTable specifying the tableLog and the maxSymbolValue. */ HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable); /* * HUF_decompress() does the following: * 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics * 2. build Huffman table from save, using HUF_readDTableX?() * 3. decode 1 or 4 segments in parallel using HUF_decompress?X?_usingDTable() */ /** HUF_selectDecoder() : * Tells which decoder is likely to decode faster, * based on a set of pre-computed metrics. * @return : 0==HUF_decompress4X1, 1==HUF_decompress4X2 . * Assumption : 0 < dstSize <= 128 KB */ U32 HUF_selectDecoder (size_t dstSize, size_t cSrcSize); /** * The minimum workspace size for the `workSpace` used in * HUF_readDTableX1_wksp() and HUF_readDTableX2_wksp(). * * The space used depends on HUF_TABLELOG_MAX, ranging from ~1500 bytes when * HUF_TABLE_LOG_MAX=12 to ~1850 bytes when HUF_TABLE_LOG_MAX=15. * Buffer overflow errors may potentially occur if code modifications result in * a required workspace size greater than that specified in the following * macro. */ #define HUF_DECOMPRESS_WORKSPACE_SIZE ((2 << 10) + (1 << 9)) #define HUF_DECOMPRESS_WORKSPACE_SIZE_U32 (HUF_DECOMPRESS_WORKSPACE_SIZE / sizeof(U32)) /* ====================== */ /* single stream variants */ /* ====================== */ size_t HUF_compress1X_usingCTable(void* dst, size_t dstSize, const void* src, size_t srcSize, const HUF_CElt* CTable, int flags); /** HUF_compress1X_repeat() : * Same as HUF_compress1X_wksp(), but considers using hufTable if *repeat != HUF_repeat_none. * If it uses hufTable it does not modify hufTable or repeat. * If it doesn't, it sets *repeat = HUF_repeat_none, and it sets hufTable to the table used. * If preferRepeat then the old table will always be used if valid. * If suspectUncompressible then some sampling checks will be run to potentially skip huffman coding */ size_t HUF_compress1X_repeat(void* dst, size_t dstSize, const void* src, size_t srcSize, unsigned maxSymbolValue, unsigned tableLog, void* workSpace, size_t wkspSize, /**< `workSpace` must be aligned on 4-bytes boundaries, `wkspSize` must be >= HUF_WORKSPACE_SIZE */ HUF_CElt* hufTable, HUF_repeat* repeat, int flags); size_t HUF_decompress1X_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #ifndef HUF_FORCE_DECOMPRESS_X1 size_t HUF_decompress1X2_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); /**< double-symbols decoder */ #endif /* BMI2 variants. * If the CPU has BMI2 support, pass bmi2=1, otherwise pass bmi2=0. */ size_t HUF_decompress1X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); #ifndef HUF_FORCE_DECOMPRESS_X2 size_t HUF_decompress1X1_DCtx_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #endif size_t HUF_decompress4X_usingDTable(void* dst, size_t maxDstSize, const void* cSrc, size_t cSrcSize, const HUF_DTable* DTable, int flags); size_t HUF_decompress4X_hufOnly_wksp(HUF_DTable* dctx, void* dst, size_t dstSize, const void* cSrc, size_t cSrcSize, void* workSpace, size_t wkspSize, int flags); #ifndef HUF_FORCE_DECOMPRESS_X2 size_t HUF_readDTableX1_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); #endif #ifndef HUF_FORCE_DECOMPRESS_X1 size_t HUF_readDTableX2_wksp(HUF_DTable* DTable, const void* src, size_t srcSize, void* workSpace, size_t wkspSize, int flags); #endif #endif /* HUF_H_298734234 */ ================================================ FILE: src/huff/mem.h ================================================ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the * LICENSE file in the root directory of this source tree) and the GPLv2 (found * in the COPYING file in the root directory of this source tree). * You may select, at your option, one of the above-listed licenses. */ #ifndef MEM_H_MODULE #define MEM_H_MODULE /*-**************************************** * Dependencies ******************************************/ #include /* size_t, ptrdiff_t */ #include /* intptr_t */ #define MEM_STATIC typedef uint32_t U32; typedef uint8_t BYTE; #endif /* MEM_H_MODULE */ ================================================ FILE: src/redis/CMakeLists.txt ================================================ option(REDIS_ZMALLOC_MI "Implement zmalloc layer using mimalloc allocator" ON) if (REDIS_ZMALLOC_MI) set(ZMALLOC_SRC "zmalloc_mi.c") set(ZMALLOC_DEPS "TRDP::mimalloc2") else() set(ZMALLOC_SRC "zmalloc.c") set(ZMALLOC_DEPS "") endif() add_library(redis_lib crc16.c crc64.c crcspeed.c debug.c intset.c geo.c geohash.c geohash_helper.c hiredis.c read.c listpack.c lzf_c.c lzf_d.c sds.c rax.c redis_aux.c t_stream.c util.c ziplist.c hyperloglog.c ${ZMALLOC_SRC}) cxx_link(redis_lib ${ZMALLOC_DEPS}) add_library(redis_test_lib dict.c siphash.c) cxx_link(redis_test_lib redis_lib) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") target_compile_options(redis_lib PRIVATE -Wno-maybe-uninitialized) endif() if (REDIS_ZMALLOC_MI) target_compile_definitions(redis_lib PUBLIC USE_ZMALLOC_MI) endif() add_subdirectory(lua) ================================================ FILE: src/redis/LICENSE.redis ================================================ Copyright (c) 2006-2020, Salvatore Sanfilippo All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Redis nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: src/redis/config.h ================================================ /* * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __CONFIG_H #define __CONFIG_H #ifdef __APPLE__ #include #endif #ifdef __linux__ #include #endif /* Define redis_fstat to fstat or fstat64() */ #if defined(__APPLE__) && !defined(MAC_OS_X_VERSION_10_6) #define redis_fstat fstat64 #define redis_stat stat64 #else #define redis_fstat fstat #define redis_stat stat #endif /* Test for proc filesystem */ #ifdef __linux__ #define HAVE_PROC_STAT 1 #define HAVE_PROC_MAPS 1 #define HAVE_PROC_SMAPS 1 #define HAVE_PROC_SOMAXCONN 1 #define HAVE_PROC_OOM_SCORE_ADJ 1 #endif /* Test for task_info() */ #if defined(__APPLE__) #define HAVE_TASKINFO 1 #endif /* Test for backtrace() */ #if defined(__APPLE__) || (defined(__linux__) && defined(__GLIBC__)) || \ defined(__FreeBSD__) || ((defined(__OpenBSD__) || defined(__NetBSD__)) && defined(USE_BACKTRACE))\ || defined(__DragonFly__) || (defined(__UCLIBC__) && defined(__UCLIBC_HAS_BACKTRACE__)) #define HAVE_BACKTRACE 1 #endif /* MSG_NOSIGNAL. */ #ifdef __linux__ #define HAVE_MSG_NOSIGNAL 1 #endif /* Test for polling API */ #ifdef __linux__ #define HAVE_EPOLL 1 #endif #if (defined(__APPLE__) && defined(MAC_OS_X_VERSION_10_6)) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined (__NetBSD__) #define HAVE_KQUEUE 1 #endif #ifdef __sun #include #ifdef _DTRACE_VERSION #define HAVE_EVPORT 1 #define HAVE_PSINFO 1 #endif #endif /* Define redis_fsync to fdatasync() in Linux and fsync() for all the rest */ #ifdef __linux__ #define redis_fsync fdatasync #else #define redis_fsync fsync #endif #if __GNUC__ >= 4 #define valkey_unreachable __builtin_unreachable #else #define valkey_unreachable abort #endif #if __GNUC__ >= 3 #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) #else #define likely(x) (x) #define unlikely(x) (x) #endif /* Define rdb_fsync_range to sync_file_range() on Linux, otherwise we use * the plain fsync() call. */ #if (defined(__linux__) && defined(SYNC_FILE_RANGE_WAIT_BEFORE)) #define rdb_fsync_range(fd,off,size) sync_file_range(fd,off,size,SYNC_FILE_RANGE_WAIT_BEFORE|SYNC_FILE_RANGE_WRITE) #else #define rdb_fsync_range(fd,off,size) fsync(fd) #endif /* Check if we can use setproctitle(). * BSD systems have support for it, we provide an implementation for * Linux and osx. */ #if (defined __NetBSD__ || defined __FreeBSD__ || defined __OpenBSD__) #define USE_SETPROCTITLE #endif #if defined(__HAIKU__) #define ESOCKTNOSUPPORT 0 #endif #if (defined __linux || defined __APPLE__) #define USE_SETPROCTITLE #define INIT_SETPROCTITLE_REPLACEMENT void spt_init(int argc, char *argv[]); void setproctitle(const char *fmt, ...); #endif /* Byte ordering detection */ #include /* This will likely define BYTE_ORDER */ #ifndef BYTE_ORDER #if (BSD >= 199103) # include #else #if defined(linux) || defined(__linux__) # include #else #define LITTLE_ENDIAN 1234 /* least-significant byte first (vax, pc) */ #define BIG_ENDIAN 4321 /* most-significant byte first (IBM, net) */ #define PDP_ENDIAN 3412 /* LSB first in word, MSW first in long (pdp)*/ #if defined(__i386__) || defined(__x86_64__) || defined(__amd64__) || \ defined(vax) || defined(ns32000) || defined(sun386) || \ defined(MIPSEL) || defined(_MIPSEL) || defined(BIT_ZERO_ON_RIGHT) || \ defined(__alpha__) || defined(__alpha) #define BYTE_ORDER LITTLE_ENDIAN #endif #if defined(sel) || defined(pyr) || defined(mc68000) || defined(sparc) || \ defined(is68k) || defined(tahoe) || defined(ibm032) || defined(ibm370) || \ defined(MIPSEB) || defined(_MIPSEB) || defined(_IBMR2) || defined(DGUX) ||\ defined(apollo) || defined(__convex__) || defined(_CRAY) || \ defined(__hppa) || defined(__hp9000) || \ defined(__hp9000s300) || defined(__hp9000s700) || \ defined (BIT_ZERO_ON_LEFT) || defined(m68k) || defined(__sparc) #define BYTE_ORDER BIG_ENDIAN #endif #endif /* linux */ #endif /* BSD */ #endif /* BYTE_ORDER */ /* Sometimes after including an OS-specific header that defines the * endianness we end with __BYTE_ORDER but not with BYTE_ORDER that is what * the Redis code uses. In this case let's define everything without the * underscores. */ #ifndef BYTE_ORDER #ifdef __BYTE_ORDER #if defined(__LITTLE_ENDIAN) && defined(__BIG_ENDIAN) #ifndef LITTLE_ENDIAN #define LITTLE_ENDIAN __LITTLE_ENDIAN #endif #ifndef BIG_ENDIAN #define BIG_ENDIAN __BIG_ENDIAN #endif #if (__BYTE_ORDER == __LITTLE_ENDIAN) #define BYTE_ORDER LITTLE_ENDIAN #else #define BYTE_ORDER BIG_ENDIAN #endif #endif #endif #endif #if !defined(BYTE_ORDER) || \ (BYTE_ORDER != BIG_ENDIAN && BYTE_ORDER != LITTLE_ENDIAN) /* you must determine what the correct bit order is for * your compiler - the next line is an intentional error * which will force your compiles to bomb until you fix * the above macros. */ #error "Undefined or invalid BYTE_ORDER" #endif #if (__i386 || __amd64 || __powerpc__) && __GNUC__ #define GNUC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) #if defined(__clang__) #define HAVE_ATOMIC #endif #if (defined(__GLIBC__) && defined(__GLIBC_PREREQ)) #if (GNUC_VERSION >= 40100 && __GLIBC_PREREQ(2, 6)) #define HAVE_ATOMIC #endif #endif #endif /* Make sure we can test for ARM just checking for __arm__, since sometimes * __arm is defined but __arm__ is not. */ #if defined(__arm) && !defined(__arm__) #define __arm__ #endif #if defined (__aarch64__) && !defined(__arm64__) #define __arm64__ #endif /* Make sure we can test for SPARC just checking for __sparc__. */ #if defined(__sparc) && !defined(__sparc__) #define __sparc__ #endif #if defined(__sparc__) || defined(__arm__) #define USE_ALIGNED_ACCESS #endif /* Define for redis_set_thread_title */ #ifdef __linux__ #define redis_set_thread_title(name) pthread_setname_np(pthread_self(), name) #else #if (defined __FreeBSD__ || defined __OpenBSD__) #include #define redis_set_thread_title(name) pthread_set_name_np(pthread_self(), name) #elif defined __NetBSD__ #include #define redis_set_thread_title(name) pthread_setname_np(pthread_self(), "%s", name) #elif defined __HAIKU__ #include #define redis_set_thread_title(name) rename_thread(find_thread(0), name) #else #if (defined __APPLE__ && defined(MAC_OS_X_VERSION_10_7)) int pthread_setname_np(const char *name); #include #define redis_set_thread_title(name) pthread_setname_np(name) #else #define redis_set_thread_title(name) #endif #endif #endif /* Check if we can use setcpuaffinity(). */ #if (defined __linux || defined __NetBSD__ || defined __FreeBSD__ || defined __DragonFly__) #define USE_SETCPUAFFINITY void setcpuaffinity(const char *cpulist); #endif #endif ================================================ FILE: src/redis/crc16.c ================================================ #include "crc16.h" /* * Copyright 2001-2010 Georges Menie (www.menie.org) * Copyright 2010-2012 Salvatore Sanfilippo (adapted to Redis coding style) * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the University of California, Berkeley nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE REGENTS AND CONTRIBUTORS BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /* CRC16 implementation according to CCITT standards. * * Note by @antirez: this is actually the XMODEM CRC 16 algorithm, using the * following parameters: * * Name : "XMODEM", also known as "ZMODEM", "CRC-16/ACORN" * Width : 16 bit * Poly : 1021 (That is actually x^16 + x^12 + x^5 + 1) * Initialization : 0000 * Reflect Input byte : False * Reflect Output CRC : False * Xor constant to output CRC : 0000 * Output for "123456789" : 31C3 */ static const uint16_t crc16tab[256] = { 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, 0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, 0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, 0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, 0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, 0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, 0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, 0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, 0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, 0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, 0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, 0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, 0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, 0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, 0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, 0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, 0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, 0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, 0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, 0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, 0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, 0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, 0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, 0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, 0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, 0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, 0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, 0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, 0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, 0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, 0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0}; uint16_t crc16(const char* buf, int len) { int counter; uint16_t crc = 0; for (counter = 0; counter < len; counter++) crc = (crc << 8) ^ crc16tab[((crc >> 8) ^ *buf++) & 0x00FF]; return crc; } ================================================ FILE: src/redis/crc16.h ================================================ #ifndef CRC16_H #define CRC16_H #include uint16_t crc16(const char* buf, int len); #endif ================================================ FILE: src/redis/crc64.c ================================================ /* Copyright (c) 2014, Matt Stancliff * Copyright (c) 2020, Amazon Web Services * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include "crc64.h" #include "crcspeed.h" static uint64_t crc64_table[8][256] = {{0}}; #define POLY UINT64_C(0xad93d23594c935a9) /******************** BEGIN GENERATED PYCRC FUNCTIONS ********************/ /** * Generated on Sun Dec 21 14:14:07 2014, * by pycrc v0.8.2, https://www.tty1.net/pycrc/ * * LICENSE ON GENERATED CODE: * ========================== * As of version 0.6, pycrc is released under the terms of the MIT licence. * The code generated by pycrc is not considered a substantial portion of the * software, therefore the author of pycrc will not claim any copyright on * the generated code. * ========================== * * CRC configuration: * Width = 64 * Poly = 0xad93d23594c935a9 * XorIn = 0xffffffffffffffff * ReflectIn = True * XorOut = 0x0000000000000000 * ReflectOut = True * Algorithm = bit-by-bit-fast * * Modifications after generation (by matt): * - included finalize step in-line with update for single-call generation * - re-worked some inner variable architectures * - adjusted function parameters to match expected prototypes. *****************************************************************************/ /** * Reflect all bits of a \a data word of \a data_len bytes. * * \param data The data word to be reflected. * \param data_len The width of \a data expressed in number of bits. * \return The reflected data. *****************************************************************************/ static inline uint_fast64_t crc_reflect(uint_fast64_t data, size_t data_len) { uint_fast64_t ret = data & 0x01; for (size_t i = 1; i < data_len; i++) { data >>= 1; ret = (ret << 1) | (data & 0x01); } return ret; } /** * Update the crc value with new data. * * \param crc The current crc value. * \param data Pointer to a buffer of \a data_len bytes. * \param data_len Number of bytes in the \a data buffer. * \return The updated crc value. ******************************************************************************/ uint64_t _crc64(uint_fast64_t crc, const void *in_data, const uint64_t len) { const uint8_t *data = in_data; unsigned long long bit; for (uint64_t offset = 0; offset < len; offset++) { uint8_t c = data[offset]; for (uint_fast8_t i = 0x01; i & 0xff; i <<= 1) { bit = crc & 0x8000000000000000; if (c & i) { bit = !bit; } crc <<= 1; if (bit) { crc ^= POLY; } } crc &= 0xffffffffffffffff; } crc = crc & 0xffffffffffffffff; return crc_reflect(crc, 64) ^ 0x0000000000000000; } /******************** END GENERATED PYCRC FUNCTIONS ********************/ /* Initializes the 16KB lookup tables. */ void crc64_init(void) { crcspeed64native_init(_crc64, crc64_table); } /* Compute crc64 */ uint64_t crc64(uint64_t crc, const unsigned char *s, uint64_t l) { return crcspeed64native(crc64_table, crc, (void *) s, l); } /* Test main */ #ifdef REDIS_TEST #include #define UNUSED(x) (void)(x) int crc64Test(int argc, char *argv[], int flags) { UNUSED(argc); UNUSED(argv); UNUSED(flags); crc64_init(); printf("[calcula]: e9c6d914c4b8d9ca == %016" PRIx64 "\n", (uint64_t)_crc64(0, "123456789", 9)); printf("[64speed]: e9c6d914c4b8d9ca == %016" PRIx64 "\n", (uint64_t)crc64(0, (unsigned char*)"123456789", 9)); char li[] = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed " "do eiusmod tempor incididunt ut labore et dolore magna " "aliqua. Ut enim ad minim veniam, quis nostrud exercitation " "ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis " "aute irure dolor in reprehenderit in voluptate velit esse " "cillum dolore eu fugiat nulla pariatur. Excepteur sint " "occaecat cupidatat non proident, sunt in culpa qui officia " "deserunt mollit anim id est laborum."; printf("[calcula]: c7794709e69683b3 == %016" PRIx64 "\n", (uint64_t)_crc64(0, li, sizeof(li))); printf("[64speed]: c7794709e69683b3 == %016" PRIx64 "\n", (uint64_t)crc64(0, (unsigned char*)li, sizeof(li))); return 0; } #endif #ifdef REDIS_TEST_MAIN int main(int argc, char *argv[]) { return crc64Test(argc, argv); } #endif ================================================ FILE: src/redis/crc64.h ================================================ #ifndef CRC64_H #define CRC64_H #include void crc64_init(void); uint64_t crc64(uint64_t crc, const unsigned char *s, uint64_t l); #ifdef REDIS_TEST int crc64Test(int argc, char *argv[], int flags); #endif #endif ================================================ FILE: src/redis/crcspeed.c ================================================ /* * Copyright (C) 2013 Mark Adler * Originally by: crc64.c Version 1.4 16 Dec 2013 Mark Adler * Modifications by Matt Stancliff : * - removed CRC64-specific behavior * - added generation of lookup tables by parameters * - removed inversion of CRC input/result * - removed automatic initialization in favor of explicit initialization This software is provided 'as-is', without any express or implied warranty. In no event will the author be held liable for any damages arising from the use of this software. Permission is granted to anyone to use this software for any purpose, including commercial applications, and to alter it and redistribute it freely, subject to the following restrictions: 1. The origin of this software must not be misrepresented; you must not claim that you wrote the original software. If you use this software in a product, an acknowledgment in the product documentation would be appreciated but is not required. 2. Altered source versions must be plainly marked as such, and must not be misrepresented as being the original software. 3. This notice may not be removed or altered from any source distribution. Mark Adler madler@alumni.caltech.edu */ #include "crcspeed.h" /* Fill in a CRC constants table. */ void crcspeed64little_init(crcfn64 crcfn, uint64_t table[8][256]) { uint64_t crc; /* generate CRCs for all single byte sequences */ for (int n = 0; n < 256; n++) { unsigned char v = n; table[0][n] = crcfn(0, &v, 1); } /* generate nested CRC table for future slice-by-8 lookup */ for (int n = 0; n < 256; n++) { crc = table[0][n]; for (int k = 1; k < 8; k++) { crc = table[0][crc & 0xff] ^ (crc >> 8); table[k][n] = crc; } } } void crcspeed16little_init(crcfn16 crcfn, uint16_t table[8][256]) { uint16_t crc; /* generate CRCs for all single byte sequences */ for (int n = 0; n < 256; n++) { table[0][n] = crcfn(0, &n, 1); } /* generate nested CRC table for future slice-by-8 lookup */ for (int n = 0; n < 256; n++) { crc = table[0][n]; for (int k = 1; k < 8; k++) { crc = table[0][(crc >> 8) & 0xff] ^ (crc << 8); table[k][n] = crc; } } } /* Reverse the bytes in a 64-bit word. */ static inline uint64_t rev8(uint64_t a) { #if defined(__GNUC__) || defined(__clang__) return __builtin_bswap64(a); #else uint64_t m; m = UINT64_C(0xff00ff00ff00ff); a = ((a >> 8) & m) | (a & m) << 8; m = UINT64_C(0xffff0000ffff); a = ((a >> 16) & m) | (a & m) << 16; return a >> 32 | a << 32; #endif } /* This function is called once to initialize the CRC table for use on a big-endian architecture. */ void crcspeed64big_init(crcfn64 fn, uint64_t big_table[8][256]) { /* Create the little endian table then reverse all the entries. */ crcspeed64little_init(fn, big_table); for (int k = 0; k < 8; k++) { for (int n = 0; n < 256; n++) { big_table[k][n] = rev8(big_table[k][n]); } } } void crcspeed16big_init(crcfn16 fn, uint16_t big_table[8][256]) { /* Create the little endian table then reverse all the entries. */ crcspeed16little_init(fn, big_table); for (int k = 0; k < 8; k++) { for (int n = 0; n < 256; n++) { big_table[k][n] = rev8(big_table[k][n]); } } } /* Calculate a non-inverted CRC multiple bytes at a time on a little-endian * architecture. If you need inverted CRC, invert *before* calling and invert * *after* calling. * 64 bit crc = process 8 bytes at once; */ uint64_t crcspeed64little(uint64_t little_table[8][256], uint64_t crc, void *buf, size_t len) { unsigned char *next = buf; /* process individual bytes until we reach an 8-byte aligned pointer */ while (len && ((uintptr_t)next & 7) != 0) { crc = little_table[0][(crc ^ *next++) & 0xff] ^ (crc >> 8); len--; } /* fast middle processing, 8 bytes (aligned!) per loop */ while (len >= 8) { crc ^= *(uint64_t *)next; crc = little_table[7][crc & 0xff] ^ little_table[6][(crc >> 8) & 0xff] ^ little_table[5][(crc >> 16) & 0xff] ^ little_table[4][(crc >> 24) & 0xff] ^ little_table[3][(crc >> 32) & 0xff] ^ little_table[2][(crc >> 40) & 0xff] ^ little_table[1][(crc >> 48) & 0xff] ^ little_table[0][crc >> 56]; next += 8; len -= 8; } /* process remaining bytes (can't be larger than 8) */ while (len) { crc = little_table[0][(crc ^ *next++) & 0xff] ^ (crc >> 8); len--; } return crc; } uint16_t crcspeed16little(uint16_t little_table[8][256], uint16_t crc, void *buf, size_t len) { unsigned char *next = buf; /* process individual bytes until we reach an 8-byte aligned pointer */ while (len && ((uintptr_t)next & 7) != 0) { crc = little_table[0][((crc >> 8) ^ *next++) & 0xff] ^ (crc << 8); len--; } /* fast middle processing, 8 bytes (aligned!) per loop */ while (len >= 8) { uint64_t n = *(uint64_t *)next; crc = little_table[7][(n & 0xff) ^ ((crc >> 8) & 0xff)] ^ little_table[6][((n >> 8) & 0xff) ^ (crc & 0xff)] ^ little_table[5][(n >> 16) & 0xff] ^ little_table[4][(n >> 24) & 0xff] ^ little_table[3][(n >> 32) & 0xff] ^ little_table[2][(n >> 40) & 0xff] ^ little_table[1][(n >> 48) & 0xff] ^ little_table[0][n >> 56]; next += 8; len -= 8; } /* process remaining bytes (can't be larger than 8) */ while (len) { crc = little_table[0][((crc >> 8) ^ *next++) & 0xff] ^ (crc << 8); len--; } return crc; } /* Calculate a non-inverted CRC eight bytes at a time on a big-endian * architecture. */ uint64_t crcspeed64big(uint64_t big_table[8][256], uint64_t crc, void *buf, size_t len) { unsigned char *next = buf; crc = rev8(crc); while (len && ((uintptr_t)next & 7) != 0) { crc = big_table[0][(crc >> 56) ^ *next++] ^ (crc << 8); len--; } while (len >= 8) { crc ^= *(uint64_t *)next; crc = big_table[0][crc & 0xff] ^ big_table[1][(crc >> 8) & 0xff] ^ big_table[2][(crc >> 16) & 0xff] ^ big_table[3][(crc >> 24) & 0xff] ^ big_table[4][(crc >> 32) & 0xff] ^ big_table[5][(crc >> 40) & 0xff] ^ big_table[6][(crc >> 48) & 0xff] ^ big_table[7][crc >> 56]; next += 8; len -= 8; } while (len) { crc = big_table[0][(crc >> 56) ^ *next++] ^ (crc << 8); len--; } return rev8(crc); } /* WARNING: Completely untested on big endian architecture. Possibly broken. */ uint16_t crcspeed16big(uint16_t big_table[8][256], uint16_t crc_in, void *buf, size_t len) { unsigned char *next = buf; uint64_t crc = crc_in; crc = rev8(crc); while (len && ((uintptr_t)next & 7) != 0) { crc = big_table[0][((crc >> (56 - 8)) ^ *next++) & 0xff] ^ (crc >> 8); len--; } while (len >= 8) { uint64_t n = *(uint64_t *)next; crc = big_table[0][(n & 0xff) ^ ((crc >> (56 - 8)) & 0xff)] ^ big_table[1][((n >> 8) & 0xff) ^ (crc & 0xff)] ^ big_table[2][(n >> 16) & 0xff] ^ big_table[3][(n >> 24) & 0xff] ^ big_table[4][(n >> 32) & 0xff] ^ big_table[5][(n >> 40) & 0xff] ^ big_table[6][(n >> 48) & 0xff] ^ big_table[7][n >> 56]; next += 8; len -= 8; } while (len) { crc = big_table[0][((crc >> (56 - 8)) ^ *next++) & 0xff] ^ (crc >> 8); len--; } return rev8(crc); } /* Return the CRC of buf[0..len-1] with initial crc, processing eight bytes at a time using passed-in lookup table. This selects one of two routines depending on the endianess of the architecture. */ uint64_t crcspeed64native(uint64_t table[8][256], uint64_t crc, void *buf, size_t len) { uint64_t n = 1; return *(char *)&n ? crcspeed64little(table, crc, buf, len) : crcspeed64big(table, crc, buf, len); } uint16_t crcspeed16native(uint16_t table[8][256], uint16_t crc, void *buf, size_t len) { uint64_t n = 1; return *(char *)&n ? crcspeed16little(table, crc, buf, len) : crcspeed16big(table, crc, buf, len); } /* Initialize CRC lookup table in architecture-dependent manner. */ void crcspeed64native_init(crcfn64 fn, uint64_t table[8][256]) { uint64_t n = 1; *(char *)&n ? crcspeed64little_init(fn, table) : crcspeed64big_init(fn, table); } void crcspeed16native_init(crcfn16 fn, uint16_t table[8][256]) { uint64_t n = 1; *(char *)&n ? crcspeed16little_init(fn, table) : crcspeed16big_init(fn, table); } ================================================ FILE: src/redis/crcspeed.h ================================================ /* Copyright (c) 2014, Matt Stancliff * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef CRCSPEED_H #define CRCSPEED_H #include #include typedef uint64_t (*crcfn64)(uint64_t, const void *, const uint64_t); typedef uint16_t (*crcfn16)(uint16_t, const void *, const uint64_t); /* CRC-64 */ void crcspeed64little_init(crcfn64 fn, uint64_t table[8][256]); void crcspeed64big_init(crcfn64 fn, uint64_t table[8][256]); void crcspeed64native_init(crcfn64 fn, uint64_t table[8][256]); uint64_t crcspeed64little(uint64_t table[8][256], uint64_t crc, void *buf, size_t len); uint64_t crcspeed64big(uint64_t table[8][256], uint64_t crc, void *buf, size_t len); uint64_t crcspeed64native(uint64_t table[8][256], uint64_t crc, void *buf, size_t len); /* CRC-16 */ void crcspeed16little_init(crcfn16 fn, uint16_t table[8][256]); void crcspeed16big_init(crcfn16 fn, uint16_t table[8][256]); void crcspeed16native_init(crcfn16 fn, uint16_t table[8][256]); uint16_t crcspeed16little(uint16_t table[8][256], uint16_t crc, void *buf, size_t len); uint16_t crcspeed16big(uint16_t table[8][256], uint16_t crc, void *buf, size_t len); uint16_t crcspeed16native(uint16_t table[8][256], uint16_t crc, void *buf, size_t len); #endif ================================================ FILE: src/redis/debug.c ================================================ /* * Copyright (c) 2009-2020, Salvatore Sanfilippo * Copyright (c) 2020, Redis Labs, Inc * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include "util.h" int verbosity = LL_NOTICE; void serverLog(int level, const char *fmt, ...) { va_list ap; char msg[LOG_MAX_LEN]; if ((level&0xff) < verbosity) return; va_start(ap, fmt); vsnprintf(msg, sizeof(msg), fmt, ap); va_end(ap); fprintf(stdout, "%s\n",msg); } void _serverPanic(const char *file, int line, const char *msg, ...) { va_list ap; va_start(ap,msg); char fmtmsg[256]; vsnprintf(fmtmsg,sizeof(fmtmsg),msg,ap); va_end(ap); serverLog(LL_WARNING, "------------------------------------------------"); serverLog(LL_WARNING, "!!! Software Failure. Press left mouse button to continue"); serverLog(LL_WARNING, "Guru Meditation: %s #%s:%d", fmtmsg,file,line); #ifndef NDEBUG #if defined(__APPLE__) __assert_rtn(msg, file, line, ""); #elif defined(__FreeBSD__) __assert("", file, line, msg); #else __assert_fail(msg, file, line, ""); #endif #endif } void _serverAssert(const char *estr, const char *file, int line) { serverLog(LL_WARNING,"=== ASSERTION FAILED ==="); serverLog(LL_WARNING,"==> %s:%d '%s' is not true",file,line,estr); } ================================================ FILE: src/redis/dict.c ================================================ /* Hash Tables Implementation. * * This file implements in memory hash tables with insert/del/replace/find/ * get-random-element operations. Hash tables will auto resize if needed * tables of power of two in size are used, collisions are handled by * chaining. See the source code for more information... :) * * Copyright (c) 2006-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include #include "dict.h" #include "zmalloc.h" #if !defined(DICT_BENCHMARK_MAIN) && defined(ROMAN_REDIS_ASSERT_DISABLED) #include "redisassert.h" #else #include #endif /* Using dictEnableResize() / dictDisableResize() we make possible to * enable/disable resizing of the hash table as needed. This is very important * for Redis, as we use copy-on-write and don't want to move too much memory * around when there is a child performing saving operations. * * Note that even when dict_can_resize is set to 0, not all resizes are * prevented: a hash table is still allowed to grow if the ratio between * the number of elements and the buckets > dict_force_resize_ratio. */ static int dict_can_resize = 1; static unsigned int dict_force_resize_ratio = 5; /* -------------------------- private prototypes ---------------------------- */ static int _dictExpandIfNeeded(dict *d); static signed char _dictNextExp(unsigned long size); static long _dictKeyIndex(dict *d, const void *key, uint64_t hash, dictEntry **existing); static int _dictInit(dict *d, dictType *type); /* -------------------------- hash functions -------------------------------- */ static uint8_t dict_hash_function_seed[16]; void dictSetHashFunctionSeed(uint8_t *seed) { memcpy(dict_hash_function_seed,seed,sizeof(dict_hash_function_seed)); } uint8_t *dictGetHashFunctionSeed(void) { return dict_hash_function_seed; } /* The default hashing function uses SipHash implementation * in siphash.c. */ uint64_t siphash(const uint8_t *in, const size_t inlen, const uint8_t *k); uint64_t siphash_nocase(const uint8_t *in, const size_t inlen, const uint8_t *k); uint64_t dictGenHashFunction(const void *key, size_t len) { return siphash(key,len,dict_hash_function_seed); } uint64_t dictGenCaseHashFunction(const unsigned char *buf, size_t len) { return siphash_nocase(buf,len,dict_hash_function_seed); } /* ----------------------------- API implementation ------------------------- */ /* Reset hash table parameters already initialized with _dictInit()*/ static void _dictReset(dict *d, int htidx) { d->ht_table[htidx] = NULL; d->ht_size_exp[htidx] = -1; d->ht_used[htidx] = 0; } /* Create a new hash table */ dict *dictCreate(dictType *type) { dict *d = zmalloc(sizeof(*d)); _dictInit(d,type); return d; } /* Initialize the hash table */ int _dictInit(dict *d, dictType *type) { _dictReset(d, 0); _dictReset(d, 1); d->type = type; d->rehashidx = -1; d->pauserehash = 0; return DICT_OK; } /* Resize the table to the minimal size that contains all the elements, * but with the invariant of a USED/BUCKETS ratio near to <= 1 */ int dictResize(dict *d) { unsigned long minimal; if (!dict_can_resize || dictIsRehashing(d)) return DICT_ERR; minimal = d->ht_used[0]; if (minimal < DICT_HT_INITIAL_SIZE) minimal = DICT_HT_INITIAL_SIZE; return dictExpand(d, minimal); } /* Expand or create the hash table, * when malloc_failed is non-NULL, it'll avoid panic if malloc fails (in which case it'll be set to 1). * Returns DICT_OK if expand was performed, and DICT_ERR if skipped. */ int _dictExpand(dict *d, unsigned long size, int* malloc_failed) { if (malloc_failed) *malloc_failed = 0; /* the size is invalid if it is smaller than the number of * elements already inside the hash table */ if (dictIsRehashing(d) || d->ht_used[0] > size) return DICT_ERR; /* the new hash table */ dictEntry **new_ht_table; unsigned long new_ht_used; signed char new_ht_size_exp = _dictNextExp(size); /* Detect overflows */ size_t newsize = 1ul<ht_size_exp[0]) return DICT_ERR; /* Allocate the new hash table and initialize all pointers to NULL */ if (malloc_failed) { new_ht_table = ztrycalloc(newsize*sizeof(dictEntry*)); *malloc_failed = new_ht_table == NULL; if (*malloc_failed) return DICT_ERR; } else new_ht_table = zcalloc(newsize*sizeof(dictEntry*)); new_ht_used = 0; /* Is this the first initialization? If so it's not really a rehashing * we just set the first hash table so that it can accept keys. */ if (d->ht_table[0] == NULL) { d->ht_size_exp[0] = new_ht_size_exp; d->ht_used[0] = new_ht_used; d->ht_table[0] = new_ht_table; return DICT_OK; } /* Prepare a second hash table for incremental rehashing */ d->ht_size_exp[1] = new_ht_size_exp; d->ht_used[1] = new_ht_used; d->ht_table[1] = new_ht_table; d->rehashidx = 0; return DICT_OK; } /* return DICT_ERR if expand was not performed */ int dictExpand(dict *d, unsigned long size) { return _dictExpand(d, size, NULL); } /* return DICT_ERR if expand failed due to memory allocation failure */ int dictTryExpand(dict *d, unsigned long size) { int malloc_failed; _dictExpand(d, size, &malloc_failed); return malloc_failed? DICT_ERR : DICT_OK; } /* Performs N steps of incremental rehashing. Returns 1 if there are still * keys to move from the old to the new hash table, otherwise 0 is returned. * * Note that a rehashing step consists in moving a bucket (that may have more * than one key as we use chaining) from the old to the new hash table, however * since part of the hash table may be composed of empty spaces, it is not * guaranteed that this function will rehash even a single bucket, since it * will visit at max N*10 empty buckets in total, otherwise the amount of * work it does would be unbound and the function may block for a long time. */ int dictRehash(dict *d, int n) { int empty_visits = n*10; /* Max number of empty buckets to visit. */ if (!dictIsRehashing(d)) return 0; while(n-- && d->ht_used[0] != 0) { dictEntry *de, *nextde; /* Note that rehashidx can't overflow as we are sure there are more * elements because ht[0].used != 0 */ assert(DICTHT_SIZE(d->ht_size_exp[0]) > (unsigned long)d->rehashidx); while(d->ht_table[0][d->rehashidx] == NULL) { d->rehashidx++; if (--empty_visits == 0) return 1; } de = d->ht_table[0][d->rehashidx]; /* Move all the keys in this bucket from the old to the new hash HT */ while(de) { uint64_t h; nextde = de->next; /* Get the index in the new hash table */ h = dictHashKey(d, de->key) & DICTHT_SIZE_MASK(d->ht_size_exp[1]); de->next = d->ht_table[1][h]; d->ht_table[1][h] = de; d->ht_used[0]--; d->ht_used[1]++; de = nextde; } d->ht_table[0][d->rehashidx] = NULL; d->rehashidx++; } /* Check if we already rehashed the whole table... */ if (d->ht_used[0] == 0) { zfree(d->ht_table[0]); /* Copy the new ht onto the old one */ d->ht_table[0] = d->ht_table[1]; d->ht_used[0] = d->ht_used[1]; d->ht_size_exp[0] = d->ht_size_exp[1]; _dictReset(d, 1); d->rehashidx = -1; return 0; } /* More to rehash... */ return 1; } long long timeInMilliseconds(void) { struct timeval tv; gettimeofday(&tv,NULL); return (((long long)tv.tv_sec)*1000)+(tv.tv_usec/1000); } /* Rehash in ms+"delta" milliseconds. The value of "delta" is larger * than 0, and is smaller than 1 in most cases. The exact upper bound * depends on the running time of dictRehash(d,100).*/ int dictRehashMilliseconds(dict *d, int ms) { if (d->pauserehash > 0) return 0; long long start = timeInMilliseconds(); int rehashes = 0; while(dictRehash(d,100)) { rehashes += 100; if (timeInMilliseconds()-start > ms) break; } return rehashes; } /* This function performs just a step of rehashing, and only if hashing has * not been paused for our hash table. When we have iterators in the * middle of a rehashing we can't mess with the two hash tables otherwise * some element can be missed or duplicated. * * This function is called by common lookup or update operations in the * dictionary so that the hash table automatically migrates from H1 to H2 * while it is actively used. */ static void _dictRehashStep(dict *d) { if (d->pauserehash == 0) dictRehash(d,1); } /* Add an element to the target hash table */ int dictAdd(dict *d, void *key, void *val) { dictEntry *entry = dictAddRaw(d,key,NULL); if (!entry) return DICT_ERR; dictSetVal(d, entry, val); return DICT_OK; } /* Low level add or find: * This function adds the entry but instead of setting a value returns the * dictEntry structure to the user, that will make sure to fill the value * field as they wish. * * This function is also directly exposed to the user API to be called * mainly in order to store non-pointers inside the hash value, example: * * entry = dictAddRaw(dict,mykey,NULL); * if (entry != NULL) dictSetSignedIntegerVal(entry,1000); * * Return values: * * If key already exists NULL is returned, and "*existing" is populated * with the existing entry if existing is not NULL. * * If key was added, the hash entry is returned to be manipulated by the caller. */ dictEntry *dictAddRaw(dict *d, void *key, dictEntry **existing) { long index; dictEntry *entry; int htidx; if (dictIsRehashing(d)) _dictRehashStep(d); /* Get the index of the new element, or -1 if * the element already exists. */ if ((index = _dictKeyIndex(d, key, dictHashKey(d,key), existing)) == -1) return NULL; /* Allocate the memory and store the new entry. * Insert the element in top, with the assumption that in a database * system it is more likely that recently added entries are accessed * more frequently. */ htidx = dictIsRehashing(d) ? 1 : 0; size_t metasize = 0; entry = zmalloc(sizeof(*entry) + metasize); entry->next = d->ht_table[htidx][index]; d->ht_table[htidx][index] = entry; d->ht_used[htidx]++; /* Set the hash entry fields. */ dictSetKey(d, entry, key); return entry; } /* Add or Overwrite: * Add an element, discarding the old value if the key already exists. * Return 1 if the key was added from scratch, 0 if there was already an * element with such key and dictReplace() just performed a value update * operation. */ int dictReplace(dict *d, void *key, void *val) { dictEntry *entry, *existing, auxentry; /* Try to add the element. If the key * does not exists dictAdd will succeed. */ entry = dictAddRaw(d,key,&existing); if (entry) { dictSetVal(d, entry, val); return 1; } /* Set the new value and free the old one. Note that it is important * to do that in this order, as the value may just be exactly the same * as the previous one. In this context, think to reference counting, * you want to increment (set), and then decrement (free), and not the * reverse. */ auxentry = *existing; dictSetVal(d, existing, val); dictFreeVal(d, &auxentry); return 0; } /* Add or Find: * dictAddOrFind() is simply a version of dictAddRaw() that always * returns the hash entry of the specified key, even if the key already * exists and can't be added (in that case the entry of the already * existing key is returned.) * * See dictAddRaw() for more information. */ dictEntry *dictAddOrFind(dict *d, void *key) { dictEntry *entry, *existing; entry = dictAddRaw(d,key,&existing); return entry ? entry : existing; } /* Search and remove an element. This is a helper function for * dictDelete() and dictUnlink(), please check the top comment * of those functions. */ static dictEntry *dictGenericDelete(dict *d, const void *key, int nofree) { uint64_t h, idx; dictEntry *he, *prevHe; int table; /* dict is empty */ if (dictSize(d) == 0) return NULL; if (dictIsRehashing(d)) _dictRehashStep(d); h = dictHashKey(d, key); for (table = 0; table <= 1; table++) { idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[table]); he = d->ht_table[table][idx]; prevHe = NULL; while(he) { if (key==he->key || dictCompareKeys(d, key, he->key)) { /* Unlink the element from the list */ if (prevHe) prevHe->next = he->next; else d->ht_table[table][idx] = he->next; if (!nofree) { dictFreeUnlinkedEntry(d, he); } d->ht_used[table]--; return he; } prevHe = he; he = he->next; } if (!dictIsRehashing(d)) break; } return NULL; /* not found */ } /* Remove an element, returning DICT_OK on success or DICT_ERR if the * element was not found. */ int dictDelete(dict *ht, const void *key) { return dictGenericDelete(ht,key,0) ? DICT_OK : DICT_ERR; } /* Remove an element from the table, but without actually releasing * the key, value and dictionary entry. The dictionary entry is returned * if the element was found (and unlinked from the table), and the user * should later call `dictFreeUnlinkedEntry()` with it in order to release it. * Otherwise if the key is not found, NULL is returned. * * This function is useful when we want to remove something from the hash * table but want to use its value before actually deleting the entry. * Without this function the pattern would require two lookups: * * entry = dictFind(...); * // Do something with entry * dictDelete(dictionary,entry); * * Thanks to this function it is possible to avoid this, and use * instead: * * entry = dictUnlink(dictionary,entry); * // Do something with entry * dictFreeUnlinkedEntry(entry); // <- This does not need to lookup again. */ dictEntry *dictUnlink(dict *d, const void *key) { return dictGenericDelete(d,key,1); } /* You need to call this function to really free the entry after a call * to dictUnlink(). It's safe to call this function with 'he' = NULL. */ void dictFreeUnlinkedEntry(dict *d, dictEntry *he) { if (he == NULL) return; dictFreeKey(d, he); dictFreeVal(d, he); zfree(he); } /* Destroy an entire dictionary */ int _dictClear(dict *d, int htidx, void(callback)(dict*)) { unsigned long i; /* Free all the elements */ for (i = 0; i < DICTHT_SIZE(d->ht_size_exp[htidx]) && d->ht_used[htidx] > 0; i++) { dictEntry *he, *nextHe; if (callback && (i & 65535) == 0) callback(d); if ((he = d->ht_table[htidx][i]) == NULL) continue; while(he) { nextHe = he->next; dictFreeKey(d, he); dictFreeVal(d, he); zfree(he); d->ht_used[htidx]--; he = nextHe; } } /* Free the table and the allocated cache structure */ zfree(d->ht_table[htidx]); /* Re-initialize the table */ _dictReset(d, htidx); return DICT_OK; /* never fails */ } /* Clear & Release the hash table */ void dictRelease(dict *d) { _dictClear(d,0,NULL); _dictClear(d,1,NULL); zfree(d); } dictEntry *dictFind(dict *d, const void *key) { dictEntry *he; uint64_t h, idx, table; if (dictSize(d) == 0) return NULL; /* dict is empty */ if (dictIsRehashing(d)) _dictRehashStep(d); h = dictHashKey(d, key); for (table = 0; table <= 1; table++) { idx = h & DICTHT_SIZE_MASK(d->ht_size_exp[table]); he = d->ht_table[table][idx]; while(he) { if (key==he->key || dictCompareKeys(d, key, he->key)) return he; he = he->next; } if (!dictIsRehashing(d)) return NULL; } return NULL; } void *dictFetchValue(dict *d, const void *key) { dictEntry *he; he = dictFind(d,key); return he ? dictGetVal(he) : NULL; } /* A fingerprint is a 64 bit number that represents the state of the dictionary * at a given time, it's just a few dict properties xored together. * When an unsafe iterator is initialized, we get the dict fingerprint, and check * the fingerprint again when the iterator is released. * If the two fingerprints are different it means that the user of the iterator * performed forbidden operations against the dictionary while iterating. */ unsigned long long dictFingerprint(dict *d) { unsigned long long integers[6], hash = 0; int j; integers[0] = (long) d->ht_table[0]; integers[1] = d->ht_size_exp[0]; integers[2] = d->ht_used[0]; integers[3] = (long) d->ht_table[1]; integers[4] = d->ht_size_exp[1]; integers[5] = d->ht_used[1]; /* We hash N integers by summing every successive integer with the integer * hashing of the previous sum. Basically: * * Result = hash(hash(hash(int1)+int2)+int3) ... * * This way the same set of integers in a different order will (likely) hash * to a different number. */ for (j = 0; j < 6; j++) { hash += integers[j]; /* For the hashing step we use Tomas Wang's 64 bit integer hash. */ hash = (~hash) + (hash << 21); // hash = (hash << 21) - hash - 1; hash = hash ^ (hash >> 24); hash = (hash + (hash << 3)) + (hash << 8); // hash * 265 hash = hash ^ (hash >> 14); hash = (hash + (hash << 2)) + (hash << 4); // hash * 21 hash = hash ^ (hash >> 28); hash = hash + (hash << 31); } return hash; } dictIterator *dictGetIterator(dict *d) { dictIterator *iter = zmalloc(sizeof(*iter)); iter->d = d; iter->table = 0; iter->index = -1; iter->safe = 0; iter->entry = NULL; iter->nextEntry = NULL; return iter; } dictIterator *dictGetSafeIterator(dict *d) { dictIterator *i = dictGetIterator(d); i->safe = 1; return i; } dictEntry *dictNext(dictIterator *iter) { while (1) { if (iter->entry == NULL) { if (iter->index == -1 && iter->table == 0) { if (iter->safe) dictPauseRehashing(iter->d); else iter->fingerprint = dictFingerprint(iter->d); } iter->index++; if (iter->index >= (long) DICTHT_SIZE(iter->d->ht_size_exp[iter->table])) { if (dictIsRehashing(iter->d) && iter->table == 0) { iter->table++; iter->index = 0; } else { break; } } iter->entry = iter->d->ht_table[iter->table][iter->index]; } else { iter->entry = iter->nextEntry; } if (iter->entry) { /* We need to save the 'next' here, the iterator user * may delete the entry we are returning. */ iter->nextEntry = iter->entry->next; return iter->entry; } } return NULL; } void dictReleaseIterator(dictIterator *iter) { if (!(iter->index == -1 && iter->table == 0)) { if (iter->safe) dictResumeRehashing(iter->d); else assert(iter->fingerprint == dictFingerprint(iter->d)); } zfree(iter); } /* Function to reverse bits. Algorithm from: * http://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel */ static unsigned long rev(unsigned long v) { unsigned long s = CHAR_BIT * sizeof(v); // bit size; must be power of 2 unsigned long mask = ~0UL; while ((s >>= 1) > 0) { mask ^= (mask << s); v = ((v >> s) & mask) | ((v << s) & ~mask); } return v; } /* dictScan() is used to iterate over the elements of a dictionary. * * Iterating works the following way: * * 1) Initially you call the function using a cursor (v) value of 0. * 2) The function performs one step of the iteration, and returns the * new cursor value you must use in the next call. * 3) When the returned cursor is 0, the iteration is complete. * * The function guarantees all elements present in the * dictionary get returned between the start and end of the iteration. * However it is possible some elements get returned multiple times. * * For every element returned, the callback argument 'fn' is * called with 'privdata' as first argument and the dictionary entry * 'de' as second argument. * * HOW IT WORKS. * * The iteration algorithm was designed by Pieter Noordhuis. * The main idea is to increment a cursor starting from the higher order * bits. That is, instead of incrementing the cursor normally, the bits * of the cursor are reversed, then the cursor is incremented, and finally * the bits are reversed again. * * This strategy is needed because the hash table may be resized between * iteration calls. * * dict.c hash tables are always power of two in size, and they * use chaining, so the position of an element in a given table is given * by computing the bitwise AND between Hash(key) and SIZE-1 * (where SIZE-1 is always the mask that is equivalent to taking the rest * of the division between the Hash of the key and SIZE). * * For example if the current hash table size is 16, the mask is * (in binary) 1111. The position of a key in the hash table will always be * the last four bits of the hash output, and so forth. * * WHAT HAPPENS IF THE TABLE CHANGES IN SIZE? * * If the hash table grows, elements can go anywhere in one multiple of * the old bucket: for example let's say we already iterated with * a 4 bit cursor 1100 (the mask is 1111 because hash table size = 16). * * If the hash table will be resized to 64 elements, then the new mask will * be 111111. The new buckets you obtain by substituting in ??1100 * with either 0 or 1 can be targeted only by keys we already visited * when scanning the bucket 1100 in the smaller hash table. * * By iterating the higher bits first, because of the inverted counter, the * cursor does not need to restart if the table size gets bigger. It will * continue iterating using cursors without '1100' at the end, and also * without any other combination of the final 4 bits already explored. * * Similarly when the table size shrinks over time, for example going from * 16 to 8, if a combination of the lower three bits (the mask for size 8 * is 111) were already completely explored, it would not be visited again * because we are sure we tried, for example, both 0111 and 1111 (all the * variations of the higher bit) so we don't need to test it again. * * WAIT... YOU HAVE *TWO* TABLES DURING REHASHING! * * Yes, this is true, but we always iterate the smaller table first, then * we test all the expansions of the current cursor into the larger * table. For example if the current cursor is 101 and we also have a * larger table of size 16, we also test (0)101 and (1)101 inside the larger * table. This reduces the problem back to having only one table, where * the larger one, if it exists, is just an expansion of the smaller one. * * LIMITATIONS * * This iterator is completely stateless, and this is a huge advantage, * including no additional memory used. * * The disadvantages resulting from this design are: * * 1) It is possible we return elements more than once. However this is usually * easy to deal with in the application level. * 2) The iterator must return multiple elements per call, as it needs to always * return all the keys chained in a given bucket, and all the expansions, so * we are sure we don't miss keys moving during rehashing. * 3) The reverse cursor is somewhat hard to understand at first, but this * comment is supposed to help. */ unsigned long dictScan(dict *d, unsigned long v, dictScanFunction *fn, dictScanBucketFunction* bucketfn, void *privdata) { int htidx0, htidx1; const dictEntry *de, *next; unsigned long m0, m1; if (dictSize(d) == 0) return 0; /* This is needed in case the scan callback tries to do dictFind or alike. */ dictPauseRehashing(d); if (!dictIsRehashing(d)) { htidx0 = 0; m0 = DICTHT_SIZE_MASK(d->ht_size_exp[htidx0]); /* Emit entries at cursor */ if (bucketfn) bucketfn(d, &d->ht_table[htidx0][v & m0]); de = d->ht_table[htidx0][v & m0]; while (de) { next = de->next; fn(privdata, de); de = next; } /* Set unmasked bits so incrementing the reversed cursor * operates on the masked bits */ v |= ~m0; /* Increment the reverse cursor */ v = rev(v); v++; v = rev(v); } else { htidx0 = 0; htidx1 = 1; /* Make sure t0 is the smaller and t1 is the bigger table */ if (DICTHT_SIZE(d->ht_size_exp[htidx0]) > DICTHT_SIZE(d->ht_size_exp[htidx1])) { htidx0 = 1; htidx1 = 0; } m0 = DICTHT_SIZE_MASK(d->ht_size_exp[htidx0]); m1 = DICTHT_SIZE_MASK(d->ht_size_exp[htidx1]); /* Emit entries at cursor */ if (bucketfn) bucketfn(d, &d->ht_table[htidx0][v & m0]); de = d->ht_table[htidx0][v & m0]; while (de) { next = de->next; fn(privdata, de); de = next; } /* Iterate over indices in larger table that are the expansion * of the index pointed to by the cursor in the smaller table */ do { /* Emit entries at cursor */ if (bucketfn) bucketfn(d, &d->ht_table[htidx1][v & m1]); de = d->ht_table[htidx1][v & m1]; while (de) { next = de->next; fn(privdata, de); de = next; } /* Increment the reverse cursor not covered by the smaller mask.*/ v |= ~m1; v = rev(v); v++; v = rev(v); /* Continue while bits covered by mask difference is non-zero */ } while (v & (m0 ^ m1)); } dictResumeRehashing(d); return v; } /* ------------------------- private functions ------------------------------ */ /* Because we may need to allocate huge memory chunk at once when dict * expands, we will check this allocation is allowed or not if the dict * type has expandAllowed member function. */ static int dictTypeExpandAllowed(dict *d) { if (d->type->expandAllowed == NULL) return 1; return d->type->expandAllowed( DICTHT_SIZE(_dictNextExp(d->ht_used[0] + 1)) * sizeof(dictEntry*), (double)d->ht_used[0] / DICTHT_SIZE(d->ht_size_exp[0])); } /* Expand the hash table if needed */ static int _dictExpandIfNeeded(dict *d) { /* Incremental rehashing already in progress. Return. */ if (dictIsRehashing(d)) return DICT_OK; /* If the hash table is empty expand it to the initial size. */ if (DICTHT_SIZE(d->ht_size_exp[0]) == 0) return dictExpand(d, DICT_HT_INITIAL_SIZE); /* If we reached the 1:1 ratio, and we are allowed to resize the hash * table (global setting) or we should avoid it but the ratio between * elements/buckets is over the "safe" threshold, we resize doubling * the number of buckets. */ if (d->ht_used[0] >= DICTHT_SIZE(d->ht_size_exp[0]) && (dict_can_resize || d->ht_used[0]/ DICTHT_SIZE(d->ht_size_exp[0]) > dict_force_resize_ratio) && dictTypeExpandAllowed(d)) { return dictExpand(d, d->ht_used[0] + 1); } return DICT_OK; } /* TODO: clz optimization */ /* Our hash table capability is a power of two */ static signed char _dictNextExp(unsigned long size) { unsigned char e = DICT_HT_INITIAL_EXP; if (size >= LONG_MAX) return (8*sizeof(long)-1); while(1) { if (((unsigned long)1<= size) return e; e++; } } /* Returns the index of a free slot that can be populated with * a hash entry for the given 'key'. * If the key already exists, -1 is returned * and the optional output parameter may be filled. * * Note that if we are in the process of rehashing the hash table, the * index is always returned in the context of the second (new) hash table. */ static long _dictKeyIndex(dict *d, const void *key, uint64_t hash, dictEntry **existing) { unsigned long idx, table; dictEntry *he; if (existing) *existing = NULL; /* Expand the hash table if needed */ if (_dictExpandIfNeeded(d) == DICT_ERR) return -1; for (table = 0; table <= 1; table++) { idx = hash & DICTHT_SIZE_MASK(d->ht_size_exp[table]); /* Search if this slot does not already contain the given key */ he = d->ht_table[table][idx]; while(he) { if (key==he->key || dictCompareKeys(d, key, he->key)) { if (existing) *existing = he; return -1; } he = he->next; } if (!dictIsRehashing(d)) break; } return idx; } void dictEmpty(dict *d, void(callback)(dict*)) { _dictClear(d,0,callback); _dictClear(d,1,callback); d->rehashidx = -1; d->pauserehash = 0; } void dictEnableResize(void) { dict_can_resize = 1; } void dictDisableResize(void) { dict_can_resize = 0; } uint64_t dictGetHash(dict *d, const void *key) { return dictHashKey(d, key); } /* Finds the dictEntry reference by using pointer and pre-calculated hash. * oldkey is a dead pointer and should not be accessed. * the hash value should be provided using dictGetHash. * no string / key comparison is performed. * return value is the reference to the dictEntry if found, or NULL if not found. */ dictEntry **dictFindEntryRefByPtrAndHash(dict *d, const void *oldptr, uint64_t hash) { dictEntry *he, **heref; unsigned long idx, table; if (dictSize(d) == 0) return NULL; /* dict is empty */ for (table = 0; table <= 1; table++) { idx = hash & DICTHT_SIZE_MASK(d->ht_size_exp[table]); heref = &d->ht_table[table][idx]; he = *heref; while(he) { if (oldptr==he->key) return heref; heref = &he->next; he = *heref; } if (!dictIsRehashing(d)) return NULL; } return NULL; } /* ------------------------------- Debugging ---------------------------------*/ #define DICT_STATS_VECTLEN 50 size_t _dictGetStatsHt(char *buf, size_t bufsize, dict *d, int htidx) { unsigned long i, slots = 0, chainlen, maxchainlen = 0; unsigned long totchainlen = 0; unsigned long clvector[DICT_STATS_VECTLEN]; size_t l = 0; if (d->ht_used[htidx] == 0) { return snprintf(buf,bufsize, "No stats available for empty dictionaries\n"); } /* Compute stats. */ for (i = 0; i < DICT_STATS_VECTLEN; i++) clvector[i] = 0; for (i = 0; i < DICTHT_SIZE(d->ht_size_exp[htidx]); i++) { dictEntry *he; if (d->ht_table[htidx][i] == NULL) { clvector[0]++; continue; } slots++; /* For each hash entry on this slot... */ chainlen = 0; he = d->ht_table[htidx][i]; while(he) { chainlen++; he = he->next; } clvector[(chainlen < DICT_STATS_VECTLEN) ? chainlen : (DICT_STATS_VECTLEN-1)]++; if (chainlen > maxchainlen) maxchainlen = chainlen; totchainlen += chainlen; } /* Generate human readable stats. */ l += snprintf(buf+l,bufsize-l, "Hash table %d stats (%s):\n" " table size: %lu\n" " number of elements: %lu\n" " different slots: %lu\n" " max chain length: %lu\n" " avg chain length (counted): %.02f\n" " avg chain length (computed): %.02f\n" " Chain length distribution:\n", htidx, (htidx == 0) ? "main hash table" : "rehashing target", DICTHT_SIZE(d->ht_size_exp[htidx]), d->ht_used[htidx], slots, maxchainlen, (float)totchainlen/slots, (float)d->ht_used[htidx]/slots); for (i = 0; i < DICT_STATS_VECTLEN-1; i++) { if (clvector[i] == 0) continue; if (l >= bufsize) break; l += snprintf(buf+l,bufsize-l, " %ld: %ld (%.02f%%)\n", i, clvector[i], ((float)clvector[i]/DICTHT_SIZE(d->ht_size_exp[htidx]))*100); } /* Unlike snprintf(), return the number of characters actually written. */ if (bufsize) buf[bufsize-1] = '\0'; return strlen(buf); } void dictGetStats(char *buf, size_t bufsize, dict *d) { size_t l; char *orig_buf = buf; size_t orig_bufsize = bufsize; l = _dictGetStatsHt(buf,bufsize,d,0); buf += l; bufsize -= l; if (dictIsRehashing(d) && bufsize > 0) { _dictGetStatsHt(buf,bufsize,d,1); } /* Make sure there is a NULL term at the end. */ if (orig_bufsize) orig_buf[orig_bufsize-1] = '\0'; } /* ------------------------------- Benchmark ---------------------------------*/ #ifdef REDIS_TEST #include "testhelp.h" #define UNUSED(V) ((void) V) uint64_t hashCallback(const void *key) { return dictGenHashFunction((unsigned char*)key, strlen((char*)key)); } int compareCallback(dict *d, const void *key1, const void *key2) { int l1,l2; UNUSED(d); l1 = strlen((char*)key1); l2 = strlen((char*)key2); if (l1 != l2) return 0; return memcmp(key1, key2, l1) == 0; } void freeCallback(dict *d, void *val) { UNUSED(d); zfree(val); } char *stringFromLongLong(long long value) { char buf[32]; int len; char *s; len = sprintf(buf,"%lld",value); s = zmalloc(len+1); memcpy(s, buf, len); s[len] = '\0'; return s; } dictType BenchmarkDictType = { hashCallback, NULL, NULL, compareCallback, freeCallback, NULL, NULL }; #define start_benchmark() start = timeInMilliseconds() #define end_benchmark(msg) do { \ elapsed = timeInMilliseconds()-start; \ printf(msg ": %ld items in %lld ms\n", count, elapsed); \ } while(0) /* ./redis-server test dict [ | --accurate] */ int dictTest(int argc, char **argv, int flags) { long j; long long start, elapsed; dict *dict = dictCreate(&BenchmarkDictType); long count = 0; int accurate = (flags & REDIS_TEST_ACCURATE); if (argc == 4) { if (accurate) { count = 5000000; } else { count = strtol(argv[3],NULL,10); } } else { count = 5000; } start_benchmark(); for (j = 0; j < count; j++) { int retval = dictAdd(dict,stringFromLongLong(j),(void*)j); assert(retval == DICT_OK); } end_benchmark("Inserting"); assert((long)dictSize(dict) == count); /* Wait for rehashing. */ while (dictIsRehashing(dict)) { dictRehashMilliseconds(dict,100); } start_benchmark(); for (j = 0; j < count; j++) { char *key = stringFromLongLong(j); dictEntry *de = dictFind(dict,key); assert(de != NULL); zfree(key); } end_benchmark("Linear access of existing elements"); start_benchmark(); for (j = 0; j < count; j++) { char *key = stringFromLongLong(j); dictEntry *de = dictFind(dict,key); assert(de != NULL); zfree(key); } end_benchmark("Linear access of existing elements (2nd round)"); start_benchmark(); for (j = 0; j < count; j++) { char *key = stringFromLongLong(rand() % count); dictEntry *de = dictFind(dict,key); assert(de != NULL); zfree(key); } end_benchmark("Random access of existing elements"); start_benchmark(); for (j = 0; j < count; j++) { dictEntry *de = dictGetRandomKey(dict); assert(de != NULL); } end_benchmark("Accessing random keys"); start_benchmark(); for (j = 0; j < count; j++) { char *key = stringFromLongLong(rand() % count); key[0] = 'X'; dictEntry *de = dictFind(dict,key); assert(de == NULL); zfree(key); } end_benchmark("Accessing missing"); start_benchmark(); for (j = 0; j < count; j++) { char *key = stringFromLongLong(j); int retval = dictDelete(dict,key); assert(retval == DICT_OK); key[0] += 17; /* Change first number to letter. */ retval = dictAdd(dict,key,(void*)j); assert(retval == DICT_OK); } end_benchmark("Removing and adding"); dictRelease(dict); return 0; } #endif ================================================ FILE: src/redis/dict.h ================================================ /* Hash Tables Implementation. * * This file implements in-memory hash tables with insert/del/replace/find/ * get-random-element operations. Hash tables will auto-resize if needed * tables of power of two in size are used, collisions are handled by * chaining. See the source code for more information... :) * * Copyright (c) 2006-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __DICT_H #define __DICT_H #include #include #include #define DICT_OK 0 #define DICT_ERR 1 /* Unused arguments generate annoying warnings... */ #define DICT_NOTUSED(V) ((void) V) typedef struct dictEntry { void *key; union { void *val; uint64_t u64; int64_t s64; double d; } v; struct dictEntry *next; /* Next entry in the same hash bucket. */ } dictEntry; typedef struct dict dict; typedef struct dictType { uint64_t (*hashFunction)(const void *key); void *(*keyDup)(dict *d, const void *key); void *(*valDup)(dict *d, const void *obj); int (*keyCompare)(dict *d, const void *key1, const void *key2); void (*keyDestructor)(dict *d, void *key); void (*valDestructor)(dict *d, void *obj); int (*expandAllowed)(size_t moreMem, double usedRatio); } dictType; #define DICTHT_SIZE(exp) ((exp) == -1 ? 0 : (unsigned long)1<<(exp)) #define DICTHT_SIZE_MASK(exp) ((exp) == -1 ? 0 : (DICTHT_SIZE(exp))-1) struct dict { dictType *type; dictEntry **ht_table[2]; unsigned long ht_used[2]; long rehashidx; /* rehashing not in progress if rehashidx == -1 */ /* Keep small vars at end for optimal (minimal) struct padding */ int16_t pauserehash; /* If >0 rehashing is paused (<0 indicates coding error) */ signed char ht_size_exp[2]; /* exponent of size. (size = 1<type->valDestructor) \ (d)->type->valDestructor((d), (entry)->v.val) #define dictSetVal(d, entry, _val_) do { \ if ((d)->type->valDup) \ (entry)->v.val = (d)->type->valDup((d), _val_); \ else \ (entry)->v.val = (_val_); \ } while(0) #define dictSetSignedIntegerVal(entry, _val_) \ do { (entry)->v.s64 = _val_; } while(0) #define dictSetUnsignedIntegerVal(entry, _val_) \ do { (entry)->v.u64 = _val_; } while(0) #define dictSetDoubleVal(entry, _val_) \ do { (entry)->v.d = _val_; } while(0) #define dictFreeKey(d, entry) \ if ((d)->type->keyDestructor) \ (d)->type->keyDestructor((d), (entry)->key) #define dictSetKey(d, entry, _key_) do { \ if ((d)->type->keyDup) \ (entry)->key = (d)->type->keyDup((d), _key_); \ else \ (entry)->key = (_key_); \ } while(0) #define dictCompareKeys(d, key1, key2) \ (((d)->type->keyCompare) ? \ (d)->type->keyCompare((d), key1, key2) : \ (key1) == (key2)) #define dictHashKey(d, key) (d)->type->hashFunction(key) #define dictGetKey(he) ((he)->key) #define dictGetVal(he) ((he)->v.val) #define dictGetSignedIntegerVal(he) ((he)->v.s64) #define dictGetUnsignedIntegerVal(he) ((he)->v.u64) #define dictGetDoubleVal(he) ((he)->v.d) #define dictSlots(d) (DICTHT_SIZE((d)->ht_size_exp[0])+DICTHT_SIZE((d)->ht_size_exp[1])) #define dictSize(d) ((d)->ht_used[0]+(d)->ht_used[1]) #define dictIsRehashing(d) ((d)->rehashidx != -1) #define dictPauseRehashing(d) (d)->pauserehash++ #define dictResumeRehashing(d) (d)->pauserehash-- /* If our unsigned long type can store a 64 bit number, use a 64 bit PRNG. */ #if ULONG_MAX >= 0xffffffffffffffff #define randomULong() ((unsigned long) genrand64_int64()) #else #define randomULong() random() #endif /* API */ dict *dictCreate(dictType *type); int dictExpand(dict *d, unsigned long size); int dictTryExpand(dict *d, unsigned long size); int dictAdd(dict *d, void *key, void *val); dictEntry *dictAddRaw(dict *d, void *key, dictEntry **existing); dictEntry *dictAddOrFind(dict *d, void *key); int dictReplace(dict *d, void *key, void *val); int dictDelete(dict *d, const void *key); dictEntry *dictUnlink(dict *d, const void *key); void dictFreeUnlinkedEntry(dict *d, dictEntry *he); void dictRelease(dict *d); dictEntry * dictFind(dict *d, const void *key); void *dictFetchValue(dict *d, const void *key); int dictResize(dict *d); dictIterator *dictGetIterator(dict *d); dictIterator *dictGetSafeIterator(dict *d); dictEntry *dictNext(dictIterator *iter); void dictReleaseIterator(dictIterator *iter); dictEntry *dictGetRandomKey(dict *d); dictEntry *dictGetFairRandomKey(dict *d); unsigned int dictGetSomeKeys(dict *d, dictEntry **des, unsigned int count); void dictGetStats(char *buf, size_t bufsize, dict *d); uint64_t dictGenHashFunction(const void *key, size_t len); uint64_t dictGenCaseHashFunction(const unsigned char *buf, size_t len); void dictEmpty(dict *d, void(callback)(dict*)); void dictEnableResize(void); void dictDisableResize(void); int dictRehash(dict *d, int n); int dictRehashMilliseconds(dict *d, int ms); void dictSetHashFunctionSeed(uint8_t *seed); uint8_t *dictGetHashFunctionSeed(void); unsigned long dictScan(dict *d, unsigned long v, dictScanFunction *fn, dictScanBucketFunction *bucketfn, void *privdata); uint64_t dictGetHash(dict *d, const void *key); dictEntry **dictFindEntryRefByPtrAndHash(dict *d, const void *oldptr, uint64_t hash); #endif /* __DICT_H */ ================================================ FILE: src/redis/endianconv.h ================================================ /* See endianconv.c top comments for more information * * ---------------------------------------------------------------------------- * * Copyright (c) 2011-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __ENDIANCONV_H #define __ENDIANCONV_H #include "config.h" #include void memrev16(void *p); void memrev32(void *p); void memrev64(void *p); uint16_t intrev16(uint16_t v); uint32_t intrev32(uint32_t v); uint64_t intrev64(uint64_t v); /* variants of the function doing the actual conversion only if the target * host is big endian */ #if (BYTE_ORDER == LITTLE_ENDIAN) #define memrev16ifbe(p) ((void)(0)) #define memrev32ifbe(p) ((void)(0)) #define memrev64ifbe(p) ((void)(0)) #define intrev16ifbe(v) (v) #define intrev32ifbe(v) (v) #define intrev64ifbe(v) (v) #else #define memrev16ifbe(p) memrev16(p) #define memrev32ifbe(p) memrev32(p) #define memrev64ifbe(p) memrev64(p) #define intrev16ifbe(v) intrev16(v) #define intrev32ifbe(v) intrev32(v) #define intrev64ifbe(v) intrev64(v) #endif /* The functions htonu64() and ntohu64() convert the specified value to * network byte ordering and back. In big endian systems they are no-ops. */ #if (BYTE_ORDER == BIG_ENDIAN) #define htonu64(v) (v) #define ntohu64(v) (v) #else #define htonu64(v) intrev64(v) #define ntohu64(v) intrev64(v) #endif #ifdef REDIS_TEST int endianconvTest(int argc, char *argv[], int flags); #endif #endif ================================================ FILE: src/redis/geo.c ================================================ /* * Copyright (c) 2014, Matt Stancliff . * Copyright (c) 2015-2016, Salvatore Sanfilippo . * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include "geo.h" #include "geohash_helper.h" #include "listpack.h" #include "util.h" #include "zmalloc.h" #include "sds.h" // D - noop #define D(...) while (0) /* ==================================================================== * This file implements the following commands: * * - geoadd - add coordinates for value to geoset * - georadius - search radius by coordinates in geoset * - georadiusbymember - search radius based on geoset member position * ==================================================================== */ /* ==================================================================== * geoArray implementation * ==================================================================== */ /* Create a new array of geoPoints. */ geoArray *geoArrayCreate(void) { geoArray *ga = zmalloc(sizeof(*ga)); /* It gets allocated on first geoArrayAppend() call. */ ga->array = NULL; ga->buckets = 0; ga->used = 0; return ga; } /* Add and populate with data a new entry to the geoArray. */ geoPoint *geoArrayAppend(geoArray *ga, double *xy, double dist, double score, char *member) { if (ga->used == ga->buckets) { ga->buckets = (ga->buckets == 0) ? 8 : ga->buckets*2; ga->array = zrealloc(ga->array,sizeof(geoPoint)*ga->buckets); } geoPoint *gp = ga->array+ga->used; gp->longitude = xy[0]; gp->latitude = xy[1]; gp->dist = dist; gp->member = member; gp->score = score; ga->used++; return gp; } /* Destroy a geoArray created with geoArrayCreate(). */ void geoArrayFree(geoArray *ga) { size_t i; for (i = 0; i < ga->used; i++) sdsfree(ga->array[i].member); zfree(ga->array); zfree(ga); } /* ==================================================================== * Helpers * ==================================================================== */ int decodeGeohash(double bits, double *xy) { GeoHashBits hash = { .bits = (uint64_t)bits, .step = GEO_STEP_MAX }; return geohashDecodeToLongLatWGS84(hash, xy); } /* Helper function for geoGetPointsInRange(): given a sorted set score * representing a point, and a GeoShape, checks if the point is within the search area. * * shape: the rectangle * score: the encoded version of lat,long * xy: output variable, the decoded lat,long * distance: output variable, the distance between the center of the shape and the point * * Return values: * * The return value is C_OK if the point is within search area, or C_ERR if it is outside. * "*xy" is populated with the decoded lat,long. * "*distance" is populated with the distance between the center of the shape and the point. */ int geoWithinShape(GeoShape *shape, double score, double *xy, double *distance) { if (!decodeGeohash(score,xy)) return C_ERR; /* Can't decode. */ /* Note that geohashGetDistanceIfInRadiusWGS84() takes arguments in * reverse order: longitude first, latitude later. */ if (shape->type == CIRCULAR_TYPE) { if (!geohashGetDistanceIfInRadiusWGS84(shape->xy[0], shape->xy[1], xy[0], xy[1], shape->t.radius*shape->conversion, distance)) return C_ERR; } else if (shape->type == RECTANGLE_TYPE) { if (!geohashGetDistanceIfInRectangle(shape->t.r.width * shape->conversion, shape->t.r.height * shape->conversion, shape->xy[0], shape->xy[1], xy[0], xy[1], distance)) return C_ERR; } return C_OK; } /* Compute the sorted set scores min (inclusive), max (exclusive) we should * query in order to retrieve all the elements inside the specified area * 'hash'. The two scores are returned by reference in *min and *max. */ void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max) { /* We want to compute the sorted set scores that will include all the * elements inside the specified Geohash 'hash', which has as many * bits as specified by hash.step * 2. * * So if step is, for example, 3, and the hash value in binary * is 101010, since our score is 52 bits we want every element which * is in binary: 101010????????????????????????????????????????????? * Where ? can be 0 or 1. * * To get the min score we just use the initial hash value left * shifted enough to get the 52 bit value. Later we increment the * 6 bit prefix (see the hash.bits++ statement), and get the new * prefix: 101011, which we align again to 52 bits to get the maximum * value (which is excluded from the search). So we get everything * between the two following scores (represented in binary): * * 1010100000000000000000000000000000000000000000000000 (included) * and * 1010110000000000000000000000000000000000000000000000 (excluded). */ *min = geohashAlign52Bits(hash); hash.bits++; *max = geohashAlign52Bits(hash); } ================================================ FILE: src/redis/geo.h ================================================ #ifndef __GEO_H__ #define __GEO_H__ #include /* for size_t */ #include "geohash_helper.h" /* Structures used inside geo.c in order to represent points and array of * points on the earth. */ typedef struct geoPoint { double longitude; double latitude; double dist; double score; char *member; } geoPoint; typedef struct geoArray { struct geoPoint *array; size_t buckets; size_t used; } geoArray; int geoWithinShape(GeoShape *shape, double score, double *xy, double *distance); void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max); #endif ================================================ FILE: src/redis/geohash.c ================================================ /* * Copyright (c) 2013-2014, yinqiwen * Copyright (c) 2014, Matt Stancliff . * Copyright (c) 2015-2016, Salvatore Sanfilippo . * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF * THE POSSIBILITY OF SUCH DAMAGE. */ #include #include "geohash.h" /** * Hashing works like this: * Divide the world into 4 buckets. Label each one as such: * ----------------- * | | | * | | | * | 0,1 | 1,1 | * ----------------- * | | | * | | | * | 0,0 | 1,0 | * ----------------- */ /* Interleave lower bits of x and y, so the bits of x * are in the even positions and bits from y in the odd; * x and y must initially be less than 2**32 (4294967296). * From: https://graphics.stanford.edu/~seander/bithacks.html#InterleaveBMN */ static inline uint64_t interleave64(uint32_t xlo, uint32_t ylo) { static const uint64_t B[] = {0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL, 0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL}; static const unsigned int S[] = {1, 2, 4, 8, 16}; uint64_t x = xlo; uint64_t y = ylo; x = (x | (x << S[4])) & B[4]; y = (y | (y << S[4])) & B[4]; x = (x | (x << S[3])) & B[3]; y = (y | (y << S[3])) & B[3]; x = (x | (x << S[2])) & B[2]; y = (y | (y << S[2])) & B[2]; x = (x | (x << S[1])) & B[1]; y = (y | (y << S[1])) & B[1]; x = (x | (x << S[0])) & B[0]; y = (y | (y << S[0])) & B[0]; return x | (y << 1); } /* reverse the interleave process * derived from http://stackoverflow.com/questions/4909263 */ static inline uint64_t deinterleave64(uint64_t interleaved) { static const uint64_t B[] = {0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL, 0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL}; static const unsigned int S[] = {0, 1, 2, 4, 8, 16}; uint64_t x = interleaved; uint64_t y = interleaved >> 1; x = (x | (x >> S[0])) & B[0]; y = (y | (y >> S[0])) & B[0]; x = (x | (x >> S[1])) & B[1]; y = (y | (y >> S[1])) & B[1]; x = (x | (x >> S[2])) & B[2]; y = (y | (y >> S[2])) & B[2]; x = (x | (x >> S[3])) & B[3]; y = (y | (y >> S[3])) & B[3]; x = (x | (x >> S[4])) & B[4]; y = (y | (y >> S[4])) & B[4]; x = (x | (x >> S[5])) & B[5]; y = (y | (y >> S[5])) & B[5]; return x | (y << 32); } void geohashGetCoordRange(GeoHashRange *long_range, GeoHashRange *lat_range) { /* These are constraints from EPSG:900913 / EPSG:3785 / OSGEO:41001 */ /* We can't geocode at the north/south pole. */ long_range->max = GEO_LONG_MAX; long_range->min = GEO_LONG_MIN; lat_range->max = GEO_LAT_MAX; lat_range->min = GEO_LAT_MIN; } int geohashEncode(const GeoHashRange *long_range, const GeoHashRange *lat_range, double longitude, double latitude, uint8_t step, GeoHashBits *hash) { /* Check basic arguments sanity. */ if (hash == NULL || step > 32 || step == 0 || RANGEPISZERO(lat_range) || RANGEPISZERO(long_range)) return 0; /* Return an error when trying to index outside the supported * constraints. */ if (longitude > GEO_LONG_MAX || longitude < GEO_LONG_MIN || latitude > GEO_LAT_MAX || latitude < GEO_LAT_MIN) return 0; hash->bits = 0; hash->step = step; if (latitude < lat_range->min || latitude > lat_range->max || longitude < long_range->min || longitude > long_range->max) { return 0; } double lat_offset = (latitude - lat_range->min) / (lat_range->max - lat_range->min); double long_offset = (longitude - long_range->min) / (long_range->max - long_range->min); /* convert to fixed point based on the step size */ lat_offset *= (1ULL << step); long_offset *= (1ULL << step); hash->bits = interleave64(lat_offset, long_offset); return 1; } int geohashEncodeType(double longitude, double latitude, uint8_t step, GeoHashBits *hash) { GeoHashRange r[2] = {{0}}; geohashGetCoordRange(&r[0], &r[1]); return geohashEncode(&r[0], &r[1], longitude, latitude, step, hash); } int geohashEncodeWGS84(double longitude, double latitude, uint8_t step, GeoHashBits *hash) { return geohashEncodeType(longitude, latitude, step, hash); } int geohashDecode(const GeoHashRange long_range, const GeoHashRange lat_range, const GeoHashBits hash, GeoHashArea *area) { if (HASHISZERO(hash) || NULL == area || RANGEISZERO(lat_range) || RANGEISZERO(long_range)) { return 0; } area->hash = hash; uint8_t step = hash.step; uint64_t hash_sep = deinterleave64(hash.bits); /* hash = [LAT][LONG] */ double lat_scale = lat_range.max - lat_range.min; double long_scale = long_range.max - long_range.min; uint32_t ilato = hash_sep; /* get lat part of deinterleaved hash */ uint32_t ilono = hash_sep >> 32; /* shift over to get long part of hash */ /* divide by 2**step. * Then, for 0-1 coordinate, multiply times scale and add to the min to get the absolute coordinate. */ area->latitude.min = lat_range.min + (ilato * 1.0 / (1ull << step)) * lat_scale; area->latitude.max = lat_range.min + ((ilato + 1) * 1.0 / (1ull << step)) * lat_scale; area->longitude.min = long_range.min + (ilono * 1.0 / (1ull << step)) * long_scale; area->longitude.max = long_range.min + ((ilono + 1) * 1.0 / (1ull << step)) * long_scale; return 1; } int geohashDecodeType(const GeoHashBits hash, GeoHashArea *area) { GeoHashRange r[2] = {{0}}; geohashGetCoordRange(&r[0], &r[1]); return geohashDecode(r[0], r[1], hash, area); } int geohashDecodeWGS84(const GeoHashBits hash, GeoHashArea *area) { return geohashDecodeType(hash, area); } int geohashDecodeAreaToLongLat(const GeoHashArea *area, double *xy) { if (!xy) return 0; xy[0] = (area->longitude.min + area->longitude.max) / 2; if (xy[0] > GEO_LONG_MAX) xy[0] = GEO_LONG_MAX; if (xy[0] < GEO_LONG_MIN) xy[0] = GEO_LONG_MIN; xy[1] = (area->latitude.min + area->latitude.max) / 2; if (xy[1] > GEO_LAT_MAX) xy[1] = GEO_LAT_MAX; if (xy[1] < GEO_LAT_MIN) xy[1] = GEO_LAT_MIN; return 1; } int geohashDecodeToLongLatType(const GeoHashBits hash, double *xy) { GeoHashArea area; memset(&area, 0, sizeof(area)); if (!xy || !geohashDecodeType(hash, &area)) return 0; return geohashDecodeAreaToLongLat(&area, xy); } int geohashDecodeToLongLatWGS84(const GeoHashBits hash, double *xy) { return geohashDecodeToLongLatType(hash, xy); } static void geohash_move_x(GeoHashBits *hash, int8_t d) { if (d == 0) return; uint64_t x = hash->bits & 0xaaaaaaaaaaaaaaaaULL; uint64_t y = hash->bits & 0x5555555555555555ULL; uint64_t zz = 0x5555555555555555ULL >> (64 - hash->step * 2); if (d > 0) { x = x + (zz + 1); } else { x = x | zz; x = x - (zz + 1); } x &= (0xaaaaaaaaaaaaaaaaULL >> (64 - hash->step * 2)); hash->bits = (x | y); } static void geohash_move_y(GeoHashBits *hash, int8_t d) { if (d == 0) return; uint64_t x = hash->bits & 0xaaaaaaaaaaaaaaaaULL; uint64_t y = hash->bits & 0x5555555555555555ULL; uint64_t zz = 0xaaaaaaaaaaaaaaaaULL >> (64 - hash->step * 2); if (d > 0) { y = y + (zz + 1); } else { y = y | zz; y = y - (zz + 1); } y &= (0x5555555555555555ULL >> (64 - hash->step * 2)); hash->bits = (x | y); } void geohashNeighbors(const GeoHashBits *hash, GeoHashNeighbors *neighbors) { neighbors->east = *hash; neighbors->west = *hash; neighbors->north = *hash; neighbors->south = *hash; neighbors->south_east = *hash; neighbors->south_west = *hash; neighbors->north_east = *hash; neighbors->north_west = *hash; geohash_move_x(&neighbors->east, 1); geohash_move_y(&neighbors->east, 0); geohash_move_x(&neighbors->west, -1); geohash_move_y(&neighbors->west, 0); geohash_move_x(&neighbors->south, 0); geohash_move_y(&neighbors->south, -1); geohash_move_x(&neighbors->north, 0); geohash_move_y(&neighbors->north, 1); geohash_move_x(&neighbors->north_west, -1); geohash_move_y(&neighbors->north_west, 1); geohash_move_x(&neighbors->north_east, 1); geohash_move_y(&neighbors->north_east, 1); geohash_move_x(&neighbors->south_east, 1); geohash_move_y(&neighbors->south_east, -1); geohash_move_x(&neighbors->south_west, -1); geohash_move_y(&neighbors->south_west, -1); } ================================================ FILE: src/redis/geohash.h ================================================ /* * Copyright (c) 2013-2014, yinqiwen * Copyright (c) 2014, Matt Stancliff . * Copyright (c) 2015, Salvatore Sanfilippo . * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF * THE POSSIBILITY OF SUCH DAMAGE. */ #ifndef GEOHASH_H_ #define GEOHASH_H_ #include #include #if defined(__cplusplus) extern "C" { #endif #define HASHISZERO(r) (!(r).bits && !(r).step) #define RANGEISZERO(r) (!(r).max && !(r).min) #define RANGEPISZERO(r) (r == NULL || RANGEISZERO(*r)) #define GEO_STEP_MAX 26 /* 26*2 = 52 bits. */ /* Limits from EPSG:900913 / EPSG:3785 / OSGEO:41001 */ #define GEO_LAT_MIN -85.05112878 #define GEO_LAT_MAX 85.05112878 #define GEO_LONG_MIN -180 #define GEO_LONG_MAX 180 typedef enum { GEOHASH_NORTH = 0, GEOHASH_EAST, GEOHASH_WEST, GEOHASH_SOUTH, GEOHASH_SOUTH_WEST, GEOHASH_SOUTH_EAST, GEOHASH_NORT_WEST, GEOHASH_NORT_EAST } GeoDirection; typedef struct { uint64_t bits; uint8_t step; } GeoHashBits; typedef struct { double min; double max; } GeoHashRange; typedef struct { GeoHashBits hash; GeoHashRange longitude; GeoHashRange latitude; } GeoHashArea; typedef struct { GeoHashBits north; GeoHashBits east; GeoHashBits west; GeoHashBits south; GeoHashBits north_east; GeoHashBits south_east; GeoHashBits north_west; GeoHashBits south_west; } GeoHashNeighbors; #define CIRCULAR_TYPE 1 #define RECTANGLE_TYPE 2 typedef struct { int type; /* search type */ double xy[2]; /* search center point, xy[0]: lon, xy[1]: lat */ double conversion; /* km: 1000 */ double bounds[4]; /* bounds[0]: min_lon, bounds[1]: min_lat * bounds[2]: max_lon, bounds[3]: max_lat */ union { /* CIRCULAR_TYPE */ double radius; /* RECTANGLE_TYPE */ struct { double height; double width; } r; } t; } GeoShape; /* * 0:success * -1:failed */ void geohashGetCoordRange(GeoHashRange *long_range, GeoHashRange *lat_range); int geohashEncode(const GeoHashRange *long_range, const GeoHashRange *lat_range, double longitude, double latitude, uint8_t step, GeoHashBits *hash); int geohashEncodeType(double longitude, double latitude, uint8_t step, GeoHashBits *hash); int geohashEncodeWGS84(double longitude, double latitude, uint8_t step, GeoHashBits *hash); int geohashDecode(const GeoHashRange long_range, const GeoHashRange lat_range, const GeoHashBits hash, GeoHashArea *area); int geohashDecodeType(const GeoHashBits hash, GeoHashArea *area); int geohashDecodeWGS84(const GeoHashBits hash, GeoHashArea *area); int geohashDecodeAreaToLongLat(const GeoHashArea *area, double *xy); int geohashDecodeToLongLatType(const GeoHashBits hash, double *xy); int geohashDecodeToLongLatWGS84(const GeoHashBits hash, double *xy); void geohashNeighbors(const GeoHashBits *hash, GeoHashNeighbors *neighbors); #if defined(__cplusplus) } #endif #endif /* GEOHASH_H_ */ ================================================ FILE: src/redis/geohash_helper.c ================================================ /* * Copyright (c) 2013-2014, yinqiwen * Copyright (c) 2014, Matt Stancliff . * Copyright (c) 2015-2016, Salvatore Sanfilippo . * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF * THE POSSIBILITY OF SUCH DAMAGE. */ /* This is a C++ to C conversion from the ardb project. * This file started out as: * https://github.com/yinqiwen/ardb/blob/d42503/src/geo/geohash_helper.cpp */ #define __USE_XOPEN #include "geohash_helper.h" #include #define D_R (M_PI / 180.0) #define R_MAJOR 6378137.0 #define R_MINOR 6356752.3142 #define RATIO (R_MINOR / R_MAJOR) #define ECCENT (sqrt(1.0 - (RATIO *RATIO))) #define COM (0.5 * ECCENT) /// @brief The usual PI/180 constant const double DEG_TO_RAD = 0.017453292519943295769236907684886; /// @brief Earth's quatratic mean radius for WGS-84 const double EARTH_RADIUS_IN_METERS = 6372797.560856; const double MERCATOR_MAX = 20037726.37; const double MERCATOR_MIN = -20037726.37; static inline double deg_rad(double ang) { return ang * D_R; } static inline double rad_deg(double ang) { return ang / D_R; } /* This function is used in order to estimate the step (bits precision) * of the 9 search area boxes during radius queries. */ uint8_t geohashEstimateStepsByRadius(double range_meters, double lat) { if (range_meters == 0) return 26; int step = 1; while (range_meters < MERCATOR_MAX) { range_meters *= 2; step++; } step -= 2; /* Make sure range is included in most of the base cases. */ /* Wider range towards the poles... Note: it is possible to do better * than this approximation by computing the distance between meridians * at this latitude, but this does the trick for now. */ if (lat > 66 || lat < -66) { step--; if (lat > 80 || lat < -80) step--; } /* Frame to valid range. */ if (step < 1) step = 1; if (step > 26) step = 26; return step; } /* Return the bounding box of the search area by shape (see geohash.h GeoShape) * bounds[0] - bounds[2] is the minimum and maximum longitude * while bounds[1] - bounds[3] is the minimum and maximum latitude. * since the higher the latitude, the shorter the arc length, the box shape is as follows * (left and right edges are actually bent), as shown in the following diagram: * * \-----------------/ -------- \-----------------/ * \ / / \ \ / * \ (long,lat) / / (long,lat) \ \ (long,lat) / * \ / / \ / \ * --------- /----------------\ /---------------\ * Northern Hemisphere Southern Hemisphere Around the equator */ int geohashBoundingBox(GeoShape *shape, double *bounds) { if (!bounds) return 0; double longitude = shape->xy[0]; double latitude = shape->xy[1]; double height = shape->conversion * (shape->type == CIRCULAR_TYPE ? shape->t.radius : shape->t.r.height/2); double width = shape->conversion * (shape->type == CIRCULAR_TYPE ? shape->t.radius : shape->t.r.width/2); const double lat_delta = rad_deg(height/EARTH_RADIUS_IN_METERS); const double long_delta_top = rad_deg(width/EARTH_RADIUS_IN_METERS/cos(deg_rad(latitude+lat_delta))); const double long_delta_bottom = rad_deg(width/EARTH_RADIUS_IN_METERS/cos(deg_rad(latitude-lat_delta))); /* The directions of the northern and southern hemispheres * are opposite, so we choice different points as min/max long/lat */ int southern_hemisphere = latitude < 0 ? 1 : 0; bounds[0] = southern_hemisphere ? longitude-long_delta_bottom : longitude-long_delta_top; bounds[2] = southern_hemisphere ? longitude+long_delta_bottom : longitude+long_delta_top; bounds[1] = latitude - lat_delta; bounds[3] = latitude + lat_delta; return 1; } /* Calculate a set of areas (center + 8) that are able to cover a range query * for the specified position and shape (see geohash.h GeoShape). * the bounding box saved in shaple.bounds */ GeoHashRadius geohashCalculateAreasByShapeWGS84(GeoShape *shape) { GeoHashRange long_range, lat_range; GeoHashRadius radius; GeoHashBits hash; GeoHashNeighbors neighbors; GeoHashArea area; double min_lon, max_lon, min_lat, max_lat; int steps; geohashBoundingBox(shape, shape->bounds); min_lon = shape->bounds[0]; min_lat = shape->bounds[1]; max_lon = shape->bounds[2]; max_lat = shape->bounds[3]; double longitude = shape->xy[0]; double latitude = shape->xy[1]; /* radius_meters is calculated differently in different search types: * 1) CIRCULAR_TYPE, just use radius. * 2) RECTANGLE_TYPE, we use sqrt((width/2)^2 + (height/2)^2) to * calculate the distance from the center point to the corner */ double radius_meters = shape->type == CIRCULAR_TYPE ? shape->t.radius : sqrt((shape->t.r.width/2)*(shape->t.r.width/2) + (shape->t.r.height/2)*(shape->t.r.height/2)); radius_meters *= shape->conversion; steps = geohashEstimateStepsByRadius(radius_meters,latitude); geohashGetCoordRange(&long_range,&lat_range); geohashEncode(&long_range,&lat_range,longitude,latitude,steps,&hash); geohashNeighbors(&hash,&neighbors); geohashDecode(long_range,lat_range,hash,&area); /* Check if the step is enough at the limits of the covered area. * Sometimes when the search area is near an edge of the * area, the estimated step is not small enough, since one of the * north / south / west / east square is too near to the search area * to cover everything. */ int decrease_step = 0; { GeoHashArea north, south, east, west; geohashDecode(long_range, lat_range, neighbors.north, &north); geohashDecode(long_range, lat_range, neighbors.south, &south); geohashDecode(long_range, lat_range, neighbors.east, &east); geohashDecode(long_range, lat_range, neighbors.west, &west); if (north.latitude.max < max_lat) decrease_step = 1; if (south.latitude.min > min_lat) decrease_step = 1; if (east.longitude.max < max_lon) decrease_step = 1; if (west.longitude.min > min_lon) decrease_step = 1; } if (steps > 1 && decrease_step) { steps--; geohashEncode(&long_range,&lat_range,longitude,latitude,steps,&hash); geohashNeighbors(&hash,&neighbors); geohashDecode(long_range,lat_range,hash,&area); } /* Exclude the search areas that are useless. */ if (steps >= 2) { if (area.latitude.min < min_lat) { GZERO(neighbors.south); GZERO(neighbors.south_west); GZERO(neighbors.south_east); } if (area.latitude.max > max_lat) { GZERO(neighbors.north); GZERO(neighbors.north_east); GZERO(neighbors.north_west); } if (area.longitude.min < min_lon) { GZERO(neighbors.west); GZERO(neighbors.south_west); GZERO(neighbors.north_west); } if (area.longitude.max > max_lon) { GZERO(neighbors.east); GZERO(neighbors.south_east); GZERO(neighbors.north_east); } } radius.hash = hash; radius.neighbors = neighbors; radius.area = area; return radius; } GeoHashFix52Bits geohashAlign52Bits(const GeoHashBits hash) { uint64_t bits = hash.bits; bits <<= (52 - hash.step * 2); return bits; } /* Calculate distance using simplified haversine great circle distance formula. * Given longitude diff is 0 the asin(sqrt(a)) on the haversine is asin(sin(abs(u))). * arcsin(sin(x)) equal to x when x ∈[−𝜋/2,𝜋/2]. Given latitude is between [−𝜋/2,𝜋/2] * we can simplify arcsin(sin(x)) to x. */ double geohashGetLatDistance(double lat1d, double lat2d) { return EARTH_RADIUS_IN_METERS * fabs(deg_rad(lat2d) - deg_rad(lat1d)); } /* Calculate distance using haversine great circle distance formula. */ double geohashGetDistance(double lon1d, double lat1d, double lon2d, double lat2d) { double lat1r, lon1r, lat2r, lon2r, u, v, a; lon1r = deg_rad(lon1d); lon2r = deg_rad(lon2d); v = sin((lon2r - lon1r) / 2); /* if v == 0 we can avoid doing expensive math when lons are practically the same */ if (v == 0.0) return geohashGetLatDistance(lat1d, lat2d); lat1r = deg_rad(lat1d); lat2r = deg_rad(lat2d); u = sin((lat2r - lat1r) / 2); a = u * u + cos(lat1r) * cos(lat2r) * v * v; return 2.0 * EARTH_RADIUS_IN_METERS * asin(sqrt(a)); } int geohashGetDistanceIfInRadius(double x1, double y1, double x2, double y2, double radius, double *distance) { *distance = geohashGetDistance(x1, y1, x2, y2); if (*distance > radius) return 0; return 1; } int geohashGetDistanceIfInRadiusWGS84(double x1, double y1, double x2, double y2, double radius, double *distance) { return geohashGetDistanceIfInRadius(x1, y1, x2, y2, radius, distance); } /* Judge whether a point is in the axis-aligned rectangle, when the distance * between a searched point and the center point is less than or equal to * height/2 or width/2 in height and width, the point is in the rectangle. * * width_m, height_m: the rectangle * x1, y1 : the center of the box * x2, y2 : the point to be searched */ int geohashGetDistanceIfInRectangle(double width_m, double height_m, double x1, double y1, double x2, double y2, double *distance) { /* latitude distance is less expensive to compute than longitude distance * so we check first for the latitude condition */ double lat_distance = geohashGetLatDistance(y2, y1); if (lat_distance > height_m/2) { return 0; } double lon_distance = geohashGetDistance(x2, y2, x1, y2); if (lon_distance > width_m/2) { return 0; } *distance = geohashGetDistance(x1, y1, x2, y2); return 1; } ================================================ FILE: src/redis/geohash_helper.h ================================================ /* * Copyright (c) 2013-2014, yinqiwen * Copyright (c) 2014, Matt Stancliff . * Copyright (c) 2015, Salvatore Sanfilippo . * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF * THE POSSIBILITY OF SUCH DAMAGE. */ #ifndef GEOHASH_HELPER_HPP_ #define GEOHASH_HELPER_HPP_ #include "geohash.h" #define GZERO(s) s.bits = s.step = 0; #define GISZERO(s) (!s.bits && !s.step) #define GISNOTZERO(s) (s.bits || s.step) typedef uint64_t GeoHashFix52Bits; typedef uint64_t GeoHashVarBits; typedef struct { GeoHashBits hash; GeoHashArea area; GeoHashNeighbors neighbors; } GeoHashRadius; uint8_t geohashEstimateStepsByRadius(double range_meters, double lat); int geohashBoundingBox(GeoShape *shape, double *bounds); GeoHashRadius geohashCalculateAreasByShapeWGS84(GeoShape *shape); GeoHashFix52Bits geohashAlign52Bits(const GeoHashBits hash); double geohashGetDistance(double lon1d, double lat1d, double lon2d, double lat2d); int geohashGetDistanceIfInRadius(double x1, double y1, double x2, double y2, double radius, double *distance); int geohashGetDistanceIfInRadiusWGS84(double x1, double y1, double x2, double y2, double radius, double *distance); int geohashGetDistanceIfInRectangle(double width_m, double height_m, double x1, double y1, double x2, double y2, double *distance); #endif /* GEOHASH_HELPER_HPP_ */ ================================================ FILE: src/redis/hiredis.c ================================================ /* * Copyright (c) 2009-2011, Salvatore Sanfilippo * Copyright (c) 2010-2014, Pieter Noordhuis * Copyright (c) 2015, Matt Stancliff , * Jan-Erik Rediger * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include "hiredis.h" #include "sds.h" static redisReply *createReplyObject(int type); static void *createStringObject(const redisReadTask *task, char *str, size_t len); static void *createArrayObject(const redisReadTask *task, size_t elements); static void *createIntegerObject(const redisReadTask *task, long long value); static void *createDoubleObject(const redisReadTask *task, double value, char *str, size_t len); static void *createNilObject(const redisReadTask *task); static void *createBoolObject(const redisReadTask *task, int bval); /* Default set of functions to build the reply. Keep in mind that such a * function returning NULL is interpreted as OOM. */ static redisReplyObjectFunctions defaultFunctions = { createStringObject, createArrayObject, createIntegerObject, createDoubleObject, createNilObject, createBoolObject, freeReplyObject }; /* Create a reply object */ static redisReply *createReplyObject(int type) { redisReply *r = s_calloc(sizeof(*r)); if (r == NULL) return NULL; r->type = type; return r; } /* Free a reply object */ void freeReplyObject(void *reply) { redisReply *r = reply; size_t j; if (r == NULL) return; switch(r->type) { case REDIS_REPLY_INTEGER: case REDIS_REPLY_NIL: case REDIS_REPLY_BOOL: break; /* Nothing to free */ case REDIS_REPLY_ARRAY: case REDIS_REPLY_MAP: case REDIS_REPLY_ATTR: case REDIS_REPLY_SET: case REDIS_REPLY_PUSH: if (r->element != NULL) { for (j = 0; j < r->elements; j++) freeReplyObject(r->element[j]); s_free(r->element); } break; case REDIS_REPLY_ERROR: case REDIS_REPLY_STATUS: case REDIS_REPLY_STRING: case REDIS_REPLY_DOUBLE: case REDIS_REPLY_VERB: case REDIS_REPLY_BIGNUM: s_free(r->str); break; } s_free(r); } static void *createStringObject(const redisReadTask *task, char *str, size_t len) { redisReply *r, *parent; char *buf; r = createReplyObject(task->type); if (r == NULL) return NULL; assert(task->type == REDIS_REPLY_ERROR || task->type == REDIS_REPLY_STATUS || task->type == REDIS_REPLY_STRING || task->type == REDIS_REPLY_VERB || task->type == REDIS_REPLY_BIGNUM); /* Copy string value */ if (task->type == REDIS_REPLY_VERB) { buf = s_malloc(len-4+1); /* Skip 4 bytes of verbatim type header. */ if (buf == NULL) goto oom; memcpy(r->vtype,str,3); r->vtype[3] = '\0'; memcpy(buf,str+4,len-4); buf[len-4] = '\0'; r->len = len - 4; } else { buf = s_malloc(len+1); if (buf == NULL) goto oom; memcpy(buf,str,len); buf[len] = '\0'; r->len = len; } r->str = buf; if (task->parent) { parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY || parent->type == REDIS_REPLY_MAP || parent->type == REDIS_REPLY_ATTR || parent->type == REDIS_REPLY_SET || parent->type == REDIS_REPLY_PUSH); parent->element[task->idx] = r; } return r; oom: freeReplyObject(r); return NULL; } static void *createArrayObject(const redisReadTask *task, size_t elements) { redisReply *r, *parent; r = createReplyObject(task->type); if (r == NULL) return NULL; if (elements > 0) { r->element = s_calloc(elements * sizeof(redisReply*)); if (r->element == NULL) { freeReplyObject(r); return NULL; } } r->elements = elements; if (task->parent) { parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY || parent->type == REDIS_REPLY_MAP || parent->type == REDIS_REPLY_ATTR || parent->type == REDIS_REPLY_SET || parent->type == REDIS_REPLY_PUSH); parent->element[task->idx] = r; } return r; } static void *createIntegerObject(const redisReadTask *task, long long value) { redisReply *r, *parent; r = createReplyObject(REDIS_REPLY_INTEGER); if (r == NULL) return NULL; r->integer = value; if (task->parent) { parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY || parent->type == REDIS_REPLY_MAP || parent->type == REDIS_REPLY_ATTR || parent->type == REDIS_REPLY_SET || parent->type == REDIS_REPLY_PUSH); parent->element[task->idx] = r; } return r; } static void *createDoubleObject(const redisReadTask *task, double value, char *str, size_t len) { redisReply *r, *parent; if (len == SIZE_MAX) // Prevents s_malloc(0) if len equals to SIZE_MAX return NULL; r = createReplyObject(REDIS_REPLY_DOUBLE); if (r == NULL) return NULL; r->dval = value; r->str = s_malloc(len+1); if (r->str == NULL) { freeReplyObject(r); return NULL; } /* The double reply also has the original protocol string representing a * double as a null terminated string. This way the caller does not need * to format back for string conversion, especially since Redis does efforts * to make the string more human readable avoiding the calssical double * decimal string conversion artifacts. */ memcpy(r->str, str, len); r->str[len] = '\0'; r->len = len; if (task->parent) { parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY || parent->type == REDIS_REPLY_MAP || parent->type == REDIS_REPLY_ATTR || parent->type == REDIS_REPLY_SET || parent->type == REDIS_REPLY_PUSH); parent->element[task->idx] = r; } return r; } static void *createNilObject(const redisReadTask *task) { int type = task->type; int is_aggregate = (type == REDIS_REPLY_ARRAY || type == REDIS_REPLY_MAP || type == REDIS_REPLY_SET || type == REDIS_REPLY_PUSH); /* For aggregate nils (*-1, etc.) preserve the original aggregate type * with SIZE_MAX elements as a sentinel, so callers can distinguish * null arrays from null bulk strings. */ if (is_aggregate) { void *obj = createArrayObject(task, 0); if (obj == NULL) return NULL; ((redisReply*)obj)->elements = SIZE_MAX; return obj; } redisReply *r, *parent; r = createReplyObject(REDIS_REPLY_NIL); if (r == NULL) return NULL; if (task->parent) { parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY || parent->type == REDIS_REPLY_MAP || parent->type == REDIS_REPLY_ATTR || parent->type == REDIS_REPLY_SET || parent->type == REDIS_REPLY_PUSH); parent->element[task->idx] = r; } return r; } static void *createBoolObject(const redisReadTask *task, int bval) { redisReply *r, *parent; r = createReplyObject(REDIS_REPLY_BOOL); if (r == NULL) return NULL; r->integer = bval != 0; if (task->parent) { parent = task->parent->obj; assert(parent->type == REDIS_REPLY_ARRAY || parent->type == REDIS_REPLY_MAP || parent->type == REDIS_REPLY_ATTR || parent->type == REDIS_REPLY_SET || parent->type == REDIS_REPLY_PUSH); parent->element[task->idx] = r; } return r; } /* Return the number of digits of 'v' when converted to string in radix 10. * Implementation borrowed from link in redis/src/util.c:string2ll(). */ static uint32_t countDigits(uint64_t v) { uint32_t result = 1; for (;;) { if (v < 10) return result; if (v < 100) return result + 1; if (v < 1000) return result + 2; if (v < 10000) return result + 3; v /= 10000U; result += 4; } } /* Helper that calculates the bulk length given a certain string length. */ static size_t bulklen(size_t len) { return 1+countDigits(len)+2+len+2; } int redisvFormatCommand(char **target, const char *format, va_list ap) { const char *c = format; char *cmd = NULL; /* final command */ int pos; /* position in final command */ sds curarg, newarg; /* current argument */ int touched = 0; /* was the current argument touched? */ char **curargv = NULL, **newargv = NULL; int argc = 0; int totlen = 0; int error_type = 0; /* 0 = no error; -1 = memory error; -2 = format error */ int j; /* Abort if there is not target to set */ if (target == NULL) return -1; /* Build the command string accordingly to protocol */ curarg = sdsempty(); if (curarg == NULL) return -1; while(*c != '\0') { if (*c != '%' || c[1] == '\0') { if (*c == ' ') { if (touched) { newargv = s_realloc(curargv,sizeof(char*)*(argc+1)); if (newargv == NULL) goto memory_err; curargv = newargv; curargv[argc++] = curarg; totlen += bulklen(sdslen(curarg)); /* curarg is put in argv so it can be overwritten. */ curarg = sdsempty(); if (curarg == NULL) goto memory_err; touched = 0; } } else { newarg = sdscatlen(curarg,c,1); if (newarg == NULL) goto memory_err; curarg = newarg; touched = 1; } } else { char *arg; size_t size; /* Set newarg so it can be checked even if it is not touched. */ newarg = curarg; switch(c[1]) { case 's': arg = va_arg(ap,char*); size = strlen(arg); if (size > 0) newarg = sdscatlen(curarg,arg,size); break; case 'b': arg = va_arg(ap,char*); size = va_arg(ap,size_t); if (size > 0) newarg = sdscatlen(curarg,arg,size); break; case '%': newarg = sdscat(curarg,"%"); break; default: /* Try to detect printf format */ { static const char intfmts[] = "diouxX"; static const char flags[] = "#0-+ "; char _format[16]; const char *_p = c+1; size_t _l = 0; va_list _cpy; /* Flags */ while (*_p != '\0' && strchr(flags,*_p) != NULL) _p++; /* Field width */ while (*_p != '\0' && isdigit((int) *_p)) _p++; /* Precision */ if (*_p == '.') { _p++; while (*_p != '\0' && isdigit((int) *_p)) _p++; } /* Copy va_list before consuming with va_arg */ va_copy(_cpy,ap); /* Make sure we have more characters otherwise strchr() accepts * '\0' as an integer specifier. This is checked after above * va_copy() to avoid UB in fmt_invalid's call to va_end(). */ if (*_p == '\0') goto fmt_invalid; /* Integer conversion (without modifiers) */ if (strchr(intfmts,*_p) != NULL) { va_arg(ap,int); goto fmt_valid; } /* Double conversion (without modifiers) */ if (strchr("eEfFgGaA",*_p) != NULL) { va_arg(ap,double); goto fmt_valid; } /* Size: char */ if (_p[0] == 'h' && _p[1] == 'h') { _p += 2; if (*_p != '\0' && strchr(intfmts,*_p) != NULL) { va_arg(ap,int); /* char gets promoted to int */ goto fmt_valid; } goto fmt_invalid; } /* Size: short */ if (_p[0] == 'h') { _p += 1; if (*_p != '\0' && strchr(intfmts,*_p) != NULL) { va_arg(ap,int); /* short gets promoted to int */ goto fmt_valid; } goto fmt_invalid; } /* Size: long long */ if (_p[0] == 'l' && _p[1] == 'l') { _p += 2; if (*_p != '\0' && strchr(intfmts,*_p) != NULL) { va_arg(ap,long long); goto fmt_valid; } goto fmt_invalid; } /* Size: long */ if (_p[0] == 'l') { _p += 1; if (*_p != '\0' && strchr(intfmts,*_p) != NULL) { va_arg(ap,long); goto fmt_valid; } goto fmt_invalid; } fmt_invalid: va_end(_cpy); goto format_err; fmt_valid: _l = (_p+1)-c; if (_l < sizeof(_format)-2) { memcpy(_format,c,_l); _format[_l] = '\0'; newarg = sdscatvprintf(curarg,_format,_cpy); /* Update current position (note: outer blocks * increment c twice so compensate here) */ c = _p-1; } va_end(_cpy); break; } } if (newarg == NULL) goto memory_err; curarg = newarg; touched = 1; c++; if (*c == '\0') break; } c++; } /* Add the last argument if needed */ if (touched) { newargv = s_realloc(curargv,sizeof(char*)*(argc+1)); if (newargv == NULL) goto memory_err; curargv = newargv; curargv[argc++] = curarg; totlen += bulklen(sdslen(curarg)); } else { sdsfree(curarg); } /* Clear curarg because it was put in curargv or was free'd. */ curarg = NULL; /* Add bytes needed to hold multi bulk count */ totlen += 1+countDigits(argc)+2; /* Build the command at protocol level */ cmd = s_malloc(totlen+1); if (cmd == NULL) goto memory_err; pos = sprintf(cmd,"*%d\r\n",argc); for (j = 0; j < argc; j++) { pos += sprintf(cmd+pos,"$%zu\r\n",sdslen(curargv[j])); memcpy(cmd+pos,curargv[j],sdslen(curargv[j])); pos += sdslen(curargv[j]); sdsfree(curargv[j]); cmd[pos++] = '\r'; cmd[pos++] = '\n'; } assert(pos == totlen); cmd[pos] = '\0'; s_free(curargv); *target = cmd; return totlen; format_err: error_type = -2; goto cleanup; memory_err: error_type = -1; goto cleanup; cleanup: if (curargv) { while(argc--) sdsfree(curargv[argc]); s_free(curargv); } sdsfree(curarg); s_free(cmd); return error_type; } /* Format a command according to the Redis protocol. This function * takes a format similar to printf: * * %s represents a C null terminated string you want to interpolate * %b represents a binary safe string * * When using %b you need to provide both the pointer to the string * and the length in bytes as a size_t. Examples: * * len = redisFormatCommand(target, "GET %s", mykey); * len = redisFormatCommand(target, "SET %s %b", mykey, myval, myvallen); */ int redisFormatCommand(char **target, const char *format, ...) { va_list ap; int len; va_start(ap,format); len = redisvFormatCommand(target,format,ap); va_end(ap); /* The API says "-1" means bad result, but we now also return "-2" in some * cases. Force the return value to always be -1. */ if (len < 0) len = -1; return len; } /* Format a command according to the Redis protocol using an sds string and * sdscatfmt for the processing of arguments. This function takes the * number of arguments, an array with arguments and an array with their * lengths. If the latter is set to NULL, strlen will be used to compute the * argument lengths. */ long long redisFormatSdsCommandArgv(sds *target, int argc, const char **argv, const size_t *argvlen) { sds cmd, aux; unsigned long long totlen, len; int j; /* Abort on a NULL target */ if (target == NULL) return -1; /* Calculate our total size */ totlen = 1+countDigits(argc)+2; for (j = 0; j < argc; j++) { len = argvlen ? argvlen[j] : strlen(argv[j]); totlen += bulklen(len); } /* Use an SDS string for command construction */ cmd = sdsempty(); if (cmd == NULL) return -1; /* We already know how much storage we need */ aux = sdsMakeRoomFor(cmd, totlen); if (aux == NULL) { sdsfree(cmd); return -1; } cmd = aux; /* Construct command */ cmd = sdscatfmt(cmd, "*%i\r\n", argc); for (j=0; j < argc; j++) { len = argvlen ? argvlen[j] : strlen(argv[j]); cmd = sdscatfmt(cmd, "$%U\r\n", len); cmd = sdscatlen(cmd, argv[j], len); cmd = sdscatlen(cmd, "\r\n", sizeof("\r\n")-1); } assert(sdslen(cmd)==totlen); *target = cmd; return totlen; } void redisFreeSdsCommand(sds cmd) { sdsfree(cmd); } /* Format a command according to the Redis protocol. This function takes the * number of arguments, an array with arguments and an array with their * lengths. If the latter is set to NULL, strlen will be used to compute the * argument lengths. */ long long redisFormatCommandArgv(char **target, int argc, const char **argv, const size_t *argvlen) { char *cmd = NULL; /* final command */ size_t pos; /* position in final command */ size_t len, totlen; int j; /* Abort on a NULL target */ if (target == NULL) return -1; /* Calculate number of bytes needed for the command */ totlen = 1+countDigits(argc)+2; for (j = 0; j < argc; j++) { len = argvlen ? argvlen[j] : strlen(argv[j]); totlen += bulklen(len); } /* Build the command at protocol level */ cmd = s_malloc(totlen+1); if (cmd == NULL) return -1; pos = sprintf(cmd,"*%d\r\n",argc); for (j = 0; j < argc; j++) { len = argvlen ? argvlen[j] : strlen(argv[j]); pos += sprintf(cmd+pos,"$%zu\r\n",len); memcpy(cmd+pos,argv[j],len); pos += len; cmd[pos++] = '\r'; cmd[pos++] = '\n'; } assert(pos == totlen); cmd[pos] = '\0'; *target = cmd; return totlen; } void redisFreeCommand(char *cmd) { s_free(cmd); } redisReader *redisReaderCreate(void) { return redisReaderCreateWithFunctions(&defaultFunctions); } ================================================ FILE: src/redis/hiredis.h ================================================ /* * Copyright (c) 2009-2011, Salvatore Sanfilippo * Copyright (c) 2010-2014, Pieter Noordhuis * Copyright (c) 2015, Matt Stancliff , * Jan-Erik Rediger * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __HIREDIS_H #define __HIREDIS_H #include "read.h" #include /* for va_list */ #ifndef _MSC_VER #include /* for struct timeval */ #else struct timeval; /* forward declaration */ typedef long long ssize_t; #endif #include /* uintXX_t, etc */ #include "sds.h" /* for sds */ #include "sdsalloc.h" /* for allocation wrappers */ #define HIREDIS_MAJOR 1 #define HIREDIS_MINOR 3 #define HIREDIS_PATCH 0 #define HIREDIS_SONAME 1.3.0 /* Connection type can be blocking or non-blocking and is set in the * least significant bit of the flags field in redisContext. */ #define REDIS_BLOCK 0x1 /* Connection may be disconnected before being free'd. The second bit * in the flags field is set when the context is connected. */ #define REDIS_CONNECTED 0x2 /* The async API might try to disconnect cleanly and flush the output * buffer and read all subsequent replies before disconnecting. * This flag means no new commands can come in and the connection * should be terminated once all replies have been read. */ #define REDIS_DISCONNECTING 0x4 /* Flag specific to the async API which means that the context should be clean * up as soon as possible. */ #define REDIS_FREEING 0x8 /* Flag that is set when an async callback is executed. */ #define REDIS_IN_CALLBACK 0x10 /* Flag that is set when the async context has one or more subscriptions. */ #define REDIS_SUBSCRIBED 0x20 /* Flag that is set when monitor mode is active */ #define REDIS_MONITORING 0x40 /* Flag that is set when we should set SO_REUSEADDR before calling bind() */ #define REDIS_REUSEADDR 0x80 /* Flag that is set when the async connection supports push replies. */ #define REDIS_SUPPORTS_PUSH 0x100 /** * Flag that indicates the user does not want the context to * be automatically freed upon error */ #define REDIS_NO_AUTO_FREE 0x200 /* Flag that indicates the user does not want replies to be automatically freed */ #define REDIS_NO_AUTO_FREE_REPLIES 0x400 /* Flags to prefer IPv6 or IPv4 when doing DNS lookup. (If both are set, * AF_UNSPEC is used.) */ #define REDIS_PREFER_IPV4 0x800 #define REDIS_PREFER_IPV6 0x1000 #define REDIS_KEEPALIVE_INTERVAL 15 /* seconds */ /* number of times we retry to connect in the case of EADDRNOTAVAIL and * SO_REUSEADDR is being used. */ #define REDIS_CONNECT_RETRIES 10 /* Forward declarations for structs defined elsewhere */ struct redisAsyncContext; struct redisContext; /* RESP3 push helpers and callback prototypes */ #define redisIsPushReply(r) (((redisReply*)(r))->type == REDIS_REPLY_PUSH) typedef void (redisPushFn)(void *, void *); typedef void (redisAsyncPushFn)(struct redisAsyncContext *, void *); #ifdef __cplusplus extern "C" { #endif /* This is the reply object returned by redisCommand() */ typedef struct redisReply { int type; /* REDIS_REPLY_* */ long long integer; /* The integer when type is REDIS_REPLY_INTEGER */ double dval; /* The double when type is REDIS_REPLY_DOUBLE */ size_t len; /* Length of string */ char *str; /* Used for REDIS_REPLY_ERROR, REDIS_REPLY_STRING REDIS_REPLY_VERB, REDIS_REPLY_DOUBLE (in additional to dval), and REDIS_REPLY_BIGNUM. */ char vtype[4]; /* Used for REDIS_REPLY_VERB, contains the null terminated 3 character content type, such as "txt". */ size_t elements; /* number of elements, for REDIS_REPLY_ARRAY */ struct redisReply **element; /* elements vector for REDIS_REPLY_ARRAY */ } redisReply; redisReader *redisReaderCreate(void); /* Function to free the reply objects hiredis returns by default. */ void freeReplyObject(void *reply); /* Functions to format a command according to the protocol. */ int redisvFormatCommand(char **target, const char *format, va_list ap); int redisFormatCommand(char **target, const char *format, ...); long long redisFormatCommandArgv(char **target, int argc, const char **argv, const size_t *argvlen); long long redisFormatSdsCommandArgv(sds *target, int argc, const char ** argv, const size_t *argvlen); void redisFreeCommand(char *cmd); void redisFreeSdsCommand(sds cmd); #ifdef __cplusplus } #endif #endif ================================================ FILE: src/redis/hyperloglog.c ================================================ /* hyperloglog.c - Redis HyperLogLog probabilistic cardinality approximation. * This file implements the algorithm and the exported Redis commands. * * Copyright (c) 2014, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include "redis/hyperloglog.h" #include #include #include "redis/redis_aux.h" #include "redis/util.h" #define min(a, b) ((a) < (b) ? (a) : (b)) /* The Redis HyperLogLog implementation is based on the following ideas: * * * The use of a 64 bit hash function as proposed in [1], in order to estimate * cardinalities larger than 10^9, at the cost of just 1 additional bit per * register. * * The use of 16384 6-bit registers for a great level of accuracy, using * a total of 12k per key. * * The use of the Redis string data type. No new type is introduced. * * No attempt is made to compress the data structure as in [1]. Also the * algorithm used is the original HyperLogLog Algorithm as in [2], with * the only difference that a 64 bit hash function is used, so no correction * is performed for values near 2^32 as in [1]. * * [1] Heule, Nunkesser, Hall: HyperLogLog in Practice: Algorithmic * Engineering of a State of The Art Cardinality Estimation Algorithm. * * [2] P. Flajolet, Éric Fusy, O. Gandouet, and F. Meunier. Hyperloglog: The * analysis of a near-optimal cardinality estimation algorithm. * * Redis uses two representations: * * 1) A "dense" representation where every entry is represented by * a 6-bit integer. * 2) A "sparse" representation using run length compression suitable * for representing HyperLogLogs with many registers set to 0 in * a memory efficient way. * * * HLL header * === * * Both the dense and sparse representation have a 16 byte header as follows: * * +------+---+-----+----------+ * | HYLL | E | N/U | Cardin. | * +------+---+-----+----------+ * * The first 4 bytes are a magic string set to the bytes "HYLL". * "E" is one byte encoding, currently set to HLL_DENSE or * HLL_SPARSE. N/U are three not used bytes. * * The "Cardin." field is a 64 bit integer stored in little endian format * with the latest cardinality computed that can be reused if the data * structure was not modified since the last computation (this is useful * because there are high probabilities that HLLADD operations don't * modify the actual data structure and hence the approximated cardinality). * * When the most significant bit in the most significant byte of the cached * cardinality is set, it means that the data structure was modified and * we can't reuse the cached value that must be recomputed. * * Dense representation * === * * The dense representation used by Redis is the following: * * +--------+--------+--------+------// //--+ * |11000000|22221111|33333322|55444444 .... | * +--------+--------+--------+------// //--+ * * The 6 bits counters are encoded one after the other starting from the * LSB to the MSB, and using the next bytes as needed. * * Sparse representation * === * * The sparse representation encodes registers using a run length * encoding composed of three opcodes, two using one byte, and one using * of two bytes. The opcodes are called ZERO, XZERO and VAL. * * ZERO opcode is represented as 00xxxxxx. The 6-bit integer represented * by the six bits 'xxxxxx', plus 1, means that there are N registers set * to 0. This opcode can represent from 1 to 64 contiguous registers set * to the value of 0. * * XZERO opcode is represented by two bytes 01xxxxxx yyyyyyyy. The 14-bit * integer represented by the bits 'xxxxxx' as most significant bits and * 'yyyyyyyy' as least significant bits, plus 1, means that there are N * registers set to 0. This opcode can represent from 0 to 16384 contiguous * registers set to the value of 0. * * VAL opcode is represented as 1vvvvvxx. It contains a 5-bit integer * representing the value of a register, and a 2-bit integer representing * the number of contiguous registers set to that value 'vvvvv'. * To obtain the value and run length, the integers vvvvv and xx must be * incremented by one. This opcode can represent values from 1 to 32, * repeated from 1 to 4 times. * * The sparse representation can't represent registers with a value greater * than 32, however it is very unlikely that we find such a register in an * HLL with a cardinality where the sparse representation is still more * memory efficient than the dense representation. When this happens the * HLL is converted to the dense representation. * * The sparse representation is purely positional. For example a sparse * representation of an empty HLL is just: XZERO:16384. * * An HLL having only 3 non-zero registers at position 1000, 1020, 1021 * respectively set to 2, 3, 3, is represented by the following three * opcodes: * * XZERO:1000 (Registers 0-999 are set to 0) * VAL:2,1 (1 register set to value 2, that is register 1000) * ZERO:19 (Registers 1001-1019 set to 0) * VAL:3,2 (2 registers set to value 3, that is registers 1020,1021) * XZERO:15362 (Registers 1022-16383 set to 0) * * In the example the sparse representation used just 7 bytes instead * of 12k in order to represent the HLL registers. In general for low * cardinality there is a big win in terms of space efficiency, traded * with CPU time since the sparse representation is slower to access. * * The following table shows average cardinality vs bytes used, 100 * samples per cardinality (when the set was not representable because * of registers with too big value, the dense representation size was used * as a sample). * * 100 267 * 200 485 * 300 678 * 400 859 * 500 1033 * 600 1205 * 700 1375 * 800 1544 * 900 1713 * 1000 1882 * 2000 3480 * 3000 4879 * 4000 6089 * 5000 7138 * 6000 8042 * 7000 8823 * 8000 9500 * 9000 10088 * 10000 10591 * * The dense representation uses 12288 bytes, so there is a big win up to * a cardinality of ~2000-3000. For bigger cardinalities the constant times * involved in updating the sparse representation is not justified by the * memory savings. The exact maximum length of the sparse representation * when this implementation switches to the dense representation is * configured via the define HLL_SPARSE_MAX_BYTES. */ #define HLL_SPARSE_MAX_BYTES 3000 struct hllhdr { char magic[4]; /* "HYLL" */ uint8_t encoding; /* HLL_DENSE or HLL_SPARSE. */ uint8_t notused[3]; /* Reserved for future use, must be zero. */ uint8_t card[8]; /* Cached cardinality, little endian. */ uint8_t registers[]; /* Data bytes. */ }; /* The cached cardinality MSB is used to signal validity of the cached value. */ #define HLL_INVALIDATE_CACHE(hdr) (hdr)->card[7] |= (1 << 7) #define HLL_VALID_CACHE(hdr) (((hdr)->card[7] & (1 << 7)) == 0) #define HLL_P 14 /* The greater is P, the smaller the error. */ #define HLL_Q \ (64 - HLL_P) /* The number of bits of the hash value used for \ determining the number of leading zeros. */ #define HLL_REGISTERS (1 << HLL_P) /* With P=14, 16384 registers. */ #define HLL_P_MASK (HLL_REGISTERS - 1) /* Mask to index register. */ #define HLL_BITS 6 /* Enough to count up to 63 leading zeroes. */ #define HLL_REGISTER_MAX ((1 << HLL_BITS) - 1) #define HLL_HDR_SIZE sizeof(struct hllhdr) #define HLL_DENSE_SIZE (HLL_HDR_SIZE + ((HLL_REGISTERS * HLL_BITS + 7) / 8)) #define HLL_DENSE 0 /* Dense encoding. */ #define HLL_SPARSE 1 /* Sparse encoding. */ #define HLL_RAW 255 /* Only used internally, never exposed. */ #define HLL_MAX_ENCODING 1 /* =========================== Low level bit macros ========================= */ /* Macros to access the dense representation. * * We need to get and set 6 bit counters in an array of 8 bit bytes. * We use macros to make sure the code is inlined since speed is critical * especially in order to compute the approximated cardinality in * HLLCOUNT where we need to access all the registers at once. * For the same reason we also want to avoid conditionals in this code path. * * +--------+--------+--------+------// * |11000000|22221111|33333322|55444444 * +--------+--------+--------+------// * * Note: in the above representation the most significant bit (MSB) * of every byte is on the left. We start using bits from the LSB to MSB, * and so forth passing to the next byte. * * Example, we want to access to counter at pos = 1 ("111111" in the * illustration above). * * The index of the first byte b0 containing our data is: * * b0 = 6 * pos / 8 = 0 * * +--------+ * |11000000| <- Our byte at b0 * +--------+ * * The position of the first bit (counting from the LSB = 0) in the byte * is given by: * * fb = 6 * pos % 8 -> 6 * * Right shift b0 of 'fb' bits. * * +--------+ * |11000000| <- Initial value of b0 * |00000011| <- After right shift of 6 pos. * +--------+ * * Left shift b1 of bits 8-fb bits (2 bits) * * +--------+ * |22221111| <- Initial value of b1 * |22111100| <- After left shift of 2 bits. * +--------+ * * OR the two bits, and finally AND with 111111 (63 in decimal) to * clean the higher order bits we are not interested in: * * +--------+ * |00000011| <- b0 right shifted * |22111100| <- b1 left shifted * |22111111| <- b0 OR b1 * | 111111| <- (b0 OR b1) AND 63, our value. * +--------+ * * We can try with a different example, like pos = 0. In this case * the 6-bit counter is actually contained in a single byte. * * b0 = 6 * pos / 8 = 0 * * +--------+ * |11000000| <- Our byte at b0 * +--------+ * * fb = 6 * pos % 8 = 0 * * So we right shift of 0 bits (no shift in practice) and * left shift the next byte of 8 bits, even if we don't use it, * but this has the effect of clearing the bits so the result * will not be affected after the OR. * * ------------------------------------------------------------------------- * * Setting the register is a bit more complex, let's assume that 'val' * is the value we want to set, already in the right range. * * We need two steps, in one we need to clear the bits, and in the other * we need to bitwise-OR the new bits. * * Let's try with 'pos' = 1, so our first byte at 'b' is 0, * * "fb" is 6 in this case. * * +--------+ * |11000000| <- Our byte at b0 * +--------+ * * To create an AND-mask to clear the bits about this position, we just * initialize the mask with the value 63, left shift it of "fs" bits, * and finally invert the result. * * +--------+ * |00111111| <- "mask" starts at 63 * |11000000| <- "mask" after left shift of "ls" bits. * |00111111| <- "mask" after invert. * +--------+ * * Now we can bitwise-AND the byte at "b" with the mask, and bitwise-OR * it with "val" left-shifted of "ls" bits to set the new bits. * * Now let's focus on the next byte b1: * * +--------+ * |22221111| <- Initial value of b1 * +--------+ * * To build the AND mask we start again with the 63 value, right shift * it by 8-fb bits, and invert it. * * +--------+ * |00111111| <- "mask" set at 2&6-1 * |00001111| <- "mask" after the right shift by 8-fb = 2 bits * |11110000| <- "mask" after bitwise not. * +--------+ * * Now we can mask it with b+1 to clear the old bits, and bitwise-OR * with "val" left-shifted by "rs" bits to set the new value. */ /* Note: if we access the last counter, we will also access the b+1 byte * that is out of the array, but sds strings always have an implicit null * term, so the byte exists, and we can skip the conditional (or the need * to allocate 1 byte more explicitly). */ /* Store the value of the register at position 'regnum' into variable 'target'. * 'p' is an array of unsigned bytes. */ #define HLL_DENSE_GET_REGISTER(target, p, regnum) \ do { \ uint8_t* _p = (uint8_t*)p; \ unsigned long _byte = regnum * HLL_BITS / 8; \ unsigned long _fb = regnum * HLL_BITS & 7; \ unsigned long _fb8 = 8 - _fb; \ unsigned long b0 = _p[_byte]; \ unsigned long b1 = _p[_byte + 1]; \ target = ((b0 >> _fb) | (b1 << _fb8)) & HLL_REGISTER_MAX; \ } while (0) /* Set the value of the register at position 'regnum' to 'val'. * 'p' is an array of unsigned bytes. */ #define HLL_DENSE_SET_REGISTER(p, regnum, val) \ do { \ uint8_t* _p = (uint8_t*)p; \ unsigned long _byte = (regnum)*HLL_BITS / 8; \ unsigned long _fb = (regnum)*HLL_BITS & 7; \ unsigned long _fb8 = 8 - _fb; \ unsigned long _v = (val); \ _p[_byte] &= ~(HLL_REGISTER_MAX << _fb); \ _p[_byte] |= _v << _fb; \ _p[_byte + 1] &= ~(HLL_REGISTER_MAX >> _fb8); \ _p[_byte + 1] |= _v >> _fb8; \ } while (0) /* Macros to access the sparse representation. * The macros parameter is expected to be an uint8_t pointer. */ #define HLL_SPARSE_XZERO_BIT 0x40 /* 01xxxxxx */ #define HLL_SPARSE_VAL_BIT 0x80 /* 1vvvvvxx */ #define HLL_SPARSE_IS_ZERO(p) (((*(p)) & 0xc0) == 0) /* 00xxxxxx */ #define HLL_SPARSE_IS_XZERO(p) (((*(p)) & 0xc0) == HLL_SPARSE_XZERO_BIT) #define HLL_SPARSE_IS_VAL(p) ((*(p)) & HLL_SPARSE_VAL_BIT) #define HLL_SPARSE_ZERO_LEN(p) (((*(p)) & 0x3f) + 1) #define HLL_SPARSE_XZERO_LEN(p) (((((*(p)) & 0x3f) << 8) | (*((p) + 1))) + 1) #define HLL_SPARSE_VAL_VALUE(p) ((((*(p)) >> 2) & 0x1f) + 1) #define HLL_SPARSE_VAL_LEN(p) (((*(p)) & 0x3) + 1) #define HLL_SPARSE_VAL_MAX_VALUE 32 #define HLL_SPARSE_VAL_MAX_LEN 4 #define HLL_SPARSE_ZERO_MAX_LEN 64 #define HLL_SPARSE_XZERO_MAX_LEN 16384 #define HLL_SPARSE_VAL_SET(p, val, len) \ do { \ *(p) = (((val)-1) << 2 | ((len)-1)) | HLL_SPARSE_VAL_BIT; \ } while (0) #define HLL_SPARSE_ZERO_SET(p, len) \ do { \ *(p) = (len)-1; \ } while (0) #define HLL_SPARSE_XZERO_SET(p, len) \ do { \ int _l = (len)-1; \ *(p) = (_l >> 8) | HLL_SPARSE_XZERO_BIT; \ *((p) + 1) = (_l & 0xff); \ } while (0) #define HLL_ALPHA_INF 0.721347520444481703680 /* constant for 0.5/ln(2) */ /* ========================= HyperLogLog algorithm ========================= */ /* Our hash function is MurmurHash2, 64 bit version. * It was modified for Redis in order to provide the same result in * big and little endian archs (endian neutral). */ uint64_t MurmurHash64A(const void* key, int len, unsigned int seed) { const uint64_t m = 0xc6a4a7935bd1e995; const int r = 47; uint64_t h = seed ^ (len * m); const uint8_t* data = (const uint8_t*)key; const uint8_t* end = data + (len - (len & 7)); while (data != end) { uint64_t k; #if (BYTE_ORDER == LITTLE_ENDIAN) #ifdef USE_ALIGNED_ACCESS memcpy(&k, data, sizeof(uint64_t)); #else k = *((uint64_t*)data); #endif #else k = (uint64_t)data[0]; k |= (uint64_t)data[1] << 8; k |= (uint64_t)data[2] << 16; k |= (uint64_t)data[3] << 24; k |= (uint64_t)data[4] << 32; k |= (uint64_t)data[5] << 40; k |= (uint64_t)data[6] << 48; k |= (uint64_t)data[7] << 56; #endif k *= m; k ^= k >> r; k *= m; h ^= k; h *= m; data += 8; } switch (len & 7) { case 7: h ^= (uint64_t)data[6] << 48; /* fall-thru */ case 6: h ^= (uint64_t)data[5] << 40; /* fall-thru */ case 5: h ^= (uint64_t)data[4] << 32; /* fall-thru */ case 4: h ^= (uint64_t)data[3] << 24; /* fall-thru */ case 3: h ^= (uint64_t)data[2] << 16; /* fall-thru */ case 2: h ^= (uint64_t)data[1] << 8; /* fall-thru */ case 1: h ^= (uint64_t)data[0]; h *= m; /* fall-thru */ }; h ^= h >> r; h *= m; h ^= h >> r; return h; } /* Given a string element to add to the HyperLogLog, returns the length * of the pattern 000..1 of the element hash. As a side effect 'regp' is * set to the register index this element hashes to. */ int hllPatLen(unsigned char* ele, size_t elesize, long* regp) { uint64_t hash, bit, index; int count; /* Count the number of zeroes starting from bit HLL_REGISTERS * (that is a power of two corresponding to the first bit we don't use * as index). The max run can be 64-P+1 = Q+1 bits. * * Note that the final "1" ending the sequence of zeroes must be * included in the count, so if we find "001" the count is 3, and * the smallest count possible is no zeroes at all, just a 1 bit * at the first position, that is a count of 1. * * This may sound like inefficient, but actually in the average case * there are high probabilities to find a 1 after a few iterations. */ hash = MurmurHash64A(ele, elesize, 0xadc83b19ULL); index = hash & HLL_P_MASK; /* Register index. */ hash >>= HLL_P; /* Remove bits used to address the register. */ hash |= ((uint64_t)1 << HLL_Q); /* Make sure the loop terminates and count will be <= Q+1. */ bit = 1; count = 1; /* Initialized to 1 since we count the "00000...1" pattern. */ while ((hash & bit) == 0) { count++; bit <<= 1; } *regp = (int)index; return count; } /* ================== Dense representation implementation ================== */ /* Low level function to set the dense HLL register at 'index' to the * specified value if the current value is smaller than 'count'. * * 'registers' is expected to have room for HLL_REGISTERS plus an * additional byte on the right. This requirement is met by sds strings * automatically since they are implicitly null terminated. * * The function always succeed, however if as a result of the operation * the approximated cardinality changed, 1 is returned. Otherwise 0 * is returned. */ int hllDenseSet(uint8_t* registers, long index, uint8_t count) { uint8_t oldcount; HLL_DENSE_GET_REGISTER(oldcount, registers, index); if (count > oldcount) { HLL_DENSE_SET_REGISTER(registers, index, count); return 1; } else { return 0; } } /* "Add" the element in the dense hyperloglog data structure. * Actually nothing is added, but the max 0 pattern counter of the subset * the element belongs to is incremented if needed. * * This is just a wrapper to hllDenseSet(), performing the hashing of the * element in order to retrieve the index and zero-run count. */ int hllDenseAdd(uint8_t* registers, unsigned char* ele, size_t elesize) { long index; uint8_t count = hllPatLen(ele, elesize, &index); /* Update the register if this element produced a longer run of zeroes. */ return hllDenseSet(registers, index, count); } /* Compute the register histogram in the dense representation. */ void hllDenseRegHisto(uint8_t* registers, int* reghisto) { int j; /* Redis default is to use 16384 registers 6 bits each. The code works * with other values by modifying the defines, but for our target value * we take a faster path with unrolled loops. */ if (HLL_REGISTERS == 16384 && HLL_BITS == 6) { uint8_t* r = registers; unsigned long r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15; for (j = 0; j < 1024; j++) { /* Handle 16 registers per iteration. */ r0 = r[0] & 63; r1 = (r[0] >> 6 | r[1] << 2) & 63; r2 = (r[1] >> 4 | r[2] << 4) & 63; r3 = (r[2] >> 2) & 63; r4 = r[3] & 63; r5 = (r[3] >> 6 | r[4] << 2) & 63; r6 = (r[4] >> 4 | r[5] << 4) & 63; r7 = (r[5] >> 2) & 63; r8 = r[6] & 63; r9 = (r[6] >> 6 | r[7] << 2) & 63; r10 = (r[7] >> 4 | r[8] << 4) & 63; r11 = (r[8] >> 2) & 63; r12 = r[9] & 63; r13 = (r[9] >> 6 | r[10] << 2) & 63; r14 = (r[10] >> 4 | r[11] << 4) & 63; r15 = (r[11] >> 2) & 63; reghisto[r0]++; reghisto[r1]++; reghisto[r2]++; reghisto[r3]++; reghisto[r4]++; reghisto[r5]++; reghisto[r6]++; reghisto[r7]++; reghisto[r8]++; reghisto[r9]++; reghisto[r10]++; reghisto[r11]++; reghisto[r12]++; reghisto[r13]++; reghisto[r14]++; reghisto[r15]++; r += 12; } } else { for (j = 0; j < HLL_REGISTERS; j++) { unsigned long reg; HLL_DENSE_GET_REGISTER(reg, registers, j); reghisto[reg]++; } } } /* ================== Sparse representation implementation ================= */ /* Convert the HLL with sparse representation given as input in its dense * representation. Both representations are represented by SDS strings, and * the input representation is freed as a side effect. * * The function returns C_OK if the sparse representation was valid, * otherwise C_ERR is returned if the representation was corrupted. */ int hllSparseToDense(sds* hll_ptr) { sds sparse = *hll_ptr, dense; struct hllhdr *hdr, *oldhdr = (struct hllhdr*)sparse; int idx = 0, runlen, regval; uint8_t *p = (uint8_t*)sparse, *end = p+sdslen(sparse); /* If the representation is already the right one return ASAP. */ hdr = (struct hllhdr*) sparse; if (hdr->encoding == HLL_DENSE) return C_OK; /* Create a string of the right size filled with zero bytes. * Note that the cached cardinality is set to 0 as a side effect * that is exactly the cardinality of an empty HLL. */ dense = sdsnewlen(NULL,HLL_DENSE_SIZE); hdr = (struct hllhdr*) dense; *hdr = *oldhdr; /* This will copy the magic and cached cardinality. */ hdr->encoding = HLL_DENSE; /* Now read the sparse representation and set non-zero registers * accordingly. */ p += HLL_HDR_SIZE; while(p < end) { if (HLL_SPARSE_IS_ZERO(p)) { runlen = HLL_SPARSE_ZERO_LEN(p); idx += runlen; p++; } else if (HLL_SPARSE_IS_XZERO(p)) { runlen = HLL_SPARSE_XZERO_LEN(p); idx += runlen; p += 2; } else { runlen = HLL_SPARSE_VAL_LEN(p); regval = HLL_SPARSE_VAL_VALUE(p); if ((runlen + idx) > HLL_REGISTERS) break; /* Overflow. */ while(runlen--) { HLL_DENSE_SET_REGISTER(hdr->registers,idx,regval); idx++; } p++; } } /* If the sparse representation was valid, we expect to find idx * set to HLL_REGISTERS. */ if (idx != HLL_REGISTERS) { sdsfree(dense); return C_ERR; } /* Free the old representation and set the new one. */ sdsfree(*hll_ptr); *hll_ptr = dense; return C_OK; } /* Low level function to set the sparse HLL register at 'index' to the * specified value if the current value is smaller than 'count'. * * The object 'hll' is the SDS object holding the HLL. The function requires * a reference to the object in order to be able to enlarge the string if * needed. * * On success, the function returns 1 if the cardinality changed, or 0 * if the register for this element was not updated. * On error (if the representation is invalid) -1 is returned. * * As a side effect the function may promote the HLL representation from * sparse to dense: this happens when a register requires to be set to a value * not representable with the sparse representation, or when the resulting * size would be greater than HLL_SPARSE_MAX_BYTES. */ int hllSparseSet(sds* hll_ptr, long index, uint8_t count, int* promoted) { struct hllhdr *hdr; uint8_t oldcount, *sparse, *end, *p, *prev, *next; long first, span; long is_zero = 0, is_xzero = 0, is_val = 0, runlen = 0; /* If the count is too big to be representable by the sparse representation * switch to dense representation. */ if (count > HLL_SPARSE_VAL_MAX_VALUE) goto promote; /* When updating a sparse representation, sometimes we may need to enlarge the * buffer for up to 3 bytes in the worst case (XZERO split into XZERO-VAL-XZERO), * and the following code does the enlarge job. * Actually, we use a greedy strategy, enlarge more than 3 bytes to avoid the need * for future reallocates on incremental growth. But we do not allocate more than * 'HLL_SPARSE_MAX_BYTES' bytes for the sparse representation. * If the available size of hyperloglog sds string is not enough for the increment * we need, we promote the hypreloglog to dense representation in 'step 3'. */ sds hll = *hll_ptr; if (sdsalloc(hll) < HLL_SPARSE_MAX_BYTES && sdsavail(hll) < 3) { size_t newlen = sdslen(hll) + 3; newlen += min(newlen, 300); /* Greediness: double 'newlen' if it is smaller than 300, or add 300 to it when it exceeds 300 */ if (newlen > HLL_SPARSE_MAX_BYTES) newlen = HLL_SPARSE_MAX_BYTES; *hll_ptr = sdsResize(hll, newlen); hll = *hll_ptr; } /* Step 1: we need to locate the opcode we need to modify to check * if a value update is actually needed. */ sparse = p = ((uint8_t*)hll) + HLL_HDR_SIZE; end = p + sdslen(hll) - HLL_HDR_SIZE; first = 0; prev = NULL; /* Points to previous opcode at the end of the loop. */ next = NULL; /* Points to the next opcode at the end of the loop. */ span = 0; while(p < end) { long oplen; /* Set span to the number of registers covered by this opcode. * * This is the most performance critical loop of the sparse * representation. Sorting the conditionals from the most to the * least frequent opcode in many-bytes sparse HLLs is faster. */ oplen = 1; if (HLL_SPARSE_IS_ZERO(p)) { span = HLL_SPARSE_ZERO_LEN(p); } else if (HLL_SPARSE_IS_VAL(p)) { span = HLL_SPARSE_VAL_LEN(p); } else { /* XZERO. */ span = HLL_SPARSE_XZERO_LEN(p); oplen = 2; } /* Break if this opcode covers the register as 'index'. */ if (index <= first+span-1) break; prev = p; p += oplen; first += span; } if (span == 0 || p >= end) return -1; /* Invalid format. */ next = HLL_SPARSE_IS_XZERO(p) ? p+2 : p+1; if (next >= end) next = NULL; /* Cache current opcode type to avoid using the macro again and * again for something that will not change. * Also cache the run-length of the opcode. */ if (HLL_SPARSE_IS_ZERO(p)) { is_zero = 1; runlen = HLL_SPARSE_ZERO_LEN(p); } else if (HLL_SPARSE_IS_XZERO(p)) { is_xzero = 1; runlen = HLL_SPARSE_XZERO_LEN(p); } else { is_val = 1; runlen = HLL_SPARSE_VAL_LEN(p); } /* Step 2: After the loop: * * 'first' stores to the index of the first register covered * by the current opcode, which is pointed by 'p'. * * 'next' ad 'prev' store respectively the next and previous opcode, * or NULL if the opcode at 'p' is respectively the last or first. * * 'span' is set to the number of registers covered by the current * opcode. * * There are different cases in order to update the data structure * in place without generating it from scratch: * * A) If it is a VAL opcode already set to a value >= our 'count' * no update is needed, regardless of the VAL run-length field. * In this case PFADD returns 0 since no changes are performed. * * B) If it is a VAL opcode with len = 1 (representing only our * register) and the value is less than 'count', we just update it * since this is a trivial case. */ if (is_val) { oldcount = HLL_SPARSE_VAL_VALUE(p); /* Case A. */ if (oldcount >= count) return 0; /* Case B. */ if (runlen == 1) { HLL_SPARSE_VAL_SET(p,count,1); goto updated; } } /* C) Another trivial to handle case is a ZERO opcode with a len of 1. * We can just replace it with a VAL opcode with our value and len of 1. */ if (is_zero && runlen == 1) { HLL_SPARSE_VAL_SET(p,count,1); goto updated; } /* D) General case. * * The other cases are more complex: our register requires to be updated * and is either currently represented by a VAL opcode with len > 1, * by a ZERO opcode with len > 1, or by an XZERO opcode. * * In those cases the original opcode must be split into multiple * opcodes. The worst case is an XZERO split in the middle resulting into * XZERO - VAL - XZERO, so the resulting sequence max length is * 5 bytes. * * We perform the split writing the new sequence into the 'new' buffer * with 'newlen' as length. Later the new sequence is inserted in place * of the old one, possibly moving what is on the right a few bytes * if the new sequence is longer than the older one. */ uint8_t seq[5], *n = seq; int last = first+span-1; /* Last register covered by the sequence. */ int len; if (is_zero || is_xzero) { /* Handle splitting of ZERO / XZERO. */ if (index != first) { len = index-first; if (len > HLL_SPARSE_ZERO_MAX_LEN) { HLL_SPARSE_XZERO_SET(n,len); n += 2; } else { HLL_SPARSE_ZERO_SET(n,len); n++; } } HLL_SPARSE_VAL_SET(n,count,1); n++; if (index != last) { len = last-index; if (len > HLL_SPARSE_ZERO_MAX_LEN) { HLL_SPARSE_XZERO_SET(n,len); n += 2; } else { HLL_SPARSE_ZERO_SET(n,len); n++; } } } else { /* Handle splitting of VAL. */ int curval = HLL_SPARSE_VAL_VALUE(p); if (index != first) { len = index-first; HLL_SPARSE_VAL_SET(n,curval,len); n++; } HLL_SPARSE_VAL_SET(n,count,1); n++; if (index != last) { len = last-index; HLL_SPARSE_VAL_SET(n,curval,len); n++; } } /* Step 3: substitute the new sequence with the old one. * * Note that we already allocated space on the sds string * calling sdsResize(). */ int seqlen = n-seq; int oldlen = is_xzero ? 2 : 1; int deltalen = seqlen-oldlen; if (deltalen > 0 && sdslen(hll) + deltalen > HLL_SPARSE_MAX_BYTES) goto promote; serverAssert(sdslen(hll) + deltalen <= sdsalloc(hll)); if (deltalen && next) memmove(next+deltalen,next,end-next); sdsIncrLen(hll,deltalen); memcpy(p,seq,seqlen); end += deltalen; updated: /* Step 4: Merge adjacent values if possible. * * The representation was updated, however the resulting representation * may not be optimal: adjacent VAL opcodes can sometimes be merged into * a single one. */ p = prev ? prev : sparse; int scanlen = 5; /* Scan up to 5 upcodes starting from prev. */ while (p < end && scanlen--) { if (HLL_SPARSE_IS_XZERO(p)) { p += 2; continue; } else if (HLL_SPARSE_IS_ZERO(p)) { p++; continue; } /* We need two adjacent VAL opcodes to try a merge, having * the same value, and a len that fits the VAL opcode max len. */ if (p+1 < end && HLL_SPARSE_IS_VAL(p+1)) { int v1 = HLL_SPARSE_VAL_VALUE(p); int v2 = HLL_SPARSE_VAL_VALUE(p+1); if (v1 == v2) { int len = HLL_SPARSE_VAL_LEN(p)+HLL_SPARSE_VAL_LEN(p+1); if (len <= HLL_SPARSE_VAL_MAX_LEN) { HLL_SPARSE_VAL_SET(p+1,v1,len); memmove(p,p+1,end-p); sdsIncrLen(hll,-1); end--; /* After a merge we reiterate without incrementing 'p' * in order to try to merge the just merged value with * a value on its right. */ continue; } } } p++; } /* Invalidate the cached cardinality. */ hdr = (struct hllhdr *)hll; HLL_INVALIDATE_CACHE(hdr); return 1; promote: /* Promote to dense representation. */ if (hllSparseToDense(&hll) == C_ERR) return -1; /* Corrupted HLL. */ *hll_ptr = hll; hdr = (struct hllhdr *)hll; /* We need to call hllDenseAdd() to perform the operation after the * conversion. However the result must be 1, since if we need to * convert from sparse to dense a register requires to be updated. * * Note that this in turn means that PFADD will make sure the command * is propagated to slaves / AOF, so if there is a sparse -> dense * conversion, it will be performed in all the slaves as well. */ int dense_retval = hllDenseSet(hdr->registers,index,count); serverAssert(dense_retval == 1); *promoted = 1; return dense_retval; } /* "Add" the element in the sparse hyperloglog data structure. * Actually nothing is added, but the max 0 pattern counter of the subset * the element belongs to is incremented if needed. * * This function is actually a wrapper for hllSparseSet(), it only performs * the hashing of the element to obtain the index and zeros run length. */ int hllSparseAdd(sds* hll_ptr, unsigned char *ele, size_t elesize, int* promoted) { long index; uint8_t count = hllPatLen(ele,elesize,&index); /* Update the register if this element produced a longer run of zeroes. */ return hllSparseSet(hll_ptr,index,count, promoted); } /* Compute the register histogram in the sparse representation. */ void hllSparseRegHisto(uint8_t* sparse, int sparselen, int* invalid, int* reghisto) { int idx = 0, runlen, regval; uint8_t *end = sparse + sparselen, *p = sparse; while (p < end) { if (HLL_SPARSE_IS_ZERO(p)) { runlen = HLL_SPARSE_ZERO_LEN(p); idx += runlen; reghisto[0] += runlen; p++; } else if (HLL_SPARSE_IS_XZERO(p)) { runlen = HLL_SPARSE_XZERO_LEN(p); idx += runlen; reghisto[0] += runlen; p += 2; } else { runlen = HLL_SPARSE_VAL_LEN(p); regval = HLL_SPARSE_VAL_VALUE(p); idx += runlen; reghisto[regval] += runlen; p++; } } if (idx != HLL_REGISTERS && invalid) *invalid = 1; } /* ========================= HyperLogLog Count ============================== * This is the core of the algorithm where the approximated count is computed. * The function uses the lower level hllDenseRegHisto() and hllSparseRegHisto() * functions as helpers to compute histogram of register values part of the * computation, which is representation-specific, while all the rest is common. */ /* Implements the register histogram calculation for uint8_t data type * which is only used internally as speedup for PFCOUNT with multiple keys. */ void hllRawRegHisto(uint8_t* registers, int* reghisto) { uint64_t* word = (uint64_t*)registers; uint8_t* bytes; int j; for (j = 0; j < HLL_REGISTERS / 8; j++) { if (*word == 0) { reghisto[0] += 8; } else { bytes = (uint8_t*)word; reghisto[bytes[0]]++; reghisto[bytes[1]]++; reghisto[bytes[2]]++; reghisto[bytes[3]]++; reghisto[bytes[4]]++; reghisto[bytes[5]]++; reghisto[bytes[6]]++; reghisto[bytes[7]]++; } word++; } } /* Helper function sigma as defined in * "New cardinality estimation algorithms for HyperLogLog sketches" * Otmar Ertl, arXiv:1702.01284 */ double hllSigma(double x) { if (x == 1.) return INFINITY; double zPrime; double y = 1; double z = x; do { x *= x; zPrime = z; z += x * y; y += y; } while (zPrime != z); return z; } /* Helper function tau as defined in * "New cardinality estimation algorithms for HyperLogLog sketches" * Otmar Ertl, arXiv:1702.01284 */ double hllTau(double x) { if (x == 0. || x == 1.) return 0.; double zPrime; double y = 1.0; double z = 1 - x; do { x = sqrt(x); zPrime = z; y *= 0.5; z -= pow(1 - x, 2) * y; } while (zPrime != z); return z / 3; } /* Return the approximated cardinality of the set based on the harmonic * mean of the registers values. 'hdr' points to the start of the SDS * representing the String object holding the HLL representation. * * If the sparse representation of the HLL object is not valid, the integer * pointed by 'invalid' is set to non-zero, otherwise it is left untouched. * * hllCount() supports a special internal-only encoding of HLL_RAW, that * is, hdr->registers will point to an uint8_t array of HLL_REGISTERS element. * This is useful in order to speedup PFCOUNT when called against multiple * keys (no need to work with 6-bit integers encoding). */ uint64_t hllCount(struct hllhdr* hdr, int* invalid) { double m = HLL_REGISTERS; double E; int j; /* Note that reghisto size could be just HLL_Q+2, because HLL_Q+1 is * the maximum frequency of the "000...1" sequence the hash function is * able to return. However it is slow to check for sanity of the * input: instead we history array at a safe size: overflows will * just write data to wrong, but correctly allocated, places. */ int reghisto[64] = {0}; /* Compute register histogram */ if (hdr->encoding == HLL_DENSE) { hllDenseRegHisto(hdr->registers, reghisto); } else if (hdr->encoding == HLL_SPARSE) { hllSparseRegHisto(hdr->registers, sdslen((sds)hdr) - HLL_HDR_SIZE, invalid, reghisto); } else if (hdr->encoding == HLL_RAW) { hllRawRegHisto(hdr->registers, reghisto); } else { serverPanic("Unknown HyperLogLog encoding in hllCount()"); } /* Estimate cardinality from register histogram. See: * "New cardinality estimation algorithms for HyperLogLog sketches" * Otmar Ertl, arXiv:1702.01284 */ double z = m * hllTau((m - reghisto[HLL_Q + 1]) / (double)m); for (j = HLL_Q; j >= 1; --j) { z += reghisto[j]; z *= 0.5; } z += m * hllSigma(reghisto[0] / (double)m); E = llroundl(HLL_ALPHA_INF * m * m / z); return (uint64_t)E; } #if 0 /* Merge by computing MAX(registers[i],hll[i]) the HyperLogLog 'hll' * with an array of uint8_t HLL_REGISTERS registers pointed by 'max'. * * The hll object must be already validated via isHLLObjectOrReply() * or in some other way. * * If the HyperLogLog is sparse and is found to be invalid, C_ERR * is returned, otherwise the function always succeeds. */ int hllMerge(uint8_t* max, robj* hll) { struct hllhdr* hdr = hll->ptr; int i; if (hdr->encoding == HLL_DENSE) { uint8_t val; for (i = 0; i < HLL_REGISTERS; i++) { HLL_DENSE_GET_REGISTER(val, hdr->registers, i); if (val > max[i]) max[i] = val; } } else { uint8_t *p = hll->ptr, *end = p + sdslen(hll->ptr); long runlen, regval; p += HLL_HDR_SIZE; i = 0; while (p < end) { if (HLL_SPARSE_IS_ZERO(p)) { runlen = HLL_SPARSE_ZERO_LEN(p); i += runlen; p++; } else if (HLL_SPARSE_IS_XZERO(p)) { runlen = HLL_SPARSE_XZERO_LEN(p); i += runlen; p += 2; } else { runlen = HLL_SPARSE_VAL_LEN(p); regval = HLL_SPARSE_VAL_VALUE(p); if ((runlen + i) > HLL_REGISTERS) break; /* Overflow. */ while (runlen--) { if (regval > max[i]) max[i] = regval; i++; } p++; } } if (i != HLL_REGISTERS) return C_ERR; } return C_OK; } /* ========================== HyperLogLog commands ========================== */ robj* createHLLObject(void) { robj* o; struct hllhdr* hdr; sds s; uint8_t* p; int sparselen = HLL_HDR_SIZE + (((HLL_REGISTERS + (HLL_SPARSE_XZERO_MAX_LEN - 1)) / HLL_SPARSE_XZERO_MAX_LEN) * 2); int aux; /* Populate the sparse representation with as many XZERO opcodes as * needed to represent all the registers. */ aux = HLL_REGISTERS; s = sdsnewlen(NULL, sparselen); p = (uint8_t*)s + HLL_HDR_SIZE; while (aux) { int xzero = HLL_SPARSE_XZERO_MAX_LEN; if (xzero > aux) xzero = aux; HLL_SPARSE_XZERO_SET(p, xzero); p += 2; aux -= xzero; } serverAssert((p - (uint8_t*)s) == sparselen); /* Create the actual object. */ o = createObject(OBJ_STRING, s); hdr = o->ptr; memcpy(hdr->magic, "HYLL", 4); hdr->encoding = HLL_SPARSE; return o; } #endif /* ========================== Dragonfly custom functions ===================== */ enum HllValidness isValidHLL(struct HllBufferPtr hll_buffer) { struct hllhdr* hdr; if (hll_buffer.size < sizeof(*hdr)) { return HLL_INVALID; } hdr = (struct hllhdr*)hll_buffer.hll; /* Magic should be "HYLL". */ if (hdr->magic[0] != 'H' || hdr->magic[1] != 'Y' || hdr->magic[2] != 'L' || hdr->magic[3] != 'L') { return HLL_INVALID; } if (hdr->encoding > HLL_MAX_ENCODING) { return HLL_INVALID; } switch (hdr->encoding) { case HLL_DENSE: /* Dense representation string length should match exactly. */ return (hll_buffer.size == HLL_DENSE_SIZE) ? HLL_VALID_DENSE : HLL_INVALID; case HLL_SPARSE: return HLL_VALID_SPARSE; default: return HLL_INVALID; } } size_t getDenseHllSize() { return HLL_DENSE_SIZE; } size_t getSparseHllInitSize() { return HLL_HDR_SIZE + (((HLL_REGISTERS+(HLL_SPARSE_XZERO_MAX_LEN-1)) / HLL_SPARSE_XZERO_MAX_LEN)*2); } int initSparseHll(struct HllBufferPtr hll_ptr) { if (hll_ptr.size != getSparseHllInitSize()) { return C_ERR; } memset(hll_ptr.hll, 0, hll_ptr.size); /* Populate the sparse representation with as many XZERO opcodes as * needed to represent all the registers. */ int aux = HLL_REGISTERS; uint8_t* p = (uint8_t*)hll_ptr.hll + HLL_HDR_SIZE; while(aux) { int xzero = HLL_SPARSE_XZERO_MAX_LEN; if (xzero > aux) xzero = aux; HLL_SPARSE_XZERO_SET(p,xzero); p += 2; aux -= xzero; } struct hllhdr* hdr = (struct hllhdr*)hll_ptr.hll; memcpy(hdr->magic, "HYLL", 4); hdr->encoding = HLL_SPARSE; return C_OK; } int createDenseHll(struct HllBufferPtr hll_ptr) { if (hll_ptr.size != getDenseHllSize()) { return C_ERR; } memset(hll_ptr.hll, 0, hll_ptr.size); struct hllhdr* hdr = (struct hllhdr*)hll_ptr.hll; memcpy(hdr->magic, "HYLL", 4); hdr->encoding = HLL_DENSE; return C_OK; } /* This is a copied & modified version of hllSparseToDense() above that does not use robj */ int convertSparseToDenseHll(struct HllBufferPtr in_hll, struct HllBufferPtr out_hll) { struct hllhdr *hdr, *oldhdr = (struct hllhdr*)in_hll.hll; int idx = 0, runlen, regval; uint8_t *p = (uint8_t*)in_hll.hll, *end = p + in_hll.size; if (oldhdr->encoding != HLL_SPARSE) return C_ERR; if (out_hll.size != getDenseHllSize()) return C_ERR; /* Create a string of the right size filled with zero bytes. * Note that the cached cardinality is set to 0 as a side effect * that is exactly the cardinality of an empty HLL. */ hdr = (struct hllhdr*)out_hll.hll; *hdr = *oldhdr; /* This will copy the magic and cached cardinality. */ hdr->encoding = HLL_DENSE; /* Now read the sparse representation and set non-zero registers * accordingly. */ p += HLL_HDR_SIZE; while (p < end) { if (HLL_SPARSE_IS_ZERO(p)) { runlen = HLL_SPARSE_ZERO_LEN(p); idx += runlen; p++; } else if (HLL_SPARSE_IS_XZERO(p)) { runlen = HLL_SPARSE_XZERO_LEN(p); idx += runlen; p += 2; } else { runlen = HLL_SPARSE_VAL_LEN(p); regval = HLL_SPARSE_VAL_VALUE(p); if ((runlen + idx) > HLL_REGISTERS) break; /* Overflow. */ while (runlen--) { HLL_DENSE_SET_REGISTER(hdr->registers, idx, regval); idx++; } p++; } } /* If the sparse representation was valid, we expect to find idx * set to HLL_REGISTERS. */ if (idx != HLL_REGISTERS) { return C_ERR; } return C_OK; } int pfadd_sparse(sds* hll_ptr, const unsigned char* value, size_t size, int* promoted) { struct hllhdr* hdr = (struct hllhdr*)(*hll_ptr); int retval = hllSparseAdd(hll_ptr, (unsigned char*)value, size, promoted); switch (retval) { case 1: HLL_INVALIDATE_CACHE(hdr); return 1; default: return retval; } } int pfadd_dense(struct HllBufferPtr hll_ptr, const unsigned char* value, size_t size) { if (isValidHLL(hll_ptr) != HLL_VALID_DENSE) return C_ERR; struct hllhdr* hdr = (struct hllhdr*)hll_ptr.hll; /* Perform the low level ADD operation for every element. */ int retval = hllDenseAdd(hdr->registers, (unsigned char*)value, size); switch (retval) { case 1: HLL_INVALIDATE_CACHE(hdr); return 1; default: return retval; } } int64_t pfcountSingle(struct HllBufferPtr hll_ptr) { uint64_t card; if (isValidHLL(hll_ptr) != HLL_VALID_DENSE) return C_ERR; /* Check if the cached cardinality is valid. */ struct hllhdr* hdr = (struct hllhdr*)hll_ptr.hll; if (HLL_VALID_CACHE(hdr)) { /* Just return the cached value. */ card = (uint64_t)hdr->card[0]; card |= (uint64_t)hdr->card[1] << 8; card |= (uint64_t)hdr->card[2] << 16; card |= (uint64_t)hdr->card[3] << 24; card |= (uint64_t)hdr->card[4] << 32; card |= (uint64_t)hdr->card[5] << 40; card |= (uint64_t)hdr->card[6] << 48; card |= (uint64_t)hdr->card[7] << 56; } else { int invalid = 0; /* Recompute it and update the cached value. */ card = hllCount(hdr, &invalid); if (invalid) { return -1; } hdr->card[0] = card & 0xff; hdr->card[1] = (card >> 8) & 0xff; hdr->card[2] = (card >> 16) & 0xff; hdr->card[3] = (card >> 24) & 0xff; hdr->card[4] = (card >> 32) & 0xff; hdr->card[5] = (card >> 40) & 0xff; hdr->card[6] = (card >> 48) & 0xff; hdr->card[7] = (card >> 56) & 0xff; } return card; } /* Merge dense-encoded HLL */ static void hllMergeDense(uint8_t* registers, struct HllBufferPtr to) { uint8_t val; struct hllhdr* hll_hdr = (struct hllhdr*)to.hll; for (int i = 0; i < HLL_REGISTERS; i++) { HLL_DENSE_GET_REGISTER(val, hll_hdr->registers, i); if (val > registers[i]) { registers[i] = val; } } } int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count) { struct hllhdr* hdr; uint8_t max[HLL_HDR_SIZE + HLL_REGISTERS]; /* Compute an HLL with M[i] = MAX(M[i]_j). */ memset(max, 0, sizeof(max)); hdr = (struct hllhdr*)max; hdr->encoding = HLL_RAW; /* Special internal-only encoding. */ for (size_t j = 0; j < hlls_count; j++) { /* Check type and size. */ struct HllBufferPtr hll = hlls[j]; if (isValidHLL(hll) != HLL_VALID_DENSE) { return C_ERR; } hllMergeDense(max, hll); } /* Compute cardinality of the resulting set. */ return hllCount(hdr, NULL); } int pfmerge(struct HllBufferPtr* in_hlls, size_t in_hlls_count, struct HllBufferPtr out_hll) { if (isValidHLL(out_hll) != HLL_VALID_DENSE) { return C_ERR; } uint8_t max[HLL_REGISTERS]; /* Compute an HLL with M[i] = MAX(M[i]_j). * We store the maximum into the max array of registers. We'll write * it to the target variable later. */ memset(max, 0, sizeof(max)); for (size_t j = 0; j < in_hlls_count; j++) { struct HllBufferPtr hll = in_hlls[j]; if (isValidHLL(hll) != HLL_VALID_DENSE) { return C_ERR; } hllMergeDense(max, hll); } struct hllhdr* hdr = (struct hllhdr*)out_hll.hll; for (size_t j = 0; j < HLL_REGISTERS; j++) { hllDenseSet(hdr->registers, j, max[j]); } HLL_INVALIDATE_CACHE(hdr); return C_OK; } ================================================ FILE: src/redis/hyperloglog.h ================================================ #ifndef __REDIS_HYPERLOGLOG_H #define __REDIS_HYPERLOGLOG_H #include #include #include "redis/sds.h" /* This version of hyperloglog, forked from Redis, only supports using the dense format of HLL. * The reason is that it is of a fixed size, which makes it easier to integrate into Dragonfly. * We do support converting of existing sprase-encoded HLL into dense-encoded, which can be useful * for replication, serialization, etc. */ enum HllValidness { HLL_INVALID, HLL_VALID_SPARSE, HLL_VALID_DENSE, }; /* Convenience struct for pointing to an Hll buffer along with its size */ struct HllBufferPtr { unsigned char* hll; size_t size; }; enum HllValidness isValidHLL(struct HllBufferPtr hll_ptr); size_t getDenseHllSize(); size_t getSparseHllInitSize(); int initSparseHll(struct HllBufferPtr hll_ptr); /* Writes into `hll_ptr` an empty dense-encoded HLL. * Returns 0 upon success, or a negative number when `hll_ptr.size` is different from * getDenseHllSize() */ int createDenseHll(struct HllBufferPtr hll_ptr); /* Converts an existing sparse-encoded HLL pointed by `in_hll`, and writes the converted result into * `out_hll`. * Returns 0 upon success, otherwise a negative number. * Failures can occur when `out_hll.size` is different from getDenseHllSize() or when input is not a * valid sparse-encoded HLL. */ int convertSparseToDenseHll(struct HllBufferPtr in_hll, struct HllBufferPtr out_hll); /* Adds `value` of size `size`, to `hll_ptr`. * If `obj` does not have an underlying type of HLL a negative number is returned. */ int pfadd_sparse(sds* hll_ptr, const unsigned char* value, size_t size, int* promoted); int pfadd_dense(struct HllBufferPtr hll_ptr, const unsigned char* value, size_t size); /* Returns the estimated count of elements for `hll_ptr`. * If `hll_ptr` is not a valid dense-encoded HLL, a negative number is returned. */ int64_t pfcountSingle(struct HllBufferPtr hll_ptr); /* Returns the estimated count for all HLLs in `hlls` array of size `hlls_count`. * All `hlls` elements must be valid, dense-encoded HLLs. */ int64_t pfcountMulti(struct HllBufferPtr* hlls, size_t hlls_count); /* Merges array of HLLs pointed to be `in_hlls` of size `in_hlls_count` into `out_hll`. * Returns 0 upon success, otherwise a negative number. * Failure can occur when any of `in_hlls` or `out_hll` is not a dense-encoded HLL. * `out_hll` *can* be one of the elements in `in_hlls`. */ int pfmerge(struct HllBufferPtr* in_hlls, size_t in_hlls_count, struct HllBufferPtr out_hll); #endif ================================================ FILE: src/redis/intset.c ================================================ /* * Copyright (c) 2009-2012, Pieter Noordhuis * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include "intset.h" #include "zmalloc.h" #include "endianconv.h" /* Note that these encodings are ordered, so: * INTSET_ENC_INT16 < INTSET_ENC_INT32 < INTSET_ENC_INT64. */ #define INTSET_ENC_INT16 (sizeof(int16_t)) #define INTSET_ENC_INT32 (sizeof(int32_t)) #define INTSET_ENC_INT64 (sizeof(int64_t)) /* Return the required encoding for the provided value. */ static uint8_t _intsetValueEncoding(int64_t v) { if (v < INT32_MIN || v > INT32_MAX) return INTSET_ENC_INT64; else if (v < INT16_MIN || v > INT16_MAX) return INTSET_ENC_INT32; else return INTSET_ENC_INT16; } /* Return the value at pos, given an encoding. */ static int64_t _intsetGetEncoded(intset *is, int pos, uint8_t enc) { int64_t v64; int32_t v32; int16_t v16; if (enc == INTSET_ENC_INT64) { memcpy(&v64,((int64_t*)is->contents)+pos,sizeof(v64)); memrev64ifbe(&v64); return v64; } else if (enc == INTSET_ENC_INT32) { memcpy(&v32,((int32_t*)is->contents)+pos,sizeof(v32)); memrev32ifbe(&v32); return v32; } else { memcpy(&v16,((int16_t*)is->contents)+pos,sizeof(v16)); memrev16ifbe(&v16); return v16; } } /* Return the value at pos, using the configured encoding. */ static int64_t _intsetGet(intset *is, int pos) { return _intsetGetEncoded(is,pos,intrev32ifbe(is->encoding)); } /* Set the value at pos, using the configured encoding. */ static void _intsetSet(intset *is, int pos, int64_t value) { uint32_t encoding = intrev32ifbe(is->encoding); if (encoding == INTSET_ENC_INT64) { ((int64_t*)is->contents)[pos] = value; memrev64ifbe(((int64_t*)is->contents)+pos); } else if (encoding == INTSET_ENC_INT32) { ((int32_t*)is->contents)[pos] = value; memrev32ifbe(((int32_t*)is->contents)+pos); } else { ((int16_t*)is->contents)[pos] = value; memrev16ifbe(((int16_t*)is->contents)+pos); } } /* Create an empty intset. */ intset *intsetNew(void) { intset *is = zmalloc(sizeof(intset)); is->encoding = intrev32ifbe(INTSET_ENC_INT16); is->length = 0; return is; } /* Resize the intset */ static intset *intsetResize(intset *is, uint32_t len) { uint64_t size = (uint64_t)len*intrev32ifbe(is->encoding); assert(size <= SIZE_MAX - sizeof(intset)); is = zrealloc(is,sizeof(intset)+size); return is; } /* Search for the position of "value". Return 1 when the value was found and * sets "pos" to the position of the value within the intset. Return 0 when * the value is not present in the intset and sets "pos" to the position * where "value" can be inserted. */ static uint8_t intsetSearch(intset *is, int64_t value, uint32_t *pos) { int min = 0, max = intrev32ifbe(is->length)-1, mid = -1; int64_t cur = -1; /* The value can never be found when the set is empty */ if (intrev32ifbe(is->length) == 0) { if (pos) *pos = 0; return 0; } else { /* Check for the case where we know we cannot find the value, * but do know the insert position. */ if (value > _intsetGet(is,max)) { if (pos) *pos = intrev32ifbe(is->length); return 0; } else if (value < _intsetGet(is,0)) { if (pos) *pos = 0; return 0; } } while(max >= min) { mid = ((unsigned int)min + (unsigned int)max) >> 1; cur = _intsetGet(is,mid); if (value > cur) { min = mid+1; } else if (value < cur) { max = mid-1; } else { break; } } if (value == cur) { if (pos) *pos = mid; return 1; } else { if (pos) *pos = min; return 0; } } /* Upgrades the intset to a larger encoding and inserts the given integer. */ static intset *intsetUpgradeAndAdd(intset *is, int64_t value) { uint8_t curenc = intrev32ifbe(is->encoding); uint8_t newenc = _intsetValueEncoding(value); int length = intrev32ifbe(is->length); int prepend = value < 0 ? 1 : 0; /* First set new encoding and resize */ is->encoding = intrev32ifbe(newenc); is = intsetResize(is,intrev32ifbe(is->length)+1); /* Upgrade back-to-front so we don't overwrite values. * Note that the "prepend" variable is used to make sure we have an empty * space at either the beginning or the end of the intset. */ while(length--) _intsetSet(is,length+prepend,_intsetGetEncoded(is,length,curenc)); /* Set the value at the beginning or the end. */ if (prepend) _intsetSet(is,0,value); else _intsetSet(is,intrev32ifbe(is->length),value); is->length = intrev32ifbe(intrev32ifbe(is->length)+1); return is; } static void intsetMoveTail(intset *is, uint32_t from, uint32_t to) { void *src, *dst; uint32_t bytes = intrev32ifbe(is->length)-from; uint32_t encoding = intrev32ifbe(is->encoding); if (encoding == INTSET_ENC_INT64) { src = (int64_t*)is->contents+from; dst = (int64_t*)is->contents+to; bytes *= sizeof(int64_t); } else if (encoding == INTSET_ENC_INT32) { src = (int32_t*)is->contents+from; dst = (int32_t*)is->contents+to; bytes *= sizeof(int32_t); } else { src = (int16_t*)is->contents+from; dst = (int16_t*)is->contents+to; bytes *= sizeof(int16_t); } memmove(dst,src,bytes); } /* Insert an integer in the intset */ intset *intsetAdd(intset *is, int64_t value, uint8_t *success) { uint8_t valenc = _intsetValueEncoding(value); uint32_t pos; if (success) *success = 1; /* Upgrade encoding if necessary. If we need to upgrade, we know that * this value should be either appended (if > 0) or prepended (if < 0), * because it lies outside the range of existing values. */ if (valenc > intrev32ifbe(is->encoding)) { /* This always succeeds, so we don't need to curry *success. */ return intsetUpgradeAndAdd(is,value); } else { /* Abort if the value is already present in the set. * This call will populate "pos" with the right position to insert * the value when it cannot be found. */ if (intsetSearch(is,value,&pos)) { if (success) *success = 0; return is; } is = intsetResize(is,intrev32ifbe(is->length)+1); if (pos < intrev32ifbe(is->length)) intsetMoveTail(is,pos,pos+1); } _intsetSet(is,pos,value); is->length = intrev32ifbe(intrev32ifbe(is->length)+1); return is; } /* Delete integer from intset */ intset *intsetRemove(intset *is, int64_t value, int *success) { uint8_t valenc = _intsetValueEncoding(value); uint32_t pos; if (success) *success = 0; if (valenc <= intrev32ifbe(is->encoding) && intsetSearch(is,value,&pos)) { uint32_t len = intrev32ifbe(is->length); /* We know we can delete */ if (success) *success = 1; /* Overwrite value with tail and update length */ if (pos < (len-1)) intsetMoveTail(is,pos+1,pos); is = intsetResize(is,len-1); is->length = intrev32ifbe(len-1); } return is; } intset *intsetTrimTail(intset *is, uint32_t tail_len) { uint32_t len = intrev32ifbe(is->length); uint32_t new_len = tail_len >= len ? 0 : len - tail_len; is->length = intrev32ifbe(new_len); return intsetResize(is, new_len); } /* Determine whether a value belongs to this set */ uint8_t intsetFind(intset *is, int64_t value) { uint8_t valenc = _intsetValueEncoding(value); return valenc <= intrev32ifbe(is->encoding) && intsetSearch(is,value,NULL); } /* Return random member */ int64_t intsetRandom(intset *is) { uint32_t len = intrev32ifbe(is->length); assert(len); /* avoid division by zero on corrupt intset payload. */ return _intsetGet(is,rand()%len); } /* Get the value at the given position. When this position is * out of range the function returns 0, when in range it returns 1. */ uint8_t intsetGet(intset *is, uint32_t pos, int64_t *value) { if (pos < intrev32ifbe(is->length)) { *value = _intsetGet(is,pos); return 1; } return 0; } /* Return intset length */ uint32_t intsetLen(const intset *is) { return intrev32ifbe(is->length); } /* Return intset blob size in bytes. */ size_t intsetBlobLen(intset *is) { return sizeof(intset)+(size_t)intrev32ifbe(is->length)*intrev32ifbe(is->encoding); } /* Validate the integrity of the data structure. * when `deep` is 0, only the integrity of the header is validated. * when `deep` is 1, we make sure there are no duplicate or out of order records. */ int intsetValidateIntegrity(const unsigned char *p, size_t size, int deep) { intset *is = (intset *)p; /* check that we can actually read the header. */ if (size < sizeof(*is)) return 0; uint32_t encoding = intrev32ifbe(is->encoding); size_t record_size; if (encoding == INTSET_ENC_INT64) { record_size = INTSET_ENC_INT64; } else if (encoding == INTSET_ENC_INT32) { record_size = INTSET_ENC_INT32; } else if (encoding == INTSET_ENC_INT16){ record_size = INTSET_ENC_INT16; } else { return 0; } /* check that the size matches (all records are inside the buffer). */ uint32_t count = intrev32ifbe(is->length); if (sizeof(*is) + count*record_size != size) return 0; /* check that the set is not empty. */ if (count==0) return 0; if (!deep) return 1; /* check that there are no dup or out of order records. */ int64_t prev = _intsetGet(is,0); for (uint32_t i=1; i #include #if 0 static void intsetRepr(intset *is) { for (uint32_t i = 0; i < intrev32ifbe(is->length); i++) { printf("%lld\n", (uint64_t)_intsetGet(is,i)); } printf("\n"); } static void error(char *err) { printf("%s\n", err); exit(1); } #endif static void ok(void) { printf("OK\n"); } static long long usec(void) { struct timeval tv; gettimeofday(&tv,NULL); return (((long long)tv.tv_sec)*1000000)+tv.tv_usec; } static intset *createSet(int bits, int size) { uint64_t mask = (1< 32) { value = (rand()*rand()) & mask; } else { value = rand() & mask; } is = intsetAdd(is,value,NULL); } return is; } static void checkConsistency(intset *is) { for (uint32_t i = 0; i < (intrev32ifbe(is->length)-1); i++) { uint32_t encoding = intrev32ifbe(is->encoding); if (encoding == INTSET_ENC_INT16) { int16_t *i16 = (int16_t*)is->contents; assert(i16[i] < i16[i+1]); } else if (encoding == INTSET_ENC_INT32) { int32_t *i32 = (int32_t*)is->contents; assert(i32[i] < i32[i+1]); } else { int64_t *i64 = (int64_t*)is->contents; assert(i64[i] < i64[i+1]); } } } #define UNUSED(x) (void)(x) int intsetTest(int argc, char **argv, int flags) { uint8_t success; int i; intset *is; srand(time(NULL)); UNUSED(argc); UNUSED(argv); UNUSED(flags); printf("Value encodings: "); { assert(_intsetValueEncoding(-32768) == INTSET_ENC_INT16); assert(_intsetValueEncoding(+32767) == INTSET_ENC_INT16); assert(_intsetValueEncoding(-32769) == INTSET_ENC_INT32); assert(_intsetValueEncoding(+32768) == INTSET_ENC_INT32); assert(_intsetValueEncoding(-2147483648) == INTSET_ENC_INT32); assert(_intsetValueEncoding(+2147483647) == INTSET_ENC_INT32); assert(_intsetValueEncoding(-2147483649) == INTSET_ENC_INT64); assert(_intsetValueEncoding(+2147483648) == INTSET_ENC_INT64); assert(_intsetValueEncoding(-9223372036854775808ull) == INTSET_ENC_INT64); assert(_intsetValueEncoding(+9223372036854775807ull) == INTSET_ENC_INT64); ok(); } printf("Basic adding: "); { is = intsetNew(); is = intsetAdd(is,5,&success); assert(success); is = intsetAdd(is,6,&success); assert(success); is = intsetAdd(is,4,&success); assert(success); is = intsetAdd(is,4,&success); assert(!success); ok(); zfree(is); } printf("Large number of random adds: "); { uint32_t inserts = 0; is = intsetNew(); for (i = 0; i < 1024; i++) { is = intsetAdd(is,rand()%0x800,&success); if (success) inserts++; } assert(intrev32ifbe(is->length) == inserts); checkConsistency(is); ok(); zfree(is); } printf("Upgrade from int16 to int32: "); { is = intsetNew(); is = intsetAdd(is,32,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT16); is = intsetAdd(is,65535,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT32); assert(intsetFind(is,32)); assert(intsetFind(is,65535)); checkConsistency(is); zfree(is); is = intsetNew(); is = intsetAdd(is,32,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT16); is = intsetAdd(is,-65535,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT32); assert(intsetFind(is,32)); assert(intsetFind(is,-65535)); checkConsistency(is); ok(); zfree(is); } printf("Upgrade from int16 to int64: "); { is = intsetNew(); is = intsetAdd(is,32,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT16); is = intsetAdd(is,4294967295,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT64); assert(intsetFind(is,32)); assert(intsetFind(is,4294967295)); checkConsistency(is); zfree(is); is = intsetNew(); is = intsetAdd(is,32,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT16); is = intsetAdd(is,-4294967295,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT64); assert(intsetFind(is,32)); assert(intsetFind(is,-4294967295)); checkConsistency(is); ok(); zfree(is); } printf("Upgrade from int32 to int64: "); { is = intsetNew(); is = intsetAdd(is,65535,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT32); is = intsetAdd(is,4294967295,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT64); assert(intsetFind(is,65535)); assert(intsetFind(is,4294967295)); checkConsistency(is); zfree(is); is = intsetNew(); is = intsetAdd(is,65535,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT32); is = intsetAdd(is,-4294967295,NULL); assert(intrev32ifbe(is->encoding) == INTSET_ENC_INT64); assert(intsetFind(is,65535)); assert(intsetFind(is,-4294967295)); checkConsistency(is); ok(); zfree(is); } printf("Stress lookups: "); { long num = 100000, size = 10000; int i, bits = 20; long long start; is = createSet(bits,size); checkConsistency(is); start = usec(); for (i = 0; i < num; i++) intsetSearch(is,rand() % ((1< * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __INTSET_H #define __INTSET_H #include typedef struct intset { uint32_t encoding; uint32_t length; int8_t contents[]; } intset; intset *intsetNew(void); intset *intsetAdd(intset *is, int64_t value, uint8_t *success); intset *intsetRemove(intset *is, int64_t value, int *success); intset *intsetTrimTail(intset *is, uint32_t trim_len); // Removes last trim_len elements. uint8_t intsetFind(intset *is, int64_t value); int64_t intsetRandom(intset *is); uint8_t intsetGet(intset *is, uint32_t pos, int64_t *value); uint32_t intsetLen(const intset *is); size_t intsetBlobLen(intset *is); int intsetValidateIntegrity(const unsigned char *is, size_t size, int deep); #ifdef REDIS_TEST int intsetTest(int argc, char *argv[], int flags); #endif #endif // __INTSET_H ================================================ FILE: src/redis/listpack.c ================================================ /* Listpack -- A lists of strings serialization format * * This file implements the specification you can find at: * * https://github.com/antirez/listpack * * Copyright (c) 2017,2020, Redis Ltd. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include #include "config.h" #include "listpack.h" #include "util.h" #include "zmalloc.h" #define LP_HDR_SIZE 6 /* 32 bit total len + 16 bit number of elements. */ #define LP_HDR_NUMELE_UNKNOWN UINT16_MAX #define LP_MAX_INT_ENCODING_LEN 9 #define LP_MAX_BACKLEN_SIZE 5 #define LP_ENCODING_INT 0 #define LP_ENCODING_STRING 1 #define LP_ENCODING_7BIT_UINT 0 #define LP_ENCODING_7BIT_UINT_MASK 0x80 #define LP_ENCODING_IS_7BIT_UINT(byte) (((byte)&LP_ENCODING_7BIT_UINT_MASK)==LP_ENCODING_7BIT_UINT) #define LP_ENCODING_7BIT_UINT_ENTRY_SIZE 2 #define LP_ENCODING_6BIT_STR 0x80 #define LP_ENCODING_6BIT_STR_MASK 0xC0 #define LP_ENCODING_IS_6BIT_STR(byte) (((byte)&LP_ENCODING_6BIT_STR_MASK)==LP_ENCODING_6BIT_STR) #define LP_ENCODING_13BIT_INT 0xC0 #define LP_ENCODING_13BIT_INT_MASK 0xE0 #define LP_ENCODING_IS_13BIT_INT(byte) (((byte)&LP_ENCODING_13BIT_INT_MASK)==LP_ENCODING_13BIT_INT) #define LP_ENCODING_13BIT_INT_ENTRY_SIZE 3 #define LP_ENCODING_12BIT_STR 0xE0 #define LP_ENCODING_12BIT_STR_MASK 0xF0 #define LP_ENCODING_IS_12BIT_STR(byte) (((byte)&LP_ENCODING_12BIT_STR_MASK)==LP_ENCODING_12BIT_STR) #define LP_ENCODING_16BIT_INT 0xF1 #define LP_ENCODING_16BIT_INT_MASK 0xFF #define LP_ENCODING_IS_16BIT_INT(byte) (((byte)&LP_ENCODING_16BIT_INT_MASK)==LP_ENCODING_16BIT_INT) #define LP_ENCODING_16BIT_INT_ENTRY_SIZE 4 #define LP_ENCODING_24BIT_INT 0xF2 #define LP_ENCODING_24BIT_INT_MASK 0xFF #define LP_ENCODING_IS_24BIT_INT(byte) (((byte)&LP_ENCODING_24BIT_INT_MASK)==LP_ENCODING_24BIT_INT) #define LP_ENCODING_24BIT_INT_ENTRY_SIZE 5 #define LP_ENCODING_32BIT_INT 0xF3 #define LP_ENCODING_32BIT_INT_MASK 0xFF #define LP_ENCODING_IS_32BIT_INT(byte) (((byte)&LP_ENCODING_32BIT_INT_MASK)==LP_ENCODING_32BIT_INT) #define LP_ENCODING_32BIT_INT_ENTRY_SIZE 6 #define LP_ENCODING_64BIT_INT 0xF4 #define LP_ENCODING_64BIT_INT_MASK 0xFF #define LP_ENCODING_IS_64BIT_INT(byte) (((byte)&LP_ENCODING_64BIT_INT_MASK)==LP_ENCODING_64BIT_INT) #define LP_ENCODING_64BIT_INT_ENTRY_SIZE 10 #define LP_ENCODING_32BIT_STR 0xF0 #define LP_ENCODING_32BIT_STR_MASK 0xFF #define LP_ENCODING_IS_32BIT_STR(byte) (((byte)&LP_ENCODING_32BIT_STR_MASK)==LP_ENCODING_32BIT_STR) #define LP_EOF 0xFF #define LP_ENCODING_6BIT_STR_LEN(p) ((p)[0] & 0x3F) #define LP_ENCODING_12BIT_STR_LEN(p) ((((p)[0] & 0xF) << 8) | (p)[1]) #define LP_ENCODING_32BIT_STR_LEN(p) \ (((uint32_t)(p)[1] << 0) | ((uint32_t)(p)[2] << 8) | ((uint32_t)(p)[3] << 16) | ((uint32_t)(p)[4] << 24)) #define lpGetTotalBytes(p) \ (((uint32_t)(p)[0] << 0) | ((uint32_t)(p)[1] << 8) | ((uint32_t)(p)[2] << 16) | ((uint32_t)(p)[3] << 24)) #define lpGetNumElements(p) (((uint32_t)(p)[4] << 0) | ((uint32_t)(p)[5] << 8)) #define lpSetTotalBytes(p, v) \ do { \ (p)[0] = (v)&0xff; \ (p)[1] = ((v)>>8)&0xff; \ (p)[2] = ((v)>>16)&0xff; \ (p)[3] = ((v)>>24)&0xff; \ } while(0) /* TODO: delete this function once corruption in the stream code is identified */ static void lpSetTotalBytesChecked(unsigned char *p, uint32_t v) { uint32_t current = lpGetTotalBytes(p); if (current == 0) { fprintf(stderr, "Error: corrupted listpack size."); abort(); } else if (current > 4194304) { /* 4 MiB */ /* suspicous size, lets check its validity*/ size_t block_size = zmalloc_size(p); if (block_size < current) { fprintf(stderr, "Error: listpack size (%u) is larger than allocated " "block size (%lu).", current, block_size); abort(); } } lpSetTotalBytes(p, v); } #define lpSetNumElements(p, v) \ do { \ (p)[4] = (v)&0xff; \ (p)[5] = ((v)>>8)&0xff; \ } while(0) /* Validates that 'p' is not outside the listpack. * All function that return a pointer to an element in the listpack will assert * that this element is valid, so it can be freely used. * Generally functions such lpNext and lpDelete assume the input pointer is * already validated (since it's the return value of another function). */ #define ASSERT_INTEGRITY(lp, p) \ do { \ assert((p) >= (lp)+LP_HDR_SIZE && (p) < (lp)+lpGetTotalBytes((lp))); \ } while (0) /* Similar to the above, but validates the entire element length rather than just * it's pointer. */ #define ASSERT_INTEGRITY_LEN(lp, p, len) \ do { \ assert((p) >= (lp)+LP_HDR_SIZE && (p)+(len) < (lp)+lpGetTotalBytes((lp))); \ } while (0) static inline void lpAssertValidEntry(unsigned char* lp, size_t lpbytes, unsigned char *p); /* Don't let listpacks grow over 1GB in any case, don't wanna risk overflow in * Total Bytes header field */ #define LISTPACK_MAX_SAFETY_SIZE (1<<30) int lpSafeToAdd(unsigned char* lp, size_t add) { size_t len = lp? lpGetTotalBytes(lp): 0; if (len + add > LISTPACK_MAX_SAFETY_SIZE) return 0; return 1; } /* Convert a string into a signed 64 bit integer. * The function returns 1 if the string could be parsed into a (non-overflowing) * signed 64 bit int, 0 otherwise. The 'value' will be set to the parsed value * when the function returns success. * * Note that this function demands that the string strictly represents * a int64 value: no spaces or other characters before or after the string * representing the number are accepted, nor zeroes at the start if not * for the string "0" representing the zero number. * * Because of its strictness, it is safe to use this function to check if * you can convert a string into a long long, and obtain back the string * from the number without any loss in the string representation. * * * ----------------------------------------------------------------------------- * * Credits: this function was adapted from the Redis OSS source code, file * "utils.c", function string2ll(), and is copyright: * * Copyright(C) 2011, Pieter Noordhuis * Copyright(C) 2011, Redis Ltd. * * The function is released under the BSD 3-clause license. */ int lpStringToInt64(const char *s, unsigned long slen, int64_t *value) { const char *p = s; unsigned long plen = 0; int negative = 0; uint64_t v; /* Abort if length indicates this cannot possibly be an int */ if (slen == 0 || slen >= LONG_STR_SIZE) return 0; /* Special case: first and only digit is 0. */ if (slen == 1 && p[0] == '0') { if (value != NULL) *value = 0; return 1; } if (p[0] == '-') { negative = 1; p++; plen++; /* Abort on only a negative sign. */ if (plen == slen) return 0; } /* First digit should be 1-9, otherwise the string should just be 0. */ if (p[0] >= '1' && p[0] <= '9') { v = p[0]-'0'; p++; plen++; } else { return 0; } while (plen < slen && p[0] >= '0' && p[0] <= '9') { if (v > (UINT64_MAX / 10)) /* Overflow. */ return 0; v *= 10; if (v > (UINT64_MAX - (p[0]-'0'))) /* Overflow. */ return 0; v += p[0]-'0'; p++; plen++; } /* Return if not all bytes were used. */ if (plen < slen) return 0; if (negative) { if (v > ((uint64_t)(-(INT64_MIN+1))+1)) /* Overflow. */ return 0; if (value != NULL) *value = -v; } else { if (v > INT64_MAX) /* Overflow. */ return 0; if (value != NULL) *value = v; } return 1; } /* Create a new, empty listpack. * On success the new listpack is returned, otherwise an error is returned. * Pre-allocate at least `capacity` bytes of memory, * over-allocated memory can be shrunk by `lpShrinkToFit`. * */ unsigned char *lpNew(size_t capacity) { unsigned char *lp = zmalloc(capacity > LP_HDR_SIZE+1 ? capacity : LP_HDR_SIZE+1); if (lp == NULL) return NULL; lpSetTotalBytes(lp,LP_HDR_SIZE+1); lpSetNumElements(lp,0); lp[LP_HDR_SIZE] = LP_EOF; return lp; } /* Free the specified listpack. */ void lpFree(unsigned char *lp) { zfree(lp); } /* Shrink the memory to fit. */ unsigned char* lpShrinkToFit(unsigned char *lp) { size_t size = lpGetTotalBytes(lp); if (size < zmalloc_size(lp)) { return zrealloc(lp, size); } else { return lp; } } /* Stores the integer encoded representation of 'v' in the 'intenc' buffer. */ static inline void lpEncodeIntegerGetType(int64_t v, unsigned char *intenc, uint64_t *enclen) { if (v >= 0 && v <= 127) { /* Single byte 0-127 integer. */ intenc[0] = v; *enclen = 1; } else if (v >= -4096 && v <= 4095) { /* 13 bit integer. */ if (v < 0) v = ((int64_t)1<<13)+v; intenc[0] = (v>>8)|LP_ENCODING_13BIT_INT; intenc[1] = v&0xff; *enclen = 2; } else if (v >= -32768 && v <= 32767) { /* 16 bit integer. */ if (v < 0) v = ((int64_t)1<<16)+v; intenc[0] = LP_ENCODING_16BIT_INT; intenc[1] = v&0xff; intenc[2] = v>>8; *enclen = 3; } else if (v >= -8388608 && v <= 8388607) { /* 24 bit integer. */ if (v < 0) v = ((int64_t)1<<24)+v; intenc[0] = LP_ENCODING_24BIT_INT; intenc[1] = v&0xff; intenc[2] = (v>>8)&0xff; intenc[3] = v>>16; *enclen = 4; } else if (v >= -2147483648 && v <= 2147483647) { /* 32 bit integer. */ if (v < 0) v = ((int64_t)1<<32)+v; intenc[0] = LP_ENCODING_32BIT_INT; intenc[1] = v&0xff; intenc[2] = (v>>8)&0xff; intenc[3] = (v>>16)&0xff; intenc[4] = v>>24; *enclen = 5; } else { /* 64 bit integer. */ uint64_t uv = v; intenc[0] = LP_ENCODING_64BIT_INT; intenc[1] = uv&0xff; intenc[2] = (uv>>8)&0xff; intenc[3] = (uv>>16)&0xff; intenc[4] = (uv>>24)&0xff; intenc[5] = (uv>>32)&0xff; intenc[6] = (uv>>40)&0xff; intenc[7] = (uv>>48)&0xff; intenc[8] = uv>>56; *enclen = 9; } } /* Given an element 'ele' of size 'size', determine if the element can be * represented inside the listpack encoded as integer, and returns * LP_ENCODING_INT if so. Otherwise returns LP_ENCODING_STR if no integer * encoding is possible. * * If the LP_ENCODING_INT is returned, the function stores the integer encoded * representation of the element in the 'intenc' buffer. * * Regardless of the returned encoding, 'enclen' is populated by reference to * the number of bytes that the string or integer encoded element will require * in order to be represented. */ static inline int lpEncodeGetType(const unsigned char *ele, uint32_t size, unsigned char *intenc, uint64_t *enclen) { int64_t v; if (lpStringToInt64((const char*)ele, size, &v)) { lpEncodeIntegerGetType(v, intenc, enclen); return LP_ENCODING_INT; } else { if (size < 64) *enclen = 1 + size; else if (size < 4096) *enclen = 2 + size; else *enclen = 5 + (uint64_t)size; return LP_ENCODING_STRING; } } /* Store a reverse-encoded variable length field, representing the length * of the previous element of size 'l', in the target buffer 'buf'. * The function returns the number of bytes used to encode it, from * 1 to 5. If 'buf' is NULL the function just returns the number of bytes * needed in order to encode the backlen. */ static inline unsigned long lpEncodeBacklen(unsigned char *buf, uint64_t l) { if (l <= 127) { if (buf) buf[0] = l; return 1; } else if (l < 16383) { if (buf) { buf[0] = l>>7; buf[1] = (l&127)|128; } return 2; } else if (l < 2097151) { if (buf) { buf[0] = l>>14; buf[1] = ((l>>7)&127)|128; buf[2] = (l&127)|128; } return 3; } else if (l < 268435455) { if (buf) { buf[0] = l>>21; buf[1] = ((l>>14)&127)|128; buf[2] = ((l>>7)&127)|128; buf[3] = (l&127)|128; } return 4; } else { if (buf) { buf[0] = l>>28; buf[1] = ((l>>21)&127)|128; buf[2] = ((l>>14)&127)|128; buf[3] = ((l>>7)&127)|128; buf[4] = (l&127)|128; } return 5; } } /* Decode the backlen and returns it. If the encoding looks invalid (more than * 5 bytes are used), UINT64_MAX is returned to report the problem. */ static inline uint64_t lpDecodeBacklen(unsigned char *p) { uint64_t val = 0; uint64_t shift = 0; do { val |= (uint64_t)(p[0] & 127) << shift; if (!(p[0] & 128)) break; shift += 7; p--; if (shift > 28) return UINT64_MAX; } while(1); return val; } /* Encode the string element pointed by 's' of size 'len' in the target * buffer 's'. The function should be called with 'buf' having always enough * space for encoding the string. This is done by calling lpEncodeGetType() * before calling this function. */ static inline void lpEncodeString(unsigned char *buf, const unsigned char *s, uint32_t len) { if (len < 64) { buf[0] = len | LP_ENCODING_6BIT_STR; memcpy(buf+1,s,len); } else if (len < 4096) { buf[0] = (len >> 8) | LP_ENCODING_12BIT_STR; buf[1] = len & 0xff; memcpy(buf+2,s,len); } else { buf[0] = LP_ENCODING_32BIT_STR; buf[1] = len & 0xff; buf[2] = (len >> 8) & 0xff; buf[3] = (len >> 16) & 0xff; buf[4] = (len >> 24) & 0xff; memcpy(buf+5,s,len); } } /* Return the encoded length of the listpack element pointed by 'p'. * This includes the encoding byte, length bytes, and the element data itself. * If the element encoding is wrong then 0 is returned. * Note that this method may access additional bytes (in case of 12 and 32 bit * str), so should only be called when we know 'p' was already validated by * lpCurrentEncodedSizeBytes or ASSERT_INTEGRITY_LEN (possibly since 'p' is * a return value of another function that validated its return. */ static inline uint32_t lpCurrentEncodedSizeUnsafe(unsigned char *p) { if (LP_ENCODING_IS_7BIT_UINT(p[0])) return 1; if (LP_ENCODING_IS_6BIT_STR(p[0])) return 1+LP_ENCODING_6BIT_STR_LEN(p); if (LP_ENCODING_IS_13BIT_INT(p[0])) return 2; if (LP_ENCODING_IS_16BIT_INT(p[0])) return 3; if (LP_ENCODING_IS_24BIT_INT(p[0])) return 4; if (LP_ENCODING_IS_32BIT_INT(p[0])) return 5; if (LP_ENCODING_IS_64BIT_INT(p[0])) return 9; if (LP_ENCODING_IS_12BIT_STR(p[0])) return 2+LP_ENCODING_12BIT_STR_LEN(p); if (LP_ENCODING_IS_32BIT_STR(p[0])) return 5+LP_ENCODING_32BIT_STR_LEN(p); if (p[0] == LP_EOF) return 1; return 0; } /* Return bytes needed to encode the length of the listpack element pointed by 'p'. * This includes just the encoding byte, and the bytes needed to encode the length * of the element (excluding the element data itself) * If the element encoding is wrong then 0 is returned. */ static inline uint32_t lpCurrentEncodedSizeBytes(unsigned char *p) { if (LP_ENCODING_IS_7BIT_UINT(p[0])) return 1; if (LP_ENCODING_IS_6BIT_STR(p[0])) return 1; if (LP_ENCODING_IS_13BIT_INT(p[0])) return 1; if (LP_ENCODING_IS_16BIT_INT(p[0])) return 1; if (LP_ENCODING_IS_24BIT_INT(p[0])) return 1; if (LP_ENCODING_IS_32BIT_INT(p[0])) return 1; if (LP_ENCODING_IS_64BIT_INT(p[0])) return 1; if (LP_ENCODING_IS_12BIT_STR(p[0])) return 2; if (LP_ENCODING_IS_32BIT_STR(p[0])) return 5; if (p[0] == LP_EOF) return 1; return 0; } /* Skip the current entry returning the next. It is invalid to call this * function if the current element is the EOF element at the end of the * listpack, however, while this function is used to implement lpNext(), * it does not return NULL when the EOF element is encountered. */ unsigned char *lpSkip(unsigned char *p) { unsigned long entrylen = lpCurrentEncodedSizeUnsafe(p); entrylen += lpEncodeBacklen(NULL,entrylen); p += entrylen; return p; } /* If 'p' points to an element of the listpack, calling lpNext() will return * the pointer to the next element (the one on the right), or NULL if 'p' * already pointed to the last element of the listpack. */ unsigned char *lpNext(unsigned char *lp, unsigned char *p) { assert(p); p = lpSkip(p); if (p[0] == LP_EOF) return NULL; lpAssertValidEntry(lp, lpBytes(lp), p); return p; } /* If 'p' points to an element of the listpack, calling lpPrev() will return * the pointer to the previous element (the one on the left), or NULL if 'p' * already pointed to the first element of the listpack. */ unsigned char *lpPrev(unsigned char *lp, unsigned char *p) { assert(p); if (p-lp == LP_HDR_SIZE) return NULL; p--; /* Seek the first backlen byte of the last element. */ uint64_t prevlen = lpDecodeBacklen(p); prevlen += lpEncodeBacklen(NULL,prevlen); p -= prevlen-1; /* Seek the first byte of the previous entry. */ lpAssertValidEntry(lp, lpBytes(lp), p); return p; } /* Return a pointer to the first element of the listpack, or NULL if the * listpack has no elements. */ unsigned char *lpFirst(unsigned char *lp) { unsigned char *p = lp + LP_HDR_SIZE; /* Skip the header. */ if (p[0] == LP_EOF) return NULL; lpAssertValidEntry(lp, lpBytes(lp), p); return p; } /* Return a pointer to the last element of the listpack, or NULL if the * listpack has no elements. */ unsigned char *lpLast(unsigned char *lp) { unsigned char *p = lp+lpGetTotalBytes(lp)-1; /* Seek EOF element. */ return lpPrev(lp,p); /* Will return NULL if EOF is the only element. */ } /* Return the number of elements inside the listpack. This function attempts * to use the cached value when within range, otherwise a full scan is * needed. As a side effect of calling this function, the listpack header * could be modified, because if the count is found to be already within * the 'numele' header field range, the new value is set. */ unsigned long lpLength(unsigned char *lp) { uint32_t numele = lpGetNumElements(lp); if (numele != LP_HDR_NUMELE_UNKNOWN) return numele; /* Too many elements inside the listpack. We need to scan in order * to get the total number. */ uint32_t count = 0; unsigned char *p = lpFirst(lp); while(p) { count++; p = lpNext(lp,p); } /* If the count is again within range of the header numele field, * set it. */ if (count < LP_HDR_NUMELE_UNKNOWN) lpSetNumElements(lp,count); return count; } /* Return the listpack element pointed by 'p'. * * The function changes behavior depending on the passed 'intbuf' value. * Specifically, if 'intbuf' is NULL: * * If the element is internally encoded as an integer, the function returns * NULL and populates the integer value by reference in 'count'. Otherwise if * the element is encoded as a string a pointer to the string (pointing inside * the listpack itself) is returned, and 'count' is set to the length of the * string. * * If instead 'intbuf' points to a buffer passed by the caller, that must be * at least LP_INTBUF_SIZE bytes, the function always returns the element as * it was a string (returning the pointer to the string and setting the * 'count' argument to the string length by reference). However if the element * is encoded as an integer, the 'intbuf' buffer is used in order to store * the string representation. * * The user should use one or the other form depending on what the value will * be used for. If there is immediate usage for an integer value returned * by the function, than to pass a buffer (and convert it back to a number) * is of course useless. * * If 'entry_size' is not NULL, *entry_size is set to the entry length of the * listpack element pointed by 'p'. This includes the encoding bytes, length * bytes, the element data itself, and the backlen bytes. * * If the function is called against a badly encoded ziplist, so that there * is no valid way to parse it, the function returns like if there was an * integer encoded with value 12345678900000000 + , this may * be an hint to understand that something is wrong. To crash in this case is * not sensible because of the different requirements of the application using * this lib. * * Similarly, there is no error returned since the listpack normally can be * assumed to be valid, so that would be a very high API cost. */ static inline unsigned char * lpGetWithSize(unsigned char *p, int64_t *count, unsigned char *intbuf, uint64_t *entry_size) { int64_t val; uint64_t uval, negstart, negmax; assert(p); /* assertion for valgrind (avoid NPD) */ if (LP_ENCODING_IS_7BIT_UINT(p[0])) { negstart = UINT64_MAX; /* 7 bit ints are always positive. */ negmax = 0; uval = p[0] & 0x7f; if (entry_size) *entry_size = LP_ENCODING_7BIT_UINT_ENTRY_SIZE; } else if (LP_ENCODING_IS_6BIT_STR(p[0])) { *count = LP_ENCODING_6BIT_STR_LEN(p); if (entry_size) *entry_size = 1 + *count + lpEncodeBacklen(NULL, *count + 1); return p+1; } else if (LP_ENCODING_IS_13BIT_INT(p[0])) { uval = ((p[0]&0x1f)<<8) | p[1]; negstart = (uint64_t)1<<12; negmax = 8191; if (entry_size) *entry_size = LP_ENCODING_13BIT_INT_ENTRY_SIZE; } else if (LP_ENCODING_IS_16BIT_INT(p[0])) { uval = (uint64_t)p[1] | (uint64_t)p[2] << 8; negstart = (uint64_t)1<<15; negmax = UINT16_MAX; if (entry_size) *entry_size = LP_ENCODING_16BIT_INT_ENTRY_SIZE; } else if (LP_ENCODING_IS_24BIT_INT(p[0])) { uval = (uint64_t)p[1] | (uint64_t)p[2] << 8 | (uint64_t)p[3] << 16; negstart = (uint64_t)1<<23; negmax = UINT32_MAX>>8; if (entry_size) *entry_size = LP_ENCODING_24BIT_INT_ENTRY_SIZE; } else if (LP_ENCODING_IS_32BIT_INT(p[0])) { uval = (uint64_t)p[1] | (uint64_t)p[2] << 8 | (uint64_t)p[3] << 16 | (uint64_t)p[4] << 24; negstart = (uint64_t)1<<31; negmax = UINT32_MAX; if (entry_size) *entry_size = LP_ENCODING_32BIT_INT_ENTRY_SIZE; } else if (LP_ENCODING_IS_64BIT_INT(p[0])) { uval = (uint64_t)p[1] | (uint64_t)p[2] << 8 | (uint64_t)p[3] << 16 | (uint64_t)p[4] << 24 | (uint64_t)p[5] << 32 | (uint64_t)p[6] << 40 | (uint64_t)p[7] << 48 | (uint64_t)p[8] << 56; negstart = (uint64_t)1<<63; negmax = UINT64_MAX; if (entry_size) *entry_size = LP_ENCODING_64BIT_INT_ENTRY_SIZE; } else if (LP_ENCODING_IS_12BIT_STR(p[0])) { *count = LP_ENCODING_12BIT_STR_LEN(p); if (entry_size) *entry_size = 2 + *count + lpEncodeBacklen(NULL, *count + 2); return p+2; } else if (LP_ENCODING_IS_32BIT_STR(p[0])) { *count = LP_ENCODING_32BIT_STR_LEN(p); if (entry_size) *entry_size = 5 + *count + lpEncodeBacklen(NULL, *count + 5); return p+5; } else { uval = 12345678900000000ULL + p[0]; negstart = UINT64_MAX; negmax = 0; } /* We reach this code path only for integer encodings. * Convert the unsigned value to the signed one using two's complement * rule. */ if (uval >= negstart) { /* This three steps conversion should avoid undefined behaviors * in the unsigned -> signed conversion. */ uval = negmax-uval; val = uval; val = -val-1; } else { val = uval; } /* Return the string representation of the integer or the value itself * depending on intbuf being NULL or not. */ if (intbuf) { *count = ll2string((char*)intbuf,LP_INTBUF_SIZE,(long long)val); return intbuf; } else { *count = val; return NULL; } } int lpGetInteger(unsigned char *p, int64_t *ival) { int64_t val; uint64_t uval = 0, negstart = UINT64_MAX, negmax = 0; uint8_t encoding = p[0]; // Prioritize checking for integers first. if (encoding < LP_ENCODING_7BIT_UINT_MASK) { uval = encoding & 0x7f; } else if (encoding > LP_ENCODING_32BIT_STR) { switch (encoding) { case LP_ENCODING_16BIT_INT: uval = (uint64_t)p[1] | (uint64_t)p[2] << 8; negstart = (uint64_t)1<<15; negmax = UINT16_MAX; break; case LP_ENCODING_24BIT_INT: uval = (uint64_t)p[1] | (uint64_t)p[2] << 8 | (uint64_t)p[3] << 16; negstart = (uint64_t)1<<23; negmax = UINT32_MAX>>8; break; case LP_ENCODING_32BIT_INT: uval = (uint64_t)p[1] | (uint64_t)p[2] << 8 | (uint64_t)p[3] << 16 | (uint64_t)p[4] << 24; negstart = (uint64_t)1<<31; negmax = UINT32_MAX; break; case LP_ENCODING_64BIT_INT: uval = (uint64_t)p[1] | (uint64_t)p[2] << 8 | (uint64_t)p[3] << 16 | (uint64_t)p[4] << 24 | (uint64_t)p[5] << 32 | (uint64_t)p[6] << 40 | (uint64_t)p[7] << 48 | (uint64_t)p[8] << 56; negstart = (uint64_t)1<<63; negmax = UINT64_MAX; break; default: return 0; } } else if (encoding < LP_ENCODING_13BIT_INT_MASK && encoding >= LP_ENCODING_6BIT_STR_MASK) { uval = ((encoding & 0x1f) << 8) | p[1]; negstart = (uint64_t)1 << 12; negmax = 8191; } else { // string encodings. return 0; } /* We reach this code path only for integer encodings. * Convert the unsigned value to the signed one using two's complement * rule. */ if (uval >= negstart) { /* This three steps conversion should avoid undefined behaviors * in the unsigned -> signed conversion. */ uval = negmax-uval; val = uval; val = -val-1; } else { val = uval; } *ival = val; return 1; } unsigned char *lpGet(unsigned char *p, int64_t *count, unsigned char *intbuf) { return lpGetWithSize(p, count, intbuf, NULL); } /* This is just a wrapper to lpGet() that is able to get entry value directly. * When the function returns NULL, it populates the integer value by reference in 'lval'. * Otherwise if the element is encoded as a string a pointer to the string (pointing * inside the listpack itself) is returned, and 'slen' is set to the length of the * string. */ unsigned char *lpGetValue(unsigned char *p, unsigned int *slen, long long *lval) { unsigned char *vstr; int64_t ele_len; vstr = lpGet(p, &ele_len, NULL); if (vstr) { *slen = ele_len; } else { *lval = ele_len; } return vstr; } /* Find pointer to the entry equal to the specified entry. Skip 'skip' entries * between every comparison. Returns NULL when the field could not be found. */ unsigned char *lpFind(unsigned char *lp, unsigned char *p, unsigned char *s, uint32_t slen, unsigned int skip) { int skipcnt = 0; unsigned char vencoding = 0; unsigned char *value; int64_t ll, vll; uint64_t entry_size = 123456789; /* initialized to avoid warning. */ uint32_t lp_bytes = lpBytes(lp); assert(p); while (p) { if (skipcnt == 0) { value = lpGetWithSize(p, &ll, NULL, &entry_size); if (value) { /* check the value doesn't reach outside the listpack before accessing it */ assert(p >= lp + LP_HDR_SIZE && p + entry_size < lp + lp_bytes); if (slen == ll && memcmp(value, s, slen) == 0) { return p; } } else { /* Find out if the searched field can be encoded. Note that * we do it only the first time, once done vencoding is set * to non-zero and vll is set to the integer value. */ if (vencoding == 0) { /* If the entry can be encoded as integer we set it to * 1, else set it to UCHAR_MAX, so that we don't retry * again the next time. */ if (slen >= 32 || slen == 0 || !lpStringToInt64((const char*)s, slen, &vll)) { vencoding = UCHAR_MAX; } else { vencoding = 1; } } /* Compare current entry with specified entry, do it only * if vencoding != UCHAR_MAX because if there is no encoding * possible for the field it can't be a valid integer. */ if (vencoding != UCHAR_MAX && ll == vll) { return p; } } /* Reset skip count */ skipcnt = skip; p += entry_size; } else { /* Skip entry */ skipcnt--; /* Move to next entry, avoid use `lpNext` due to `lpAssertValidEntry` in * `lpNext` will call `lpBytes`, will cause performance degradation */ p = lpSkip(p); } /* The next call to lpGetWithSize could read at most 8 bytes past `p` * We use the slower validation call only when necessary. */ if (p + 8 >= lp + lp_bytes) lpAssertValidEntry(lp, lp_bytes, p); else assert(p >= lp + LP_HDR_SIZE && p < lp + lp_bytes); if (p[0] == LP_EOF) break; } return NULL; } /* Insert, delete or replace the specified string element 'elestr' of length * 'size' or integer element 'eleint' at the specified position 'p', with 'p' * being a listpack element pointer obtained with lpFirst(), lpLast(), lpNext(), * lpPrev() or lpSeek(). * * The element is inserted before, after, or replaces the element pointed * by 'p' depending on the 'where' argument, that can be LP_BEFORE, LP_AFTER * or LP_REPLACE. * * If both 'elestr' and `eleint` are NULL, the function removes the element * pointed by 'p' instead of inserting one. * If `eleint` is non-NULL, 'size' is the length of 'eleint', the function insert * or replace with a 64 bit integer, which is stored in the 'eleint' buffer. * If 'elestr` is non-NULL, 'size' is the length of 'elestr', the function insert * or replace with a string, which is stored in the 'elestr' buffer. * * Returns NULL on out of memory or when the listpack total length would exceed * the max allowed size of 2^32-1, otherwise the new pointer to the listpack * holding the new element is returned (and the old pointer passed is no longer * considered valid) * * If 'newp' is not NULL, at the end of a successful call '*newp' will be set * to the address of the element just added, so that it will be possible to * continue an interaction with lpNext() and lpPrev(). * * For deletion operations (both 'elestr' and 'eleint' set to NULL) 'newp' is * set to the next element, on the right of the deleted one, or to NULL if the * deleted element was the last one. */ unsigned char *lpInsert(unsigned char *lp, const unsigned char *elestr, unsigned char *eleint, uint32_t size, unsigned char *p, int where, unsigned char **newp) { unsigned char intenc[LP_MAX_INT_ENCODING_LEN]; unsigned char backlen[LP_MAX_BACKLEN_SIZE]; uint64_t enclen; /* The length of the encoded element. */ int del_ele = (elestr == NULL && eleint == NULL); /* when deletion, it is conceptually replacing the element with a * zero-length element. So whatever we get passed as 'where', set * it to LP_REPLACE. */ if (del_ele) where = LP_REPLACE; /* If we need to insert after the current element, we just jump to the * next element (that could be the EOF one) and handle the case of * inserting before. So the function will actually deal with just two * cases: LP_BEFORE and LP_REPLACE. */ if (where == LP_AFTER) { p = lpSkip(p); where = LP_BEFORE; ASSERT_INTEGRITY(lp, p); } /* Store the offset of the element 'p', so that we can obtain its * address again after a reallocation. */ unsigned long poff = p-lp; int enctype; if (elestr) { /* Calling lpEncodeGetType() results into the encoded version of the * element to be stored into 'intenc' in case it is representable as * an integer: in that case, the function returns LP_ENCODING_INT. * Otherwise if LP_ENCODING_STR is returned, we'll have to call * lpEncodeString() to actually write the encoded string on place later. * * Whatever the returned encoding is, 'enclen' is populated with the * length of the encoded element. */ enctype = lpEncodeGetType(elestr,size,intenc,&enclen); if (enctype == LP_ENCODING_INT) eleint = intenc; } else if (eleint) { enctype = LP_ENCODING_INT; enclen = size; /* 'size' is the length of the encoded integer element. */ } else { enctype = -1; enclen = 0; } /* We need to also encode the backward-parsable length of the element * and append it to the end: this allows to traverse the listpack from * the end to the start. */ unsigned long backlen_size = (!del_ele) ? lpEncodeBacklen(backlen, enclen) : 0; uint64_t old_listpack_bytes = lpGetTotalBytes(lp); uint32_t replaced_len = 0; if (where == LP_REPLACE) { replaced_len = lpCurrentEncodedSizeUnsafe(p); replaced_len += lpEncodeBacklen(NULL,replaced_len); ASSERT_INTEGRITY_LEN(lp, p, replaced_len); } uint64_t new_listpack_bytes = old_listpack_bytes + enclen + backlen_size - replaced_len; if (new_listpack_bytes > UINT32_MAX) return NULL; /* We now need to reallocate in order to make space or shrink the * allocation (in case 'when' value is LP_REPLACE and the new element is * smaller). However we do that before memmoving the memory to * make room for the new element if the final allocation will get * larger, or we do it after if the final allocation will get smaller. */ unsigned char *dst = lp + poff; /* May be updated after reallocation. */ /* Realloc before: we need more room. */ if (new_listpack_bytes > old_listpack_bytes && new_listpack_bytes > zmalloc_size(lp)) { if ((lp = zrealloc(lp, new_listpack_bytes)) == NULL) return NULL; dst = lp + poff; } /* Setup the listpack relocating the elements to make the exact room * we need to store the new one. */ if (where == LP_BEFORE) { memmove(dst+enclen+backlen_size,dst,old_listpack_bytes-poff); } else { /* LP_REPLACE. */ memmove(dst + enclen + backlen_size, dst + replaced_len, old_listpack_bytes - poff - replaced_len); } /* Realloc after: we need to free space. */ if (new_listpack_bytes < old_listpack_bytes) { if ((lp = zrealloc(lp,new_listpack_bytes)) == NULL) return NULL; dst = lp + poff; } /* Store the entry. */ if (newp) { *newp = dst; /* In case of deletion, set 'newp' to NULL if the next element is * the EOF element. */ if (del_ele && dst[0] == LP_EOF) *newp = NULL; } if (!del_ele) { if (enctype == LP_ENCODING_INT) { memcpy(dst,eleint,enclen); } else if (elestr) { lpEncodeString(dst,elestr,size); } else { valkey_unreachable(); } dst += enclen; memcpy(dst,backlen,backlen_size); dst += backlen_size; } /* Update header. */ if (where != LP_REPLACE || del_ele) { uint32_t num_elements = lpGetNumElements(lp); if (num_elements != LP_HDR_NUMELE_UNKNOWN) { if (!del_ele) lpSetNumElements(lp,num_elements+1); else lpSetNumElements(lp,num_elements-1); } } lpSetTotalBytesChecked(lp,new_listpack_bytes); #if 0 /* This code path is normally disabled: what it does is to force listpack * to return *always* a new pointer after performing some modification to * the listpack, even if the previous allocation was enough. This is useful * in order to spot bugs in code using listpacks: by doing so we can find * if the caller forgets to set the new pointer where the listpack reference * is stored, after an update. */ unsigned char *oldlp = lp; lp = zmalloc(new_listpack_bytes); memcpy(lp,oldlp,new_listpack_bytes); if (newp) { unsigned long offset = (*newp)-oldlp; *newp = lp + offset; } /* Make sure the old allocation contains garbage. */ memset(oldlp,'A',new_listpack_bytes); zfree(oldlp); #endif return lp; } /* This is just a wrapper for lpInsert() to directly use a string. */ unsigned char *lpInsertString(unsigned char *lp, const unsigned char *s, uint32_t slen, unsigned char *p, int where, unsigned char **newp) { return lpInsert(lp, s, NULL, slen, p, where, newp); } /* This is just a wrapper for lpInsert() to directly use a 64 bit integer * instead of a string. */ unsigned char *lpInsertInteger(unsigned char *lp, long long lval, unsigned char *p, int where, unsigned char **newp) { uint64_t enclen; /* The length of the encoded element. */ unsigned char intenc[LP_MAX_INT_ENCODING_LEN]; lpEncodeIntegerGetType(lval, intenc, &enclen); return lpInsert(lp, NULL, intenc, enclen, p, where, newp); } /* Append the specified element 's' of length 'slen' at the head of the listpack. */ unsigned char *lpPrepend(unsigned char *lp, const unsigned char *s, uint32_t slen) { unsigned char *p = lpFirst(lp); if (!p) return lpAppend(lp, s, slen); return lpInsert(lp, s, NULL, slen, p, LP_BEFORE, NULL); } /* Append the specified integer element 'lval' at the head of the listpack. */ unsigned char *lpPrependInteger(unsigned char *lp, long long lval) { unsigned char *p = lpFirst(lp); if (!p) return lpAppendInteger(lp, lval); return lpInsertInteger(lp, lval, p, LP_BEFORE, NULL); } /* Append the specified element 'ele' of length 'size' at the end of the * listpack. It is implemented in terms of lpInsert(), so the return value is * the same as lpInsert(). */ unsigned char *lpAppend(unsigned char *lp, const unsigned char *ele, uint32_t size) { uint64_t listpack_bytes = lpGetTotalBytes(lp); unsigned char *eofptr = lp + listpack_bytes - 1; return lpInsert(lp,ele,NULL,size,eofptr,LP_BEFORE,NULL); } /* Append the specified integer element 'lval' at the end of the listpack. */ unsigned char *lpAppendInteger(unsigned char *lp, long long lval) { uint64_t listpack_bytes = lpGetTotalBytes(lp); unsigned char *eofptr = lp + listpack_bytes - 1; return lpInsertInteger(lp, lval, eofptr, LP_BEFORE, NULL); } /* This is just a wrapper for lpInsert() to directly use a string to replace * the current element. The function returns the new listpack as return * value, and also updates the current cursor by updating '*p'. */ unsigned char *lpReplace(unsigned char *lp, unsigned char **p, const unsigned char *s, uint32_t slen) { return lpInsert(lp, s, NULL, slen, *p, LP_REPLACE, p); } /* This is just a wrapper for lpInsertInteger() to directly use a 64 bit integer * instead of a string to replace the current element. The function returns * the new listpack as return value, and also updates the current cursor * by updating '*p'. */ unsigned char *lpReplaceInteger(unsigned char *lp, unsigned char **p, long long lval) { return lpInsertInteger(lp, lval, *p, LP_REPLACE, p); } /* Remove the element pointed by 'p', and return the resulting listpack. * If 'newp' is not NULL, the next element pointer (to the right of the * deleted one) is returned by reference. If the deleted element was the * last one, '*newp' is set to NULL. */ unsigned char *lpDelete(unsigned char *lp, unsigned char *p, unsigned char **newp) { return lpInsert(lp,NULL,NULL,0,p,LP_REPLACE,newp); } /* Delete a range of entries from the listpack start with the element pointed by 'p'. */ unsigned char *lpDeleteRangeWithEntry(unsigned char *lp, unsigned char **p, unsigned long num) { size_t bytes = lpBytes(lp); unsigned long deleted = 0; unsigned char *eofptr = lp + bytes - 1; unsigned char *first, *tail; first = tail = *p; if (num == 0) return lp; /* Nothing to delete, return ASAP. */ /* Find the next entry to the last entry that needs to be deleted. * lpLength may be unreliable due to corrupt data, so we cannot * treat 'num' as the number of elements to be deleted. */ while (num--) { deleted++; tail = lpSkip(tail); if (tail[0] == LP_EOF) break; lpAssertValidEntry(lp, bytes, tail); } /* Store the offset of the element 'first', so that we can obtain its * address again after a reallocation. */ unsigned long poff = first-lp; /* Move tail to the front of the listpack */ memmove(first, tail, eofptr - tail + 1); lpSetTotalBytesChecked(lp, bytes - (tail - first)); uint32_t numele = lpGetNumElements(lp); if (numele != LP_HDR_NUMELE_UNKNOWN) lpSetNumElements(lp, numele - deleted); lp = lpShrinkToFit(lp); /* Store the entry. */ *p = lp+poff; if ((*p)[0] == LP_EOF) *p = NULL; return lp; } /* Delete a range of entries from the listpack. */ unsigned char *lpDeleteRange(unsigned char *lp, long index, unsigned long num) { unsigned char *p; uint32_t numele = lpGetNumElements(lp); if (num == 0) return lp; /* Nothing to delete, return ASAP. */ if ((p = lpSeek(lp, index)) == NULL) return lp; /* If we know we're gonna delete beyond the end of the listpack, we can just move * the EOF marker, and there's no need to iterate through the entries, * but if we can't be sure how many entries there are, we rather avoid calling lpLength * since that means an additional iteration on all elements. * * Note that index could overflow, but we use the value after seek, so when we * use it no overflow happens. */ if (numele != LP_HDR_NUMELE_UNKNOWN && index < 0) index = (long)numele + index; if (numele != LP_HDR_NUMELE_UNKNOWN && (numele - (unsigned long)index) <= num) { p[0] = LP_EOF; lpSetTotalBytesChecked(lp, p - lp + 1); lpSetNumElements(lp, index); lp = lpShrinkToFit(lp); } else { lp = lpDeleteRangeWithEntry(lp, &p, num); } return lp; } /* Merge listpacks 'first' and 'second' by appending 'second' to 'first'. * * NOTE: The larger listpack is reallocated to contain the new merged listpack. * Either 'first' or 'second' can be used for the result. The parameter not * used will be free'd and set to NULL. * * After calling this function, the input parameters are no longer valid since * they are changed and free'd in-place. * * The result listpack is the contents of 'first' followed by 'second'. * * On failure: returns NULL if the merge is impossible. * On success: returns the merged listpack (which is expanded version of either * 'first' or 'second', also frees the other unused input listpack, and sets the * input listpack argument equal to newly reallocated listpack return value. */ unsigned char *lpMerge(unsigned char **first, unsigned char **second) { /* If any params are null, we can't merge, so NULL. */ if (first == NULL || *first == NULL || second == NULL || *second == NULL) return NULL; /* Can't merge same list into itself. */ if (*first == *second) return NULL; size_t first_bytes = lpBytes(*first); unsigned long first_len = lpLength(*first); size_t second_bytes = lpBytes(*second); unsigned long second_len = lpLength(*second); int append; unsigned char *source, *target; size_t target_bytes, source_bytes; /* Pick the largest listpack so we can resize easily in-place. * We must also track if we are now appending or prepending to * the target listpack. */ if (first_bytes >= second_bytes) { /* retain first, append second to first. */ target = *first; target_bytes = first_bytes; source = *second; source_bytes = second_bytes; append = 1; } else { /* else, retain second, prepend first to second. */ target = *second; target_bytes = second_bytes; source = *first; source_bytes = first_bytes; append = 0; } /* Calculate final bytes (subtract one pair of metadata) */ unsigned long long lpbytes = (unsigned long long)first_bytes + second_bytes - LP_HDR_SIZE - 1; assert(lpbytes < UINT32_MAX); /* larger values can't be stored */ unsigned long lplength = first_len + second_len; /* Combined lp length should be limited within UINT16_MAX */ lplength = lplength < UINT16_MAX ? lplength : UINT16_MAX; /* Extend target to new lpbytes then append or prepend source. */ target = zrealloc(target, lpbytes); if (append) { /* append == appending to target */ /* Copy source after target (copying over original [END]): * [TARGET - END, SOURCE - HEADER] */ memcpy(target + target_bytes - 1, source + LP_HDR_SIZE, source_bytes - LP_HDR_SIZE); } else { /* !append == prepending to target */ /* Move target *contents* exactly size of (source - [END]), * then copy source into vacated space (source - [END]): * [SOURCE - END, TARGET - HEADER] */ memmove(target + source_bytes - 1, target + LP_HDR_SIZE, target_bytes - LP_HDR_SIZE); memcpy(target, source, source_bytes - 1); } lpSetNumElements(target, lplength); lpSetTotalBytesChecked(target, lpbytes); /* Now free and NULL out what we didn't realloc */ if (append) { zfree(*second); *second = NULL; *first = target; } else { zfree(*first); *first = NULL; *second = target; } return target; } /* Return the total number of bytes the listpack is composed of. */ size_t lpBytes(unsigned char *lp) { return lpGetTotalBytes(lp); } /* Seek the specified element and returns the pointer to the seeked element. * Positive indexes specify the zero-based element to seek from the head to * the tail, negative indexes specify elements starting from the tail, where * -1 means the last element, -2 the penultimate and so forth. If the index * is out of range, NULL is returned. */ unsigned char *lpSeek(unsigned char *lp, long index) { int forward = 1; /* Seek forward by default. */ /* We want to seek from left to right or the other way around * depending on the listpack length and the element position. * However if the listpack length cannot be obtained in constant time, * we always seek from left to right. */ uint32_t numele = lpGetNumElements(lp); if (numele != LP_HDR_NUMELE_UNKNOWN) { if (index < 0) index = (long)numele+index; if (index < 0) return NULL; /* Index still < 0 means out of range. */ if (index >= (long)numele) return NULL; /* Out of range the other side. */ /* We want to scan right-to-left if the element we are looking for * is past the half of the listpack. */ if (index > (long)numele/2) { forward = 0; /* Right to left scanning always expects a negative index. Convert * our index to negative form. */ index -= numele; } } else { /* If the listpack length is unspecified, for negative indexes we * want to always scan right-to-left. */ if (index < 0) forward = 0; } /* Forward and backward scanning is trivially based on lpNext()/lpPrev(). */ if (forward) { unsigned char *ele = lpFirst(lp); while (index > 0 && ele) { ele = lpNext(lp,ele); index--; } return ele; } else { unsigned char *ele = lpLast(lp); while (index < -1 && ele) { ele = lpPrev(lp,ele); index++; } return ele; } } /* Same as lpFirst but without validation assert, to be used right before lpValidateNext. */ unsigned char *lpValidateFirst(unsigned char *lp) { unsigned char *p = lp + LP_HDR_SIZE; /* Skip the header. */ if (p[0] == LP_EOF) return NULL; return p; } /* Validate the integrity of a single listpack entry and move to the next one. * The input argument 'pp' is a reference to the current record and is advanced on exit. * Returns 1 if valid, 0 if invalid. */ int lpValidateNext(unsigned char *lp, unsigned char **pp, size_t lpbytes) { #define OUT_OF_RANGE(p) ((p) < lp + LP_HDR_SIZE || (p) > lp + lpbytes - 1) unsigned char *p = *pp; if (!p) return 0; /* Before accessing p, make sure it's valid. */ if (OUT_OF_RANGE(p)) return 0; if (*p == LP_EOF) { *pp = NULL; return 1; } /* check that we can read the encoded size */ uint32_t lenbytes = lpCurrentEncodedSizeBytes(p); if (!lenbytes) return 0; /* make sure the encoded entry length doesn't reach outside the edge of the listpack */ if (OUT_OF_RANGE(p + lenbytes)) return 0; /* get the entry length and encoded backlen. */ unsigned long entrylen = lpCurrentEncodedSizeUnsafe(p); unsigned long encodedBacklen = lpEncodeBacklen(NULL,entrylen); entrylen += encodedBacklen; /* make sure the entry doesn't reach outside the edge of the listpack */ if (OUT_OF_RANGE(p + entrylen)) return 0; /* move to the next entry */ p += entrylen; /* make sure the encoded length at the end patches the one at the beginning. */ uint64_t prevlen = lpDecodeBacklen(p-1); if (prevlen + encodedBacklen != entrylen) return 0; *pp = p; return 1; #undef OUT_OF_RANGE } /* Validate that the entry doesn't reach outside the listpack allocation. */ static inline void lpAssertValidEntry(unsigned char* lp, size_t lpbytes, unsigned char *p) { assert(lpValidateNext(lp, &p, lpbytes)); } /* Validate the integrity of the data structure. * when `deep` is 0, only the integrity of the header is validated. * when `deep` is 1, we scan all the entries one by one. */ int lpValidateIntegrity(unsigned char *lp, size_t size, int deep, listpackValidateEntryCB entry_cb, void *cb_userdata) { /* Check that we can actually read the header. (and EOF) */ if (size < LP_HDR_SIZE + 1) return 0; /* Check that the encoded size in the header must match the allocated size. */ size_t bytes = lpGetTotalBytes(lp); if (bytes != size) return 0; /* The last byte must be the terminator. */ if (lp[size - 1] != LP_EOF) return 0; if (!deep) return 1; /* Validate the individual entries. */ uint32_t count = 0; uint32_t numele = lpGetNumElements(lp); unsigned char *p = lp + LP_HDR_SIZE; while(p && p[0] != LP_EOF) { unsigned char *prev = p; /* Validate this entry and move to the next entry in advance * to avoid callback crash due to corrupt listpack. */ if (!lpValidateNext(lp, &p, bytes)) return 0; /* Optionally let the caller validate the entry too. */ if (entry_cb && !entry_cb(prev, numele, cb_userdata)) return 0; count++; } /* Make sure 'p' really does point to the end of the listpack. */ if (p != lp + size - 1) return 0; /* Check that the count in the header is correct */ if (numele != LP_HDR_NUMELE_UNKNOWN && numele != count) return 0; return 1; } /* Compare entry pointer to by 'p' with string 's' of length 'slen'. * Return 1 if equal. */ unsigned int lpCompare(unsigned char *p, const unsigned char *s, uint32_t slen) { unsigned char *value; int64_t sz; if (p[0] == LP_EOF) return 0; value = lpGet(p, &sz, NULL); if (value) { return (slen == sz) && memcmp(value,s,slen) == 0; } else { /* We use lpStringToInt64() to get an integer representation of the * string 's' and compare it to 'sval', it's much faster than convert * integer to string and comparing. */ int64_t sval; if (lpStringToInt64((const char *)s, slen, &sval)) return sz == sval; } return 0; } /* uint compare for qsort */ static int uintCompare(const void *a, const void *b) { return (*(unsigned int *) a - *(unsigned int *) b); } /* Helper method to store a string into from val or lval into dest */ static inline void lpSaveValue(unsigned char *val, unsigned int len, int64_t lval, listpackEntry *dest) { dest->sval = val; dest->slen = len; dest->lval = lval; } /* Randomly select a pair of key and value. * total_count is a pre-computed length/2 of the listpack (to avoid calls to lpLength) * 'key' and 'val' are used to store the result key value pair. * 'val' can be NULL if the value is not needed. */ void lpRandomPair(unsigned char *lp, unsigned long total_count, listpackEntry *key, listpackEntry *val) { unsigned char *p; /* Avoid div by zero on corrupt listpack */ assert(total_count); /* Generate even numbers, because listpack saved K-V pair */ int r = (rand() % total_count) * 2; p = lpSeek(lp, r); assert(p); key->sval = lpGetValue(p, &(key->slen), &(key->lval)); if (!val) return; p = lpNext(lp, p); assert(p); val->sval = lpGetValue(p, &(val->slen), &(val->lval)); } /* Randomly select count of key value pairs and store into 'keys' and * 'vals' args. The order of the picked entries is random, and the selections * are non-unique (repetitions are possible). * The 'vals' arg can be NULL in which case we skip these. */ void lpRandomPairs(unsigned char *lp, unsigned int count, listpackEntry *keys, listpackEntry *vals) { unsigned char *p, *key, *value; unsigned int klen = 0, vlen = 0; long long klval = 0, vlval = 0; /* Notice: the index member must be first due to the use in uintCompare */ typedef struct { unsigned int index; unsigned int order; } rand_pick; rand_pick *picks = zmalloc(sizeof(rand_pick)*count); unsigned int total_size = lpLength(lp)/2; /* Avoid div by zero on corrupt listpack */ assert(total_size); /* create a pool of random indexes (some may be duplicate). */ for (unsigned int i = 0; i < count; i++) { picks[i].index = (rand() % total_size) * 2; /* Generate even indexes */ /* keep track of the order we picked them */ picks[i].order = i; } /* sort by indexes. */ qsort(picks, count, sizeof(rand_pick), uintCompare); /* fetch the elements form the listpack into a output array respecting the original order. */ unsigned int lpindex = picks[0].index, pickindex = 0; p = lpSeek(lp, lpindex); while (p && pickindex < count) { key = lpGetValue(p, &klen, &klval); p = lpNext(lp, p); assert(p); value = lpGetValue(p, &vlen, &vlval); while (pickindex < count && lpindex == picks[pickindex].index) { int storeorder = picks[pickindex].order; lpSaveValue(key, klen, klval, &keys[storeorder]); if (vals) lpSaveValue(value, vlen, vlval, &vals[storeorder]); pickindex++; } lpindex += 2; p = lpNext(lp, p); } zfree(picks); } /* Randomly select count of key value pairs and store into 'keys' and * 'vals' args. The selections are unique (no repetitions), and the order of * the picked entries is NOT-random. * The 'vals' arg can be NULL in which case we skip these. * The return value is the number of items picked which can be lower than the * requested count if the listpack doesn't hold enough pairs. */ unsigned int lpRandomPairsUnique(unsigned char *lp, unsigned int count, listpackEntry *keys, listpackEntry *vals) { unsigned char *p, *key; unsigned int klen = 0; long long klval = 0; unsigned int total_size = lpLength(lp)/2; unsigned int index = 0; if (count > total_size) count = total_size; /* To only iterate once, every time we try to pick a member, the probability * we pick it is the quotient of the count left we want to pick and the * count still we haven't visited in the dict, this way, we could make every * member be equally picked.*/ p = lpFirst(lp); unsigned int picked = 0, remaining = count; while (picked < count && p) { double randomDouble = ((double)rand()) / RAND_MAX; double threshold = ((double)remaining) / (total_size - index); if (randomDouble <= threshold) { key = lpGetValue(p, &klen, &klval); lpSaveValue(key, klen, klval, &keys[picked]); p = lpNext(lp, p); assert(p); if (vals) { key = lpGetValue(p, &klen, &klval); lpSaveValue(key, klen, klval, &vals[picked]); } remaining--; picked++; } else { p = lpNext(lp, p); assert(p); } p = lpNext(lp, p); index++; } return picked; } /* Print info of listpack which is used in debugCommand */ void lpRepr(unsigned char *lp) { unsigned char *p, *vstr; int64_t vlen; unsigned char intbuf[LP_INTBUF_SIZE]; int index = 0; printf("{total bytes %zu} {num entries %lu}\n", lpBytes(lp), lpLength(lp)); p = lpFirst(lp); while(p) { uint32_t encoded_size_bytes = lpCurrentEncodedSizeBytes(p); uint32_t encoded_size = lpCurrentEncodedSizeUnsafe(p); unsigned long back_len = lpEncodeBacklen(NULL, encoded_size); printf("{\n" "\taddr: 0x%08lx,\n" "\tindex: %2d,\n" "\toffset: %1lu,\n" "\thdr+entrylen+backlen: %2lu,\n" "\thdrlen: %3u,\n" "\tbacklen: %2lu,\n" "\tpayload: %1u\n", (long unsigned)p, index, (unsigned long)(p - lp), encoded_size + back_len, encoded_size_bytes, back_len, encoded_size - encoded_size_bytes); printf("\tbytes: "); for (unsigned int i = 0; i < (encoded_size + back_len); i++) { printf("%02x|",p[i]); } printf("\n"); vstr = lpGet(p, &vlen, intbuf); printf("\t[str]"); if (vlen > 40) { if (fwrite(vstr, 40, 1, stdout) == 0) perror("fwrite"); printf("..."); } else { if (fwrite(vstr, vlen, 1, stdout) == 0) perror("fwrite"); } printf("\n}\n"); index++; p = lpNext(lp, p); } printf("{end}\n\n"); } #ifdef REDIS_TEST #include #include "adlist.h" #include "sds.h" #include "testhelp.h" #define UNUSED(x) (void)(x) #define TEST(name) printf("test — %s\n", name); char *mixlist[] = {"hello", "foo", "quux", "1024"}; char *intlist[] = {"4294967296", "-100", "100", "128000", "non integer", "much much longer non integer"}; static unsigned char *createList() { unsigned char *lp = lpNew(0); lp = lpAppend(lp, (unsigned char*)mixlist[1], strlen(mixlist[1])); lp = lpAppend(lp, (unsigned char*)mixlist[2], strlen(mixlist[2])); lp = lpPrepend(lp, (unsigned char*)mixlist[0], strlen(mixlist[0])); lp = lpAppend(lp, (unsigned char*)mixlist[3], strlen(mixlist[3])); return lp; } static unsigned char *createIntList() { unsigned char *lp = lpNew(0); lp = lpAppend(lp, (unsigned char*)intlist[2], strlen(intlist[2])); lp = lpAppend(lp, (unsigned char*)intlist[3], strlen(intlist[3])); lp = lpPrepend(lp, (unsigned char*)intlist[1], strlen(intlist[1])); lp = lpPrepend(lp, (unsigned char*)intlist[0], strlen(intlist[0])); lp = lpAppend(lp, (unsigned char*)intlist[4], strlen(intlist[4])); lp = lpAppend(lp, (unsigned char*)intlist[5], strlen(intlist[5])); return lp; } static long long usec(void) { struct timeval tv; gettimeofday(&tv, NULL); return (((long long)tv.tv_sec)*1000000)+tv.tv_usec; } static void stress(int pos, int num, int maxsize, int dnum) { int i, j, k; unsigned char *lp; char posstr[2][5] = { "HEAD", "TAIL" }; long long start; for (i = 0; i < maxsize; i+=dnum) { lp = lpNew(0); for (j = 0; j < i; j++) { lp = lpAppend(lp, (unsigned char*)"quux", 4); } /* Do num times a push+pop from pos */ start = usec(); for (k = 0; k < num; k++) { if (pos == 0) { lp = lpPrepend(lp, (unsigned char*)"quux", 4); } else { lp = lpAppend(lp, (unsigned char*)"quux", 4); } lp = lpDelete(lp, lpFirst(lp), NULL); } printf("List size: %8d, bytes: %8zu, %dx push+pop (%s): %6lld usec\n", i, lpBytes(lp), num, posstr[pos], usec()-start); lpFree(lp); } } static unsigned char *pop(unsigned char *lp, int where) { unsigned char *p, *vstr; int64_t vlen; p = lpSeek(lp, where == 0 ? 0 : -1); vstr = lpGet(p, &vlen, NULL); if (where == 0) printf("Pop head: "); else printf("Pop tail: "); if (vstr) { if (vlen && fwrite(vstr, vlen, 1, stdout) == 0) perror("fwrite"); } else { printf("%lld", (long long)vlen); } printf("\n"); return lpDelete(lp, p, &p); } static int randstring(char *target, unsigned int min, unsigned int max) { int p = 0; int len = min+rand()%(max-min+1); int minval, maxval; switch(rand() % 3) { case 0: minval = 0; maxval = 255; break; case 1: minval = 48; maxval = 122; break; case 2: minval = 48; maxval = 52; break; default: assert(NULL); } while(p < len) target[p++] = minval+rand()%(maxval-minval+1); return len; } static void verifyEntry(unsigned char *p, unsigned char *s, size_t slen) { assert(lpCompare(p, s, slen)); } static int lpValidation(unsigned char *p, unsigned int head_count, void *userdata) { UNUSED(p); UNUSED(head_count); int ret; long *count = userdata; ret = lpCompare(p, (unsigned char *)mixlist[*count], strlen(mixlist[*count])); (*count)++; return ret; } int listpackTest(int argc, char *argv[], int flags) { UNUSED(argc); UNUSED(argv); int i; unsigned char *lp, *p, *vstr; int64_t vlen; unsigned char intbuf[LP_INTBUF_SIZE]; int accurate = (flags & REDIS_TEST_ACCURATE); TEST("Create int list") { lp = createIntList(); assert(lpLength(lp) == 6); lpFree(lp); } TEST("Create list") { lp = createList(); assert(lpLength(lp) == 4); lpFree(lp); } TEST("Test lpPrepend") { lp = lpNew(0); lp = lpPrepend(lp, (unsigned char*)"abc", 3); lp = lpPrepend(lp, (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp, 0), (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp, 1), (unsigned char*)"abc", 3); lpFree(lp); } TEST("Test lpPrependInteger") { lp = lpNew(0); lp = lpPrependInteger(lp, 127); lp = lpPrependInteger(lp, 4095); lp = lpPrependInteger(lp, 32767); lp = lpPrependInteger(lp, 8388607); lp = lpPrependInteger(lp, 2147483647); lp = lpPrependInteger(lp, 9223372036854775807); verifyEntry(lpSeek(lp, 0), (unsigned char*)"9223372036854775807", 19); verifyEntry(lpSeek(lp, -1), (unsigned char*)"127", 3); lpFree(lp); } TEST("Get element at index") { lp = createList(); verifyEntry(lpSeek(lp, 0), (unsigned char*)"hello", 5); verifyEntry(lpSeek(lp, 3), (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp, -1), (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp, -4), (unsigned char*)"hello", 5); assert(lpSeek(lp, 4) == NULL); assert(lpSeek(lp, -5) == NULL); lpFree(lp); } TEST("Pop list") { lp = createList(); lp = pop(lp, 1); lp = pop(lp, 0); lp = pop(lp, 1); lp = pop(lp, 1); lpFree(lp); } TEST("Get element at index") { lp = createList(); verifyEntry(lpSeek(lp, 0), (unsigned char*)"hello", 5); verifyEntry(lpSeek(lp, 3), (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp, -1), (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp, -4), (unsigned char*)"hello", 5); assert(lpSeek(lp, 4) == NULL); assert(lpSeek(lp, -5) == NULL); lpFree(lp); } TEST("Iterate list from 0 to end") { lp = createList(); p = lpFirst(lp); i = 0; while (p) { verifyEntry(p, (unsigned char*)mixlist[i], strlen(mixlist[i])); p = lpNext(lp, p); i++; } lpFree(lp); } TEST("Iterate list from 1 to end") { lp = createList(); i = 1; p = lpSeek(lp, i); while (p) { verifyEntry(p, (unsigned char*)mixlist[i], strlen(mixlist[i])); p = lpNext(lp, p); i++; } lpFree(lp); } TEST("Iterate list from 2 to end") { lp = createList(); i = 2; p = lpSeek(lp, i); while (p) { verifyEntry(p, (unsigned char*)mixlist[i], strlen(mixlist[i])); p = lpNext(lp, p); i++; } lpFree(lp); } TEST("Iterate from back to front") { lp = createList(); p = lpLast(lp); i = 3; while (p) { verifyEntry(p, (unsigned char*)mixlist[i], strlen(mixlist[i])); p = lpPrev(lp, p); i--; } lpFree(lp); } TEST("Iterate from back to front, deleting all items") { lp = createList(); p = lpLast(lp); i = 3; while ((p = lpLast(lp))) { verifyEntry(p, (unsigned char*)mixlist[i], strlen(mixlist[i])); lp = lpDelete(lp, p, &p); assert(p == NULL); i--; } lpFree(lp); } TEST("Delete whole listpack when num == -1"); { lp = createList(); lp = lpDeleteRange(lp, 0, -1); assert(lpLength(lp) == 0); assert(lp[LP_HDR_SIZE] == LP_EOF); assert(lpBytes(lp) == (LP_HDR_SIZE + 1)); zfree(lp); lp = createList(); unsigned char *ptr = lpFirst(lp); lp = lpDeleteRangeWithEntry(lp, &ptr, -1); assert(lpLength(lp) == 0); assert(lp[LP_HDR_SIZE] == LP_EOF); assert(lpBytes(lp) == (LP_HDR_SIZE + 1)); zfree(lp); } TEST("Delete whole listpack with negative index"); { lp = createList(); lp = lpDeleteRange(lp, -4, 4); assert(lpLength(lp) == 0); assert(lp[LP_HDR_SIZE] == LP_EOF); assert(lpBytes(lp) == (LP_HDR_SIZE + 1)); zfree(lp); lp = createList(); unsigned char *ptr = lpSeek(lp, -4); lp = lpDeleteRangeWithEntry(lp, &ptr, 4); assert(lpLength(lp) == 0); assert(lp[LP_HDR_SIZE] == LP_EOF); assert(lpBytes(lp) == (LP_HDR_SIZE + 1)); zfree(lp); } TEST("Delete inclusive range 0,0"); { lp = createList(); lp = lpDeleteRange(lp, 0, 1); assert(lpLength(lp) == 3); assert(lpSkip(lpLast(lp))[0] == LP_EOF); /* check set LP_EOF correctly */ zfree(lp); lp = createList(); unsigned char *ptr = lpFirst(lp); lp = lpDeleteRangeWithEntry(lp, &ptr, 1); assert(lpLength(lp) == 3); assert(lpSkip(lpLast(lp))[0] == LP_EOF); /* check set LP_EOF correctly */ zfree(lp); } TEST("Delete inclusive range 0,1"); { lp = createList(); lp = lpDeleteRange(lp, 0, 2); assert(lpLength(lp) == 2); verifyEntry(lpFirst(lp), (unsigned char*)mixlist[2], strlen(mixlist[2])); zfree(lp); lp = createList(); unsigned char *ptr = lpFirst(lp); lp = lpDeleteRangeWithEntry(lp, &ptr, 2); assert(lpLength(lp) == 2); verifyEntry(lpFirst(lp), (unsigned char*)mixlist[2], strlen(mixlist[2])); zfree(lp); } TEST("Delete inclusive range 1,2"); { lp = createList(); lp = lpDeleteRange(lp, 1, 2); assert(lpLength(lp) == 2); verifyEntry(lpFirst(lp), (unsigned char*)mixlist[0], strlen(mixlist[0])); zfree(lp); lp = createList(); unsigned char *ptr = lpSeek(lp, 1); lp = lpDeleteRangeWithEntry(lp, &ptr, 2); assert(lpLength(lp) == 2); verifyEntry(lpFirst(lp), (unsigned char*)mixlist[0], strlen(mixlist[0])); zfree(lp); } TEST("Delete with start index out of range"); { lp = createList(); lp = lpDeleteRange(lp, 5, 1); assert(lpLength(lp) == 4); zfree(lp); } TEST("Delete with num overflow"); { lp = createList(); lp = lpDeleteRange(lp, 1, 5); assert(lpLength(lp) == 1); verifyEntry(lpFirst(lp), (unsigned char*)mixlist[0], strlen(mixlist[0])); zfree(lp); lp = createList(); unsigned char *ptr = lpSeek(lp, 1); lp = lpDeleteRangeWithEntry(lp, &ptr, 5); assert(lpLength(lp) == 1); verifyEntry(lpFirst(lp), (unsigned char*)mixlist[0], strlen(mixlist[0])); zfree(lp); } TEST("Delete foo while iterating") { lp = createList(); p = lpFirst(lp); while (p) { if (lpCompare(p, (unsigned char*)"foo", 3)) { lp = lpDelete(lp, p, &p); } else { p = lpNext(lp, p); } } lpFree(lp); } TEST("Replace with same size") { lp = createList(); /* "hello", "foo", "quux", "1024" */ unsigned char *orig_lp = lp; p = lpSeek(lp, 0); lp = lpReplace(lp, &p, (unsigned char*)"zoink", 5); p = lpSeek(lp, 3); lp = lpReplace(lp, &p, (unsigned char*)"y", 1); p = lpSeek(lp, 1); lp = lpReplace(lp, &p, (unsigned char*)"65536", 5); p = lpSeek(lp, 0); assert(!memcmp((char*)p, "\x85zoink\x06" "\xf2\x00\x00\x01\x04" /* 65536 as int24 */ "\x84quux\05" "\x81y\x02" "\xff", 22)); assert(lp == orig_lp); /* no reallocations have happened */ lpFree(lp); } TEST("Replace with different size") { lp = createList(); /* "hello", "foo", "quux", "1024" */ p = lpSeek(lp, 1); lp = lpReplace(lp, &p, (unsigned char*)"squirrel", 8); p = lpSeek(lp, 0); assert(!strncmp((char*)p, "\x85hello\x06" "\x88squirrel\x09" "\x84quux\x05" "\xc4\x00\x02" "\xff", 27)); lpFree(lp); } TEST("Regression test for >255 byte strings") { char v1[257] = {0}, v2[257] = {0}; memset(v1,'x',256); memset(v2,'y',256); lp = lpNew(0); lp = lpAppend(lp, (unsigned char*)v1 ,strlen(v1)); lp = lpAppend(lp, (unsigned char*)v2 ,strlen(v2)); /* Pop values again and compare their value. */ p = lpFirst(lp); vstr = lpGet(p, &vlen, NULL); assert(strncmp(v1, (char*)vstr, vlen) == 0); p = lpSeek(lp, 1); vstr = lpGet(p, &vlen, NULL); assert(strncmp(v2, (char*)vstr, vlen) == 0); lpFree(lp); } TEST("Create long list and check indices") { lp = lpNew(0); char buf[32]; int i,len; for (i = 0; i < 1000; i++) { len = sprintf(buf, "%d", i); lp = lpAppend(lp, (unsigned char*)buf, len); } for (i = 0; i < 1000; i++) { p = lpSeek(lp, i); vstr = lpGet(p, &vlen, NULL); assert(i == vlen); p = lpSeek(lp, -i-1); vstr = lpGet(p, &vlen, NULL); assert(999-i == vlen); } lpFree(lp); } TEST("Compare strings with listpack entries") { lp = createList(); p = lpSeek(lp,0); assert(lpCompare(p,(unsigned char*)"hello",5)); assert(!lpCompare(p,(unsigned char*)"hella",5)); p = lpSeek(lp,3); assert(lpCompare(p,(unsigned char*)"1024",4)); assert(!lpCompare(p,(unsigned char*)"1025",4)); lpFree(lp); } TEST("lpMerge two empty listpacks") { unsigned char *lp1 = lpNew(0); unsigned char *lp2 = lpNew(0); /* Merge two empty listpacks, get empty result back. */ lp1 = lpMerge(&lp1, &lp2); assert(lpLength(lp1) == 0); zfree(lp1); } TEST("lpMerge two listpacks - first larger than second") { unsigned char *lp1 = createIntList(); unsigned char *lp2 = createList(); size_t lp1_bytes = lpBytes(lp1); size_t lp2_bytes = lpBytes(lp2); unsigned long lp1_len = lpLength(lp1); unsigned long lp2_len = lpLength(lp2); unsigned char *lp3 = lpMerge(&lp1, &lp2); assert(lp3 == lp1); assert(lp2 == NULL); assert(lpLength(lp3) == (lp1_len + lp2_len)); assert(lpBytes(lp3) == (lp1_bytes + lp2_bytes - LP_HDR_SIZE - 1)); verifyEntry(lpSeek(lp3, 0), (unsigned char*)"4294967296", 10); verifyEntry(lpSeek(lp3, 5), (unsigned char*)"much much longer non integer", 28); verifyEntry(lpSeek(lp3, 6), (unsigned char*)"hello", 5); verifyEntry(lpSeek(lp3, -1), (unsigned char*)"1024", 4); zfree(lp3); } TEST("lpMerge two listpacks - second larger than first") { unsigned char *lp1 = createList(); unsigned char *lp2 = createIntList(); size_t lp1_bytes = lpBytes(lp1); size_t lp2_bytes = lpBytes(lp2); unsigned long lp1_len = lpLength(lp1); unsigned long lp2_len = lpLength(lp2); unsigned char *lp3 = lpMerge(&lp1, &lp2); assert(lp3 == lp2); assert(lp1 == NULL); assert(lpLength(lp3) == (lp1_len + lp2_len)); assert(lpBytes(lp3) == (lp1_bytes + lp2_bytes - LP_HDR_SIZE - 1)); verifyEntry(lpSeek(lp3, 0), (unsigned char*)"hello", 5); verifyEntry(lpSeek(lp3, 3), (unsigned char*)"1024", 4); verifyEntry(lpSeek(lp3, 4), (unsigned char*)"4294967296", 10); verifyEntry(lpSeek(lp3, -1), (unsigned char*)"much much longer non integer", 28); zfree(lp3); } TEST("Random pair with one element") { listpackEntry key, val; unsigned char *lp = lpNew(0); lp = lpAppend(lp, (unsigned char*)"abc", 3); lp = lpAppend(lp, (unsigned char*)"123", 3); lpRandomPair(lp, 1, &key, &val); assert(memcmp(key.sval, "abc", key.slen) == 0); assert(val.lval == 123); lpFree(lp); } TEST("Random pair with many elements") { listpackEntry key, val; unsigned char *lp = lpNew(0); lp = lpAppend(lp, (unsigned char*)"abc", 3); lp = lpAppend(lp, (unsigned char*)"123", 3); lp = lpAppend(lp, (unsigned char*)"456", 3); lp = lpAppend(lp, (unsigned char*)"def", 3); lpRandomPair(lp, 2, &key, &val); if (key.sval) { assert(!memcmp(key.sval, "abc", key.slen)); assert(key.slen == 3); assert(val.lval == 123); } if (!key.sval) { assert(key.lval == 456); assert(!memcmp(val.sval, "def", val.slen)); } lpFree(lp); } TEST("Random pairs with one element") { int count = 5; unsigned char *lp = lpNew(0); listpackEntry *keys = zmalloc(sizeof(listpackEntry) * count); listpackEntry *vals = zmalloc(sizeof(listpackEntry) * count); lp = lpAppend(lp, (unsigned char*)"abc", 3); lp = lpAppend(lp, (unsigned char*)"123", 3); lpRandomPairs(lp, count, keys, vals); assert(memcmp(keys[4].sval, "abc", keys[4].slen) == 0); assert(vals[4].lval == 123); zfree(keys); zfree(vals); lpFree(lp); } TEST("Random pairs with many elements") { int count = 5; lp = lpNew(0); listpackEntry *keys = zmalloc(sizeof(listpackEntry) * count); listpackEntry *vals = zmalloc(sizeof(listpackEntry) * count); lp = lpAppend(lp, (unsigned char*)"abc", 3); lp = lpAppend(lp, (unsigned char*)"123", 3); lp = lpAppend(lp, (unsigned char*)"456", 3); lp = lpAppend(lp, (unsigned char*)"def", 3); lpRandomPairs(lp, count, keys, vals); for (int i = 0; i < count; i++) { if (keys[i].sval) { assert(!memcmp(keys[i].sval, "abc", keys[i].slen)); assert(keys[i].slen == 3); assert(vals[i].lval == 123); } if (!keys[i].sval) { assert(keys[i].lval == 456); assert(!memcmp(vals[i].sval, "def", vals[i].slen)); } } zfree(keys); zfree(vals); lpFree(lp); } TEST("Random pairs unique with one element") { unsigned picked; int count = 5; lp = lpNew(0); listpackEntry *keys = zmalloc(sizeof(listpackEntry) * count); listpackEntry *vals = zmalloc(sizeof(listpackEntry) * count); lp = lpAppend(lp, (unsigned char*)"abc", 3); lp = lpAppend(lp, (unsigned char*)"123", 3); picked = lpRandomPairsUnique(lp, count, keys, vals); assert(picked == 1); assert(memcmp(keys[0].sval, "abc", keys[0].slen) == 0); assert(vals[0].lval == 123); zfree(keys); zfree(vals); lpFree(lp); } TEST("Random pairs unique with many elements") { unsigned picked; int count = 5; lp = lpNew(0); listpackEntry *keys = zmalloc(sizeof(listpackEntry) * count); listpackEntry *vals = zmalloc(sizeof(listpackEntry) * count); lp = lpAppend(lp, (unsigned char*)"abc", 3); lp = lpAppend(lp, (unsigned char*)"123", 3); lp = lpAppend(lp, (unsigned char*)"456", 3); lp = lpAppend(lp, (unsigned char*)"def", 3); picked = lpRandomPairsUnique(lp, count, keys, vals); assert(picked == 2); for (int i = 0; i < 2; i++) { if (keys[i].sval) { assert(!memcmp(keys[i].sval, "abc", keys[i].slen)); assert(keys[i].slen == 3); assert(vals[i].lval == 123); } if (!keys[i].sval) { assert(keys[i].lval == 456); assert(!memcmp(vals[i].sval, "def", vals[i].slen)); } } zfree(keys); zfree(vals); lpFree(lp); } TEST("push various encodings") { lp = lpNew(0); /* Push integer encode element using lpAppend */ lp = lpAppend(lp, (unsigned char*)"127", 3); assert(LP_ENCODING_IS_7BIT_UINT(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)"4095", 4); assert(LP_ENCODING_IS_13BIT_INT(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)"32767", 5); assert(LP_ENCODING_IS_16BIT_INT(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)"8388607", 7); assert(LP_ENCODING_IS_24BIT_INT(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)"2147483647", 10); assert(LP_ENCODING_IS_32BIT_INT(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)"9223372036854775807", 19); assert(LP_ENCODING_IS_64BIT_INT(lpLast(lp)[0])); /* Push integer encode element using lpAppendInteger */ lp = lpAppendInteger(lp, 127); assert(LP_ENCODING_IS_7BIT_UINT(lpLast(lp)[0])); verifyEntry(lpLast(lp), (unsigned char*)"127", 3); lp = lpAppendInteger(lp, 4095); verifyEntry(lpLast(lp), (unsigned char*)"4095", 4); assert(LP_ENCODING_IS_13BIT_INT(lpLast(lp)[0])); lp = lpAppendInteger(lp, 32767); verifyEntry(lpLast(lp), (unsigned char*)"32767", 5); assert(LP_ENCODING_IS_16BIT_INT(lpLast(lp)[0])); lp = lpAppendInteger(lp, 8388607); verifyEntry(lpLast(lp), (unsigned char*)"8388607", 7); assert(LP_ENCODING_IS_24BIT_INT(lpLast(lp)[0])); lp = lpAppendInteger(lp, 2147483647); verifyEntry(lpLast(lp), (unsigned char*)"2147483647", 10); assert(LP_ENCODING_IS_32BIT_INT(lpLast(lp)[0])); lp = lpAppendInteger(lp, 9223372036854775807); verifyEntry(lpLast(lp), (unsigned char*)"9223372036854775807", 19); assert(LP_ENCODING_IS_64BIT_INT(lpLast(lp)[0])); /* string encode */ unsigned char *str = zmalloc(65535); memset(str, 0, 65535); lp = lpAppend(lp, (unsigned char*)str, 63); assert(LP_ENCODING_IS_6BIT_STR(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)str, 4095); assert(LP_ENCODING_IS_12BIT_STR(lpLast(lp)[0])); lp = lpAppend(lp, (unsigned char*)str, 65535); assert(LP_ENCODING_IS_32BIT_STR(lpLast(lp)[0])); zfree(str); lpFree(lp); } TEST("Test lpFind") { lp = createList(); assert(lpFind(lp, lpFirst(lp), (unsigned char*)"abc", 3, 0) == NULL); verifyEntry(lpFind(lp, lpFirst(lp), (unsigned char*)"hello", 5, 0), (unsigned char*)"hello", 5); verifyEntry(lpFind(lp, lpFirst(lp), (unsigned char*)"1024", 4, 0), (unsigned char*)"1024", 4); lpFree(lp); } TEST("Test lpValidateIntegrity") { lp = createList(); long count = 0; assert(lpValidateIntegrity(lp, lpBytes(lp), 1, lpValidation, &count) == 1); lpFree(lp); } TEST("Test number of elements exceeds LP_HDR_NUMELE_UNKNOWN") { lp = lpNew(0); for (int i = 0; i < LP_HDR_NUMELE_UNKNOWN + 1; i++) lp = lpAppend(lp, (unsigned char*)"1", 1); assert(lpGetNumElements(lp) == LP_HDR_NUMELE_UNKNOWN); assert(lpLength(lp) == LP_HDR_NUMELE_UNKNOWN+1); lp = lpDeleteRange(lp, -2, 2); assert(lpGetNumElements(lp) == LP_HDR_NUMELE_UNKNOWN); assert(lpLength(lp) == LP_HDR_NUMELE_UNKNOWN-1); assert(lpGetNumElements(lp) == LP_HDR_NUMELE_UNKNOWN-1); /* update length after lpLength */ lpFree(lp); } TEST("Stress with random payloads of different encoding") { unsigned long long start = usec(); int i,j,len,where; unsigned char *p; char buf[1024]; int buflen; list *ref; listNode *refnode; int iteration = accurate ? 20000 : 20; for (i = 0; i < iteration; i++) { lp = lpNew(0); ref = listCreate(); listSetFreeMethod(ref,(void (*)(void*))sdsfree); len = rand() % 256; /* Create lists */ for (j = 0; j < len; j++) { where = (rand() & 1) ? 0 : 1; if (rand() % 2) { buflen = randstring(buf,1,sizeof(buf)-1); } else { switch(rand() % 3) { case 0: buflen = sprintf(buf,"%lld",(0LL + rand()) >> 20); break; case 1: buflen = sprintf(buf,"%lld",(0LL + rand())); break; case 2: buflen = sprintf(buf,"%lld",(0LL + rand()) << 20); break; default: assert(NULL); } } /* Add to listpack */ if (where == 0) { lp = lpPrepend(lp, (unsigned char*)buf, buflen); } else { lp = lpAppend(lp, (unsigned char*)buf, buflen); } /* Add to reference list */ if (where == 0) { listAddNodeHead(ref,sdsnewlen(buf, buflen)); } else if (where == 1) { listAddNodeTail(ref,sdsnewlen(buf, buflen)); } else { assert(NULL); } } assert(listLength(ref) == lpLength(lp)); for (j = 0; j < len; j++) { /* Naive way to get elements, but similar to the stresser * executed from the Tcl test suite. */ p = lpSeek(lp,j); refnode = listIndex(ref,j); vstr = lpGet(p, &vlen, intbuf); assert(memcmp(vstr,listNodeValue(refnode),vlen) == 0); } lpFree(lp); listRelease(ref); } printf("Done. usec=%lld\n\n", usec()-start); } TEST("Stress with variable listpack size") { unsigned long long start = usec(); int maxsize = accurate ? 16384 : 16; stress(0,100000,maxsize,256); stress(1,100000,maxsize,256); printf("Done. usec=%lld\n\n", usec()-start); } /* Benchmarks */ { int iteration = accurate ? 100000 : 100; lp = lpNew(0); TEST("Benchmark lpAppend") { unsigned long long start = usec(); for (int i=0; i * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __LISTPACK_H #define __LISTPACK_H #include #include #define LP_INTBUF_SIZE 21 /* 20 digits of -2^63 + 1 null term = 21. */ /* lpInsert() where argument possible values: */ #define LP_BEFORE 0 #define LP_AFTER 1 #define LP_REPLACE 2 /* Each entry in the listpack is either a string or an integer. */ typedef struct { /* When string is used, it is provided with the length (slen). */ unsigned char *sval; uint32_t slen; /* When integer is used, 'sval' is NULL, and lval holds the value. */ long long lval; } listpackEntry; unsigned char *lpNew(size_t capacity); void lpFree(unsigned char *lp); unsigned char* lpShrinkToFit(unsigned char *lp); unsigned char *lpInsertString(unsigned char *lp, const unsigned char *s, uint32_t slen, unsigned char *p, int where, unsigned char **newp); unsigned char *lpPrepend(unsigned char *lp, const unsigned char *s, uint32_t slen); unsigned char *lpPrependInteger(unsigned char *lp, long long lval); unsigned char *lpAppend(unsigned char *lp, const unsigned char *s, uint32_t slen); unsigned char *lpAppendInteger(unsigned char *lp, long long lval); unsigned char *lpInsertInteger(unsigned char *lp, long long lval, unsigned char *p, int where, unsigned char **newp); unsigned char *lpReplace(unsigned char *lp, unsigned char **p, const unsigned char *s, uint32_t slen); unsigned char *lpReplaceInteger(unsigned char *lp, unsigned char **p, long long lval); unsigned char *lpDelete(unsigned char *lp, unsigned char *p, unsigned char **newp); unsigned char *lpDeleteRangeWithEntry(unsigned char *lp, unsigned char **p, unsigned long num); unsigned char *lpDeleteRange(unsigned char *lp, long index, unsigned long num); unsigned char *lpMerge(unsigned char **first, unsigned char **second); unsigned long lpLength(unsigned char *lp); unsigned char *lpGet(unsigned char *p, int64_t *count, unsigned char *intbuf); // Fills count and returns 1 if the item is an integer, 0 otherwise. int lpGetInteger(unsigned char *p, int64_t *ival); int lpStringToInt64(const char *s, unsigned long slen, int64_t *value); unsigned char *lpGetValue(unsigned char *p, unsigned int *slen, long long *lval); unsigned char *lpFind(unsigned char *lp, unsigned char *p, unsigned char *s, uint32_t slen, unsigned int skip); unsigned char *lpFirst(unsigned char *lp); unsigned char *lpLast(unsigned char *lp); unsigned char *lpNext(unsigned char *lp, unsigned char *p); unsigned char *lpPrev(unsigned char *lp, unsigned char *p); size_t lpBytes(unsigned char *lp); unsigned char *lpSeek(unsigned char *lp, long index); typedef int (*listpackValidateEntryCB)(unsigned char *p, unsigned int head_count, void *userdata); int lpValidateIntegrity(unsigned char *lp, size_t size, int deep, listpackValidateEntryCB entry_cb, void *cb_userdata); unsigned char *lpValidateFirst(unsigned char *lp); int lpValidateNext(unsigned char *lp, unsigned char **pp, size_t lpbytes); unsigned int lpCompare(unsigned char *p, const unsigned char *s, uint32_t slen); void lpRandomPair(unsigned char *lp, unsigned long total_count, listpackEntry *key, listpackEntry *val); void lpRandomPairs(unsigned char *lp, unsigned int count, listpackEntry *keys, listpackEntry *vals); unsigned int lpRandomPairsUnique(unsigned char *lp, unsigned int count, listpackEntry *keys, listpackEntry *vals); int lpSafeToAdd(unsigned char* lp, size_t add); void lpRepr(unsigned char *lp); #ifdef REDIS_TEST int listpackTest(int argc, char *argv[], int flags); #endif #endif ================================================ FILE: src/redis/lua/CMakeLists.txt ================================================ add_library(lua_modules STATIC cjson/fpconv.c cjson/strbuf.c cjson/lua_cjson.c cmsgpack/lua_cmsgpack.c struct/lua_struct.c bit/bit.c ) target_compile_options(lua_modules PRIVATE -Wno-sign-compare -Wno-misleading-indentation -Wno-implicit-fallthrough -Wno-undefined-inline -Wno-stringop-overflow) target_link_libraries(lua_modules TRDP::lua) ================================================ FILE: src/redis/lua/README.md ================================================ Since version 5.2 `luaL_register` is deprecated and removed. The new `luaL_newlib` function doesn't make the module globally available upon registration and is ment to be used with the `require` function. To provide the modules globally, `luaL_newlib` is followed by a `lua_setglobal` for bit and struct. ================================================ FILE: src/redis/lua/bit/bit.c ================================================ /* ** Lua BitOp -- a bit operations library for Lua 5.1/5.2. ** http://bitop.luajit.org/ ** ** Copyright (C) 2008-2012 Mike Pall. All rights reserved. ** ** Permission is hereby granted, free of charge, to any person obtaining ** a copy of this software and associated documentation files (the ** "Software"), to deal in the Software without restriction, including ** without limitation the rights to use, copy, modify, merge, publish, ** distribute, sublicense, and/or sell copies of the Software, and to ** permit persons to whom the Software is furnished to do so, subject to ** the following conditions: ** ** The above copyright notice and this permission notice shall be ** included in all copies or substantial portions of the Software. ** ** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, ** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF ** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. ** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY ** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, ** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE ** SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ** ** [ MIT license: http://www.opensource.org/licenses/mit-license.php ] */ #define LUA_BITOP_VERSION "1.0.3" #define LUA_LIB #include "lua.h" #include "lauxlib.h" #ifdef _MSC_VER /* MSVC is stuck in the last century and doesn't have C99's stdint.h. */ typedef __int32 int32_t; typedef unsigned __int32 uint32_t; typedef unsigned __int64 uint64_t; #else #include #endif typedef int32_t SBits; typedef uint32_t UBits; typedef union { lua_Number n; #if defined(LUA_NUMBER_DOUBLE) || defined(LUA_FLOAT_DOUBLE) uint64_t b; #else UBits b; #endif } BitNum; /* Convert argument to bit type. */ static UBits barg(lua_State *L, int idx) { BitNum bn; UBits b; #if LUA_VERSION_NUM < 502 bn.n = lua_tonumber(L, idx); #else bn.n = luaL_checknumber(L, idx); #endif #if defined(LUA_NUMBER_DOUBLE) || defined(LUA_FLOAT_DOUBLE) bn.n += 6755399441055744.0; /* 2^52+2^51 */ #ifdef SWAPPED_DOUBLE b = (UBits)(bn.b >> 32); #else b = (UBits)bn.b; #endif #elif defined(LUA_NUMBER_INT) || defined(LUA_INT_INT) || \ defined(LUA_NUMBER_LONG) || defined(LUA_INT_LONG) || \ defined(LUA_NUMBER_LONGLONG) || defined(LUA_INT_LONGLONG) || \ defined(LUA_NUMBER_LONG_LONG) || defined(LUA_NUMBER_LLONG) if (sizeof(UBits) == sizeof(lua_Number)) b = bn.b; else b = (UBits)(SBits)bn.n; #elif defined(LUA_NUMBER_FLOAT) || defined(LUA_FLOAT_FLOAT) #error "A 'float' lua_Number type is incompatible with this library" #else #error "Unknown number type, check LUA_NUMBER_*, LUA_FLOAT_*, LUA_INT_* in luaconf.h" #endif #if LUA_VERSION_NUM < 502 if (b == 0 && !lua_isnumber(L, idx)) { luaL_typerror(L, idx, "number"); } #endif return b; } /* Return bit type. */ #if LUA_VERSION_NUM < 503 #define BRET(b) lua_pushnumber(L, (lua_Number)(SBits)(b)); return 1; #else #define BRET(b) lua_pushinteger(L, (lua_Integer)(SBits)(b)); return 1; #endif static int bit_tobit(lua_State *L) { BRET(barg(L, 1)) } static int bit_bnot(lua_State *L) { BRET(~barg(L, 1)) } #define BIT_OP(func, opr) \ static int func(lua_State *L) { int i; UBits b = barg(L, 1); \ for (i = lua_gettop(L); i > 1; i--) b opr barg(L, i); BRET(b) } BIT_OP(bit_band, &=) BIT_OP(bit_bor, |=) BIT_OP(bit_bxor, ^=) #define bshl(b, n) (b << n) #define bshr(b, n) (b >> n) #define bsar(b, n) ((SBits)b >> n) #define brol(b, n) ((b << n) | (b >> (32-n))) #define bror(b, n) ((b << (32-n)) | (b >> n)) #define BIT_SH(func, fn) \ static int func(lua_State *L) { \ UBits b = barg(L, 1); UBits n = barg(L, 2) & 31; BRET(fn(b, n)) } BIT_SH(bit_lshift, bshl) BIT_SH(bit_rshift, bshr) BIT_SH(bit_arshift, bsar) BIT_SH(bit_rol, brol) BIT_SH(bit_ror, bror) static int bit_bswap(lua_State *L) { UBits b = barg(L, 1); b = (b >> 24) | ((b >> 8) & 0xff00) | ((b & 0xff00) << 8) | (b << 24); BRET(b) } static int bit_tohex(lua_State *L) { UBits b = barg(L, 1); SBits n = lua_isnone(L, 2) ? 8 : (SBits)barg(L, 2); const char *hexdigits = "0123456789abcdef"; char buf[8]; int i; if (n == INT32_MIN) n = INT32_MIN+1; if (n < 0) { n = -n; hexdigits = "0123456789ABCDEF"; } if (n > 8) n = 8; for (i = (int)n; --i >= 0; ) { buf[i] = hexdigits[b & 15]; b >>= 4; } lua_pushlstring(L, buf, (size_t)n); return 1; } static const struct luaL_Reg bit_funcs[] = { { "tobit", bit_tobit }, { "bnot", bit_bnot }, { "band", bit_band }, { "bor", bit_bor }, { "bxor", bit_bxor }, { "lshift", bit_lshift }, { "rshift", bit_rshift }, { "arshift", bit_arshift }, { "rol", bit_rol }, { "ror", bit_ror }, { "bswap", bit_bswap }, { "tohex", bit_tohex }, { NULL, NULL } }; /* Signed right-shifts are implementation-defined per C89/C99. ** But the de facto standard are arithmetic right-shifts on two's ** complement CPUs. This behaviour is required here, so test for it. */ #define BAD_SAR (bsar(-8, 2) != (SBits)-2) LUALIB_API int luaopen_bit(lua_State *L) { UBits b; #if LUA_VERSION_NUM < 503 lua_pushnumber(L, (lua_Number)1437217655L); #else lua_pushinteger(L, (lua_Integer)1437217655L); #endif b = barg(L, -1); if (b != (UBits)1437217655L || BAD_SAR) { /* Perform a simple self-test. */ const char *msg = "compiled with incompatible luaconf.h"; #if defined(LUA_NUMBER_DOUBLE) || defined(LUA_FLOAT_DOUBLE) #ifdef _WIN32 if (b == (UBits)1610612736L) msg = "use D3DCREATE_FPU_PRESERVE with DirectX"; #endif if (b == (UBits)1127743488L) msg = "not compiled with SWAPPED_DOUBLE"; #endif if (BAD_SAR) msg = "arithmetic right-shift broken"; luaL_error(L, "bit library self-test failed (%s)", msg); } luaL_newlib(L, bit_funcs); lua_setglobal(L, "bit"); return 1; } ================================================ FILE: src/redis/lua/cjson/fpconv.c ================================================ /* fpconv - Floating point conversion routines * * Copyright (c) 2011-2012 Mark Pulford * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /* JSON uses a '.' decimal separator. strtod() / sprintf() under C libraries * with locale support will break when the decimal separator is a comma. * * fpconv_* will around these issues with a translation buffer if required. */ #include #include #include #include #include "fpconv.h" /* Lua CJSON assumes the locale is the same for all threads within a * process and doesn't change after initialisation. * * This avoids the need for per thread storage or expensive checks * for call. */ static char locale_decimal_point = '.'; /* In theory multibyte decimal_points are possible, but * Lua CJSON only supports UTF-8 and known locales only have * single byte decimal points ([.,]). * * localconv() may not be thread safe (=>crash), and nl_langinfo() is * not supported on some platforms. Use sprintf() instead - if the * locale does change, at least Lua CJSON won't crash. */ static void fpconv_update_locale() { char buf[8]; snprintf(buf, sizeof(buf), "%g", 0.5); /* Failing this test might imply the platform has a buggy dtoa * implementation or wide characters */ if (buf[0] != '0' || buf[2] != '5' || buf[3] != 0) { fprintf(stderr, "Error: wide characters found or printf() bug."); abort(); } locale_decimal_point = buf[1]; } /* Check for a valid number character: [-+0-9a-yA-Y.] * Eg: -0.6e+5, infinity, 0xF0.F0pF0 * * Used to find the probable end of a number. It doesn't matter if * invalid characters are counted - strtod() will find the valid * number if it exists. The risk is that slightly more memory might * be allocated before a parse error occurs. */ static inline int valid_number_character(char ch) { char lower_ch; if ('0' <= ch && ch <= '9') return 1; if (ch == '-' || ch == '+' || ch == '.') return 1; /* Hex digits, exponent (e), base (p), "infinity",.. */ lower_ch = ch | 0x20; if ('a' <= lower_ch && lower_ch <= 'y') return 1; return 0; } /* Calculate the size of the buffer required for a strtod locale * conversion. */ static int strtod_buffer_size(const char *s) { const char *p = s; while (valid_number_character(*p)) p++; return p - s; } /* Similar to strtod(), but must be passed the current locale's decimal point * character. Guaranteed to be called at the start of any valid number in a string */ double fpconv_strtod(const char *nptr, char **endptr) { char localbuf[FPCONV_G_FMT_BUFSIZE]; char *buf, *endbuf, *dp; int buflen; double value; /* System strtod() is fine when decimal point is '.' */ if (locale_decimal_point == '.') return strtod(nptr, endptr); buflen = strtod_buffer_size(nptr); if (!buflen) { /* No valid characters found, standard strtod() return */ *endptr = (char *)nptr; return 0; } /* Duplicate number into buffer */ if (buflen >= FPCONV_G_FMT_BUFSIZE) { /* Handle unusually large numbers */ buf = malloc(buflen + 1); if (!buf) { fprintf(stderr, "Out of memory"); abort(); } } else { /* This is the common case.. */ buf = localbuf; } memcpy(buf, nptr, buflen); buf[buflen] = 0; /* Update decimal point character if found */ dp = strchr(buf, '.'); if (dp) *dp = locale_decimal_point; value = strtod(buf, &endbuf); *endptr = (char *)&nptr[endbuf - buf]; if (buflen >= FPCONV_G_FMT_BUFSIZE) free(buf); return value; } /* "fmt" must point to a buffer of at least 6 characters */ static void set_number_format(char *fmt, int precision) { int d1, d2, i; assert(1 <= precision && precision <= 14); /* Create printf format (%.14g) from precision */ d1 = precision / 10; d2 = precision % 10; fmt[0] = '%'; fmt[1] = '.'; i = 2; if (d1) { fmt[i++] = '0' + d1; } fmt[i++] = '0' + d2; fmt[i++] = 'g'; fmt[i] = 0; } /* Assumes there is always at least 32 characters available in the target buffer */ int fpconv_g_fmt(char *str, double num, int precision) { char buf[FPCONV_G_FMT_BUFSIZE]; char fmt[6]; int len; char *b; set_number_format(fmt, precision); /* Pass through when decimal point character is dot. */ if (locale_decimal_point == '.') return snprintf(str, FPCONV_G_FMT_BUFSIZE, fmt, num); /* snprintf() to a buffer then translate for other decimal point characters */ len = snprintf(buf, FPCONV_G_FMT_BUFSIZE, fmt, num); /* Copy into target location. Translate decimal point if required */ b = buf; do { *str++ = (*b == locale_decimal_point ? '.' : *b); } while(*b++); return len; } void fpconv_init() { fpconv_update_locale(); } /* vi:ai et sw=4 ts=4: */ ================================================ FILE: src/redis/lua/cjson/fpconv.h ================================================ /* Lua CJSON floating point conversion routines */ /* Buffer required to store the largest string representation of a double. * * Longest double printed with %.14g is 21 characters long: * -1.7976931348623e+308 */ # define FPCONV_G_FMT_BUFSIZE 32 #ifdef USE_INTERNAL_FPCONV static inline void fpconv_init() { /* Do nothing - not required */ } #else extern void fpconv_init(); #endif extern int fpconv_g_fmt(char*, double, int); extern double fpconv_strtod(const char*, char**); /* vi:ai et sw=4 ts=4: */ ================================================ FILE: src/redis/lua/cjson/lua_cjson.c ================================================ /* Lua CJSON - JSON support for Lua * * Copyright (c) 2010-2012 Mark Pulford * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /* Caveats: * - JSON "null" values are represented as lightuserdata since Lua * tables cannot contain "nil". Compare with cjson.null. * - Invalid UTF-8 characters are not detected and will be passed * untouched. If required, UTF-8 error checking should be done * outside this library. * - Javascript comments are not part of the JSON spec, and are not * currently supported. * * Note: Decoding is slower than encoding. Lua spends significant * time (30%) managing tables when parsing JSON since it is * difficult to know object/array sizes ahead of time. */ #include #include #include #include #include #include #include "strbuf.h" #include "fpconv.h" #ifndef CJSON_MODNAME #define CJSON_MODNAME "cjson" #endif #ifndef CJSON_VERSION #define CJSON_VERSION "2.1devel" #endif /* Workaround for Solaris platforms missing isinf() */ #if !defined(isinf) && (defined(USE_INTERNAL_ISINF) || defined(MISSING_ISINF)) #define isinf(x) (!isnan(x) && isnan((x) - (x))) #endif #define DEFAULT_SPARSE_CONVERT 0 #define DEFAULT_SPARSE_RATIO 2 #define DEFAULT_SPARSE_SAFE 10 #define DEFAULT_ENCODE_MAX_DEPTH 1000 #define DEFAULT_DECODE_MAX_DEPTH 1000 #define DEFAULT_ENCODE_INVALID_NUMBERS 0 #define DEFAULT_DECODE_INVALID_NUMBERS 1 #define DEFAULT_ENCODE_KEEP_BUFFER 1 #define DEFAULT_ENCODE_NUMBER_PRECISION 14 #ifdef DISABLE_INVALID_NUMBERS #undef DEFAULT_DECODE_INVALID_NUMBERS #define DEFAULT_DECODE_INVALID_NUMBERS 0 #endif typedef enum { T_OBJ_BEGIN, T_OBJ_END, T_ARR_BEGIN, T_ARR_END, T_STRING, T_NUMBER, T_BOOLEAN, T_NULL, T_COLON, T_COMMA, T_END, T_WHITESPACE, T_ERROR, T_UNKNOWN } json_token_type_t; static const char *json_token_type_name[] = { "T_OBJ_BEGIN", "T_OBJ_END", "T_ARR_BEGIN", "T_ARR_END", "T_STRING", "T_NUMBER", "T_BOOLEAN", "T_NULL", "T_COLON", "T_COMMA", "T_END", "T_WHITESPACE", "T_ERROR", "T_UNKNOWN", NULL }; typedef struct { json_token_type_t ch2token[256]; char escape2char[256]; /* Decoding */ /* encode_buf is only allocated and used when * encode_keep_buffer is set */ strbuf_t encode_buf; int encode_sparse_convert; int encode_sparse_ratio; int encode_sparse_safe; int encode_max_depth; int encode_invalid_numbers; /* 2 => Encode as "null" */ int encode_number_precision; int encode_keep_buffer; int decode_invalid_numbers; int decode_max_depth; } json_config_t; typedef struct { const char *data; const char *ptr; strbuf_t *tmp; /* Temporary storage for strings */ json_config_t *cfg; int current_depth; } json_parse_t; typedef struct { json_token_type_t type; int index; union { const char *string; double number; int boolean; } value; int string_len; } json_token_t; static const char *char2escape[256] = { "\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b", "\\t", "\\n", "\\u000b", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012", "\\u0013", "\\u0014", "\\u0015", "\\u0016", "\\u0017", "\\u0018", "\\u0019", "\\u001a", "\\u001b", "\\u001c", "\\u001d", "\\u001e", "\\u001f", NULL, NULL, "\\\"", NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, "\\/", NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, "\\\\", NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, "\\u007f", NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, }; /* ===== CONFIGURATION ===== */ static json_config_t *json_fetch_config(lua_State *l) { json_config_t *cfg; cfg = lua_touserdata(l, lua_upvalueindex(1)); if (!cfg) luaL_error(l, "BUG: Unable to fetch CJSON configuration"); return cfg; } /* Ensure the correct number of arguments have been provided. * Pad with nil to allow other functions to simply check arg[i] * to find whether an argument was provided */ static json_config_t *json_arg_init(lua_State *l, int args) { luaL_argcheck(l, lua_gettop(l) <= args, args + 1, "found too many arguments"); while (lua_gettop(l) < args) lua_pushnil(l); return json_fetch_config(l); } /* Process integer options for configuration functions */ static int json_integer_option(lua_State *l, int optindex, int *setting, int min, int max) { char errmsg[64]; int value; if (!lua_isnil(l, optindex)) { value = luaL_checkinteger(l, optindex); snprintf(errmsg, sizeof(errmsg), "expected integer between %d and %d", min, max); luaL_argcheck(l, min <= value && value <= max, 1, errmsg); *setting = value; } lua_pushinteger(l, *setting); return 1; } /* Process enumerated arguments for a configuration function */ static int json_enum_option(lua_State *l, int optindex, int *setting, const char **options, int bool_true) { static const char *bool_options[] = { "off", "on", NULL }; if (!options) { options = bool_options; bool_true = 1; } if (!lua_isnil(l, optindex)) { if (bool_true && lua_isboolean(l, optindex)) *setting = lua_toboolean(l, optindex) * bool_true; else *setting = luaL_checkoption(l, optindex, NULL, options); } if (bool_true && (*setting == 0 || *setting == bool_true)) lua_pushboolean(l, *setting); else lua_pushstring(l, options[*setting]); return 1; } /* Configures handling of extremely sparse arrays: * convert: Convert extremely sparse arrays into objects? Otherwise error. * ratio: 0: always allow sparse; 1: never allow sparse; >1: use ratio * safe: Always use an array when the max index <= safe */ static int json_cfg_encode_sparse_array(lua_State *l) { json_config_t *cfg = json_arg_init(l, 3); json_enum_option(l, 1, &cfg->encode_sparse_convert, NULL, 1); json_integer_option(l, 2, &cfg->encode_sparse_ratio, 0, INT_MAX); json_integer_option(l, 3, &cfg->encode_sparse_safe, 0, INT_MAX); return 3; } /* Configures the maximum number of nested arrays/objects allowed when * encoding */ static int json_cfg_encode_max_depth(lua_State *l) { json_config_t *cfg = json_arg_init(l, 1); return json_integer_option(l, 1, &cfg->encode_max_depth, 1, INT_MAX); } /* Configures the maximum number of nested arrays/objects allowed when * encoding */ static int json_cfg_decode_max_depth(lua_State *l) { json_config_t *cfg = json_arg_init(l, 1); return json_integer_option(l, 1, &cfg->decode_max_depth, 1, INT_MAX); } /* Configures number precision when converting doubles to text */ static int json_cfg_encode_number_precision(lua_State *l) { json_config_t *cfg = json_arg_init(l, 1); return json_integer_option(l, 1, &cfg->encode_number_precision, 1, 14); } /* Configures JSON encoding buffer persistence */ static int json_cfg_encode_keep_buffer(lua_State *l) { json_config_t *cfg = json_arg_init(l, 1); int old_value; old_value = cfg->encode_keep_buffer; json_enum_option(l, 1, &cfg->encode_keep_buffer, NULL, 1); /* Init / free the buffer if the setting has changed */ if (old_value ^ cfg->encode_keep_buffer) { if (cfg->encode_keep_buffer) strbuf_init(&cfg->encode_buf, 0); else strbuf_free(&cfg->encode_buf); } return 1; } #if defined(DISABLE_INVALID_NUMBERS) && !defined(USE_INTERNAL_FPCONV) void json_verify_invalid_number_setting(lua_State *l, int *setting) { if (*setting == 1) { *setting = 0; luaL_error(l, "Infinity, NaN, and/or hexadecimal numbers are not supported."); } } #else #define json_verify_invalid_number_setting(l, s) do { } while(0) #endif static int json_cfg_encode_invalid_numbers(lua_State *l) { static const char *options[] = { "off", "on", "null", NULL }; json_config_t *cfg = json_arg_init(l, 1); json_enum_option(l, 1, &cfg->encode_invalid_numbers, options, 1); json_verify_invalid_number_setting(l, &cfg->encode_invalid_numbers); return 1; } static int json_cfg_decode_invalid_numbers(lua_State *l) { json_config_t *cfg = json_arg_init(l, 1); json_enum_option(l, 1, &cfg->decode_invalid_numbers, NULL, 1); json_verify_invalid_number_setting(l, &cfg->encode_invalid_numbers); return 1; } static int json_destroy_config(lua_State *l) { json_config_t *cfg; cfg = lua_touserdata(l, 1); if (cfg) strbuf_free(&cfg->encode_buf); cfg = NULL; return 0; } static void json_create_config(lua_State *l) { json_config_t *cfg; int i; cfg = lua_newuserdata(l, sizeof(*cfg)); /* Create GC method to clean up strbuf */ lua_newtable(l); lua_pushcfunction(l, json_destroy_config); lua_setfield(l, -2, "__gc"); lua_setmetatable(l, -2); cfg->encode_sparse_convert = DEFAULT_SPARSE_CONVERT; cfg->encode_sparse_ratio = DEFAULT_SPARSE_RATIO; cfg->encode_sparse_safe = DEFAULT_SPARSE_SAFE; cfg->encode_max_depth = DEFAULT_ENCODE_MAX_DEPTH; cfg->decode_max_depth = DEFAULT_DECODE_MAX_DEPTH; cfg->encode_invalid_numbers = DEFAULT_ENCODE_INVALID_NUMBERS; cfg->decode_invalid_numbers = DEFAULT_DECODE_INVALID_NUMBERS; cfg->encode_keep_buffer = DEFAULT_ENCODE_KEEP_BUFFER; cfg->encode_number_precision = DEFAULT_ENCODE_NUMBER_PRECISION; #if DEFAULT_ENCODE_KEEP_BUFFER > 0 strbuf_init(&cfg->encode_buf, 0); #endif /* Decoding init */ /* Tag all characters as an error */ for (i = 0; i < 256; i++) cfg->ch2token[i] = T_ERROR; /* Set tokens that require no further processing */ cfg->ch2token['{'] = T_OBJ_BEGIN; cfg->ch2token['}'] = T_OBJ_END; cfg->ch2token['['] = T_ARR_BEGIN; cfg->ch2token[']'] = T_ARR_END; cfg->ch2token[','] = T_COMMA; cfg->ch2token[':'] = T_COLON; cfg->ch2token['\0'] = T_END; cfg->ch2token[' '] = T_WHITESPACE; cfg->ch2token['\t'] = T_WHITESPACE; cfg->ch2token['\n'] = T_WHITESPACE; cfg->ch2token['\r'] = T_WHITESPACE; /* Update characters that require further processing */ cfg->ch2token['f'] = T_UNKNOWN; /* false? */ cfg->ch2token['i'] = T_UNKNOWN; /* inf, ininity? */ cfg->ch2token['I'] = T_UNKNOWN; cfg->ch2token['n'] = T_UNKNOWN; /* null, nan? */ cfg->ch2token['N'] = T_UNKNOWN; cfg->ch2token['t'] = T_UNKNOWN; /* true? */ cfg->ch2token['"'] = T_UNKNOWN; /* string? */ cfg->ch2token['+'] = T_UNKNOWN; /* number? */ cfg->ch2token['-'] = T_UNKNOWN; for (i = 0; i < 10; i++) cfg->ch2token['0' + i] = T_UNKNOWN; /* Lookup table for parsing escape characters */ for (i = 0; i < 256; i++) cfg->escape2char[i] = 0; /* String error */ cfg->escape2char['"'] = '"'; cfg->escape2char['\\'] = '\\'; cfg->escape2char['/'] = '/'; cfg->escape2char['b'] = '\b'; cfg->escape2char['t'] = '\t'; cfg->escape2char['n'] = '\n'; cfg->escape2char['f'] = '\f'; cfg->escape2char['r'] = '\r'; cfg->escape2char['u'] = 'u'; /* Unicode parsing required */ } /* ===== ENCODING ===== */ static void json_encode_exception(lua_State *l, json_config_t *cfg, strbuf_t *json, int lindex, const char *reason) { if (!cfg->encode_keep_buffer) strbuf_free(json); luaL_error(l, "Cannot serialise %s: %s", lua_typename(l, lua_type(l, lindex)), reason); } /* json_append_string args: * - lua_State * - JSON strbuf * - String (Lua stack index) * * Returns nothing. Doesn't remove string from Lua stack */ static void json_append_string(lua_State *l, strbuf_t *json, int lindex) { const char *escstr; int i; const char *str; size_t len; str = lua_tolstring(l, lindex, &len); /* Worst case is len * 6 (all unicode escapes). * This buffer is reused constantly for small strings * If there are any excess pages, they won't be hit anyway. * This gains ~5% speedup. */ strbuf_ensure_empty_length(json, len * 6 + 2); strbuf_append_char_unsafe(json, '\"'); for (i = 0; i < len; i++) { escstr = char2escape[(unsigned char)str[i]]; if (escstr) strbuf_append_string(json, escstr); else strbuf_append_char_unsafe(json, str[i]); } strbuf_append_char_unsafe(json, '\"'); } /* Find the size of the array on the top of the Lua stack * -1 object (not a pure array) * >=0 elements in array */ static int lua_array_length(lua_State *l, json_config_t *cfg, strbuf_t *json) { double k; int max; int items; max = 0; items = 0; lua_pushnil(l); /* table, startkey */ while (lua_next(l, -2) != 0) { /* table, key, value */ if (lua_type(l, -2) == LUA_TNUMBER && (k = lua_tonumber(l, -2))) { /* Integer >= 1 ? */ if (floor(k) == k && k >= 1) { if (k > max) max = k; items++; lua_pop(l, 1); continue; } } /* Must not be an array (non integer key) */ lua_pop(l, 2); return -1; } /* Encode excessively sparse arrays as objects (if enabled) */ if (cfg->encode_sparse_ratio > 0 && max > items * cfg->encode_sparse_ratio && max > cfg->encode_sparse_safe) { if (!cfg->encode_sparse_convert) json_encode_exception(l, cfg, json, -1, "excessively sparse array"); return -1; } return max; } static void json_check_encode_depth(lua_State *l, json_config_t *cfg, int current_depth, strbuf_t *json) { /* Ensure there are enough slots free to traverse a table (key, * value) and push a string for a potential error message. * * Unlike "decode", the key and value are still on the stack when * lua_checkstack() is called. Hence an extra slot for luaL_error() * below is required just in case the next check to lua_checkstack() * fails. * * While this won't cause a crash due to the EXTRA_STACK reserve * slots, it would still be an improper use of the API. */ if (current_depth <= cfg->encode_max_depth && lua_checkstack(l, 3)) return; if (!cfg->encode_keep_buffer) strbuf_free(json); luaL_error(l, "Cannot serialise, excessive nesting (%d)", current_depth); } static void json_append_data(lua_State *l, json_config_t *cfg, int current_depth, strbuf_t *json); /* json_append_array args: * - lua_State * - JSON strbuf * - Size of passwd Lua array (top of stack) */ static void json_append_array(lua_State *l, json_config_t *cfg, int current_depth, strbuf_t *json, int array_length) { int comma, i; strbuf_append_char(json, '['); comma = 0; for (i = 1; i <= array_length; i++) { if (comma) strbuf_append_char(json, ','); else comma = 1; lua_rawgeti(l, -1, i); json_append_data(l, cfg, current_depth, json); lua_pop(l, 1); } strbuf_append_char(json, ']'); } static void json_append_number(lua_State *l, json_config_t *cfg, strbuf_t *json, int lindex) { double num = lua_tonumber(l, lindex); int len; if (cfg->encode_invalid_numbers == 0) { /* Prevent encoding invalid numbers */ if (isinf(num) || isnan(num)) json_encode_exception(l, cfg, json, lindex, "must not be NaN or Inf"); } else if (cfg->encode_invalid_numbers == 1) { /* Encode invalid numbers, but handle "nan" separately * since some platforms may encode as "-nan". */ if (isnan(num)) { strbuf_append_mem(json, "nan", 3); return; } } else { /* Encode invalid numbers as "null" */ if (isinf(num) || isnan(num)) { strbuf_append_mem(json, "null", 4); return; } } strbuf_ensure_empty_length(json, FPCONV_G_FMT_BUFSIZE); len = fpconv_g_fmt(strbuf_empty_ptr(json), num, cfg->encode_number_precision); strbuf_extend_length(json, len); } static void json_append_object(lua_State *l, json_config_t *cfg, int current_depth, strbuf_t *json) { int comma, keytype; /* Object */ strbuf_append_char(json, '{'); lua_pushnil(l); /* table, startkey */ comma = 0; while (lua_next(l, -2) != 0) { if (comma) strbuf_append_char(json, ','); else comma = 1; /* table, key, value */ keytype = lua_type(l, -2); if (keytype == LUA_TNUMBER) { strbuf_append_char(json, '"'); json_append_number(l, cfg, json, -2); strbuf_append_mem(json, "\":", 2); } else if (keytype == LUA_TSTRING) { json_append_string(l, json, -2); strbuf_append_char(json, ':'); } else { json_encode_exception(l, cfg, json, -2, "table key must be a number or string"); /* never returns */ } /* table, key, value */ json_append_data(l, cfg, current_depth, json); lua_pop(l, 1); /* table, key */ } strbuf_append_char(json, '}'); } /* Serialise Lua data into JSON string. */ static void json_append_data(lua_State *l, json_config_t *cfg, int current_depth, strbuf_t *json) { int len; switch (lua_type(l, -1)) { case LUA_TSTRING: json_append_string(l, json, -1); break; case LUA_TNUMBER: json_append_number(l, cfg, json, -1); break; case LUA_TBOOLEAN: if (lua_toboolean(l, -1)) strbuf_append_mem(json, "true", 4); else strbuf_append_mem(json, "false", 5); break; case LUA_TTABLE: current_depth++; json_check_encode_depth(l, cfg, current_depth, json); len = lua_array_length(l, cfg, json); if (len > 0) json_append_array(l, cfg, current_depth, json, len); else json_append_object(l, cfg, current_depth, json); break; case LUA_TNIL: strbuf_append_mem(json, "null", 4); break; case LUA_TLIGHTUSERDATA: if (lua_touserdata(l, -1) == NULL) { strbuf_append_mem(json, "null", 4); break; } default: /* Remaining types (LUA_TFUNCTION, LUA_TUSERDATA, LUA_TTHREAD, * and LUA_TLIGHTUSERDATA) cannot be serialised */ json_encode_exception(l, cfg, json, -1, "type not supported"); /* never returns */ } } static int json_encode(lua_State *l) { json_config_t *cfg = json_fetch_config(l); strbuf_t local_encode_buf; strbuf_t *encode_buf; char *json; int len; luaL_argcheck(l, lua_gettop(l) == 1, 1, "expected 1 argument"); if (!cfg->encode_keep_buffer) { /* Use private buffer */ encode_buf = &local_encode_buf; strbuf_init(encode_buf, 0); } else { /* Reuse existing buffer */ encode_buf = &cfg->encode_buf; strbuf_reset(encode_buf); } json_append_data(l, cfg, 0, encode_buf); json = strbuf_string(encode_buf, &len); lua_pushlstring(l, json, len); if (!cfg->encode_keep_buffer) strbuf_free(encode_buf); return 1; } /* ===== DECODING ===== */ static void json_process_value(lua_State *l, json_parse_t *json, json_token_t *token); static int hexdigit2int(char hex) { if ('0' <= hex && hex <= '9') return hex - '0'; /* Force lowercase */ hex |= 0x20; if ('a' <= hex && hex <= 'f') return 10 + hex - 'a'; return -1; } static int decode_hex4(const char *hex) { int digit[4]; int i; /* Convert ASCII hex digit to numeric digit * Note: this returns an error for invalid hex digits, including * NULL */ for (i = 0; i < 4; i++) { digit[i] = hexdigit2int(hex[i]); if (digit[i] < 0) { return -1; } } return (digit[0] << 12) + (digit[1] << 8) + (digit[2] << 4) + digit[3]; } /* Converts a Unicode codepoint to UTF-8. * Returns UTF-8 string length, and up to 4 bytes in *utf8 */ static int codepoint_to_utf8(char *utf8, int codepoint) { /* 0xxxxxxx */ if (codepoint <= 0x7F) { utf8[0] = codepoint; return 1; } /* 110xxxxx 10xxxxxx */ if (codepoint <= 0x7FF) { utf8[0] = (codepoint >> 6) | 0xC0; utf8[1] = (codepoint & 0x3F) | 0x80; return 2; } /* 1110xxxx 10xxxxxx 10xxxxxx */ if (codepoint <= 0xFFFF) { utf8[0] = (codepoint >> 12) | 0xE0; utf8[1] = ((codepoint >> 6) & 0x3F) | 0x80; utf8[2] = (codepoint & 0x3F) | 0x80; return 3; } /* 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx */ if (codepoint <= 0x1FFFFF) { utf8[0] = (codepoint >> 18) | 0xF0; utf8[1] = ((codepoint >> 12) & 0x3F) | 0x80; utf8[2] = ((codepoint >> 6) & 0x3F) | 0x80; utf8[3] = (codepoint & 0x3F) | 0x80; return 4; } return 0; } /* Called when index pointing to beginning of UTF-16 code escape: \uXXXX * \u is guaranteed to exist, but the remaining hex characters may be * missing. * Translate to UTF-8 and append to temporary token string. * Must advance index to the next character to be processed. * Returns: 0 success * -1 error */ static int json_append_unicode_escape(json_parse_t *json) { char utf8[4]; /* Surrogate pairs require 4 UTF-8 bytes */ int codepoint; int surrogate_low; int len; int escape_len = 6; /* Fetch UTF-16 code unit */ codepoint = decode_hex4(json->ptr + 2); if (codepoint < 0) return -1; /* UTF-16 surrogate pairs take the following 2 byte form: * 11011 x yyyyyyyyyy * When x = 0: y is the high 10 bits of the codepoint * x = 1: y is the low 10 bits of the codepoint * * Check for a surrogate pair (high or low) */ if ((codepoint & 0xF800) == 0xD800) { /* Error if the 1st surrogate is not high */ if (codepoint & 0x400) return -1; /* Ensure the next code is a unicode escape */ if (*(json->ptr + escape_len) != '\\' || *(json->ptr + escape_len + 1) != 'u') { return -1; } /* Fetch the next codepoint */ surrogate_low = decode_hex4(json->ptr + 2 + escape_len); if (surrogate_low < 0) return -1; /* Error if the 2nd code is not a low surrogate */ if ((surrogate_low & 0xFC00) != 0xDC00) return -1; /* Calculate Unicode codepoint */ codepoint = (codepoint & 0x3FF) << 10; surrogate_low &= 0x3FF; codepoint = (codepoint | surrogate_low) + 0x10000; escape_len = 12; } /* Convert codepoint to UTF-8 */ len = codepoint_to_utf8(utf8, codepoint); if (!len) return -1; /* Append bytes and advance parse index */ strbuf_append_mem_unsafe(json->tmp, utf8, len); json->ptr += escape_len; return 0; } static void json_set_token_error(json_token_t *token, json_parse_t *json, const char *errtype) { token->type = T_ERROR; token->index = json->ptr - json->data; token->value.string = errtype; } static void json_next_string_token(json_parse_t *json, json_token_t *token) { char *escape2char = json->cfg->escape2char; char ch; /* Caller must ensure a string is next */ assert(*json->ptr == '"'); /* Skip " */ json->ptr++; /* json->tmp is the temporary strbuf used to accumulate the * decoded string value. * json->tmp is sized to handle JSON containing only a string value. */ strbuf_reset(json->tmp); while ((ch = *json->ptr) != '"') { if (!ch) { /* Premature end of the string */ json_set_token_error(token, json, "unexpected end of string"); return; } /* Handle escapes */ if (ch == '\\') { /* Fetch escape character */ ch = *(json->ptr + 1); /* Translate escape code and append to tmp string */ ch = escape2char[(unsigned char)ch]; if (ch == 'u') { if (json_append_unicode_escape(json) == 0) continue; json_set_token_error(token, json, "invalid unicode escape code"); return; } if (!ch) { json_set_token_error(token, json, "invalid escape code"); return; } /* Skip '\' */ json->ptr++; } /* Append normal character or translated single character * Unicode escapes are handled above */ strbuf_append_char_unsafe(json->tmp, ch); json->ptr++; } json->ptr++; /* Eat final quote (") */ strbuf_ensure_null(json->tmp); token->type = T_STRING; token->value.string = strbuf_string(json->tmp, &token->string_len); } /* JSON numbers should take the following form: * -?(0|[1-9]|[1-9][0-9]+)(.[0-9]+)?([eE][-+]?[0-9]+)? * * json_next_number_token() uses strtod() which allows other forms: * - numbers starting with '+' * - NaN, -NaN, infinity, -infinity * - hexadecimal numbers * - numbers with leading zeros * * json_is_invalid_number() detects "numbers" which may pass strtod()'s * error checking, but should not be allowed with strict JSON. * * json_is_invalid_number() may pass numbers which cause strtod() * to generate an error. */ static int json_is_invalid_number(json_parse_t *json) { const char *p = json->ptr; /* Reject numbers starting with + */ if (*p == '+') return 1; /* Skip minus sign if it exists */ if (*p == '-') p++; /* Reject numbers starting with 0x, or leading zeros */ if (*p == '0') { int ch2 = *(p + 1); if ((ch2 | 0x20) == 'x' || /* Hex */ ('0' <= ch2 && ch2 <= '9')) /* Leading zero */ return 1; return 0; } else if (*p <= '9') { return 0; /* Ordinary number */ } /* Reject inf/nan */ if (!strncasecmp(p, "inf", 3)) return 1; if (!strncasecmp(p, "nan", 3)) return 1; /* Pass all other numbers which may still be invalid, but * strtod() will catch them. */ return 0; } static void json_next_number_token(json_parse_t *json, json_token_t *token) { char *endptr; token->type = T_NUMBER; token->value.number = fpconv_strtod(json->ptr, &endptr); if (json->ptr == endptr) json_set_token_error(token, json, "invalid number"); else json->ptr = endptr; /* Skip the processed number */ return; } /* Fills in the token struct. * T_STRING will return a pointer to the json_parse_t temporary string * T_ERROR will leave the json->ptr pointer at the error. */ static void json_next_token(json_parse_t *json, json_token_t *token) { const json_token_type_t *ch2token = json->cfg->ch2token; int ch; /* Eat whitespace. */ while (1) { ch = (unsigned char)*(json->ptr); token->type = ch2token[ch]; if (token->type != T_WHITESPACE) break; json->ptr++; } /* Store location of new token. Required when throwing errors * for unexpected tokens (syntax errors). */ token->index = json->ptr - json->data; /* Don't advance the pointer for an error or the end */ if (token->type == T_ERROR) { json_set_token_error(token, json, "invalid token"); return; } if (token->type == T_END) { return; } /* Found a known single character token, advance index and return */ if (token->type != T_UNKNOWN) { json->ptr++; return; } /* Process characters which triggered T_UNKNOWN * * Must use strncmp() to match the front of the JSON string. * JSON identifier must be lowercase. * When strict_numbers if disabled, either case is allowed for * Infinity/NaN (since we are no longer following the spec..) */ if (ch == '"') { json_next_string_token(json, token); return; } else if (ch == '-' || ('0' <= ch && ch <= '9')) { if (!json->cfg->decode_invalid_numbers && json_is_invalid_number(json)) { json_set_token_error(token, json, "invalid number"); return; } json_next_number_token(json, token); return; } else if (!strncmp(json->ptr, "true", 4)) { token->type = T_BOOLEAN; token->value.boolean = 1; json->ptr += 4; return; } else if (!strncmp(json->ptr, "false", 5)) { token->type = T_BOOLEAN; token->value.boolean = 0; json->ptr += 5; return; } else if (!strncmp(json->ptr, "null", 4)) { token->type = T_NULL; json->ptr += 4; return; } else if (json->cfg->decode_invalid_numbers && json_is_invalid_number(json)) { /* When decode_invalid_numbers is enabled, only attempt to process * numbers we know are invalid JSON (Inf, NaN, hex) * This is required to generate an appropriate token error, * otherwise all bad tokens will register as "invalid number" */ json_next_number_token(json, token); return; } /* Token starts with t/f/n but isn't recognised above. */ json_set_token_error(token, json, "invalid token"); } /* This function does not return. * DO NOT CALL WITH DYNAMIC MEMORY ALLOCATED. * The only supported exception is the temporary parser string * json->tmp struct. * json and token should exist on the stack somewhere. * luaL_error() will long_jmp and release the stack */ static void json_throw_parse_error(lua_State *l, json_parse_t *json, const char *exp, json_token_t *token) { const char *found; strbuf_free(json->tmp); if (token->type == T_ERROR) found = token->value.string; else found = json_token_type_name[token->type]; /* Note: token->index is 0 based, display starting from 1 */ luaL_error(l, "Expected %s but found %s at character %d", exp, found, token->index + 1); } static inline void json_decode_ascend(json_parse_t *json) { json->current_depth--; } static void json_decode_descend(lua_State *l, json_parse_t *json, int slots) { json->current_depth++; if (json->current_depth <= json->cfg->decode_max_depth && lua_checkstack(l, slots)) { return; } strbuf_free(json->tmp); luaL_error(l, "Found too many nested data structures (%d) at character %d", json->current_depth, json->ptr - json->data); } static void json_parse_object_context(lua_State *l, json_parse_t *json) { json_token_t token; /* 3 slots required: * .., table, key, value */ json_decode_descend(l, json, 3); lua_newtable(l); json_next_token(json, &token); /* Handle empty objects */ if (token.type == T_OBJ_END) { json_decode_ascend(json); return; } while (1) { if (token.type != T_STRING) json_throw_parse_error(l, json, "object key string", &token); /* Push key */ lua_pushlstring(l, token.value.string, token.string_len); json_next_token(json, &token); if (token.type != T_COLON) json_throw_parse_error(l, json, "colon", &token); /* Fetch value */ json_next_token(json, &token); json_process_value(l, json, &token); /* Set key = value */ lua_rawset(l, -3); json_next_token(json, &token); if (token.type == T_OBJ_END) { json_decode_ascend(json); return; } if (token.type != T_COMMA) json_throw_parse_error(l, json, "comma or object end", &token); json_next_token(json, &token); } } /* Handle the array context */ static void json_parse_array_context(lua_State *l, json_parse_t *json) { json_token_t token; int i; /* 2 slots required: * .., table, value */ json_decode_descend(l, json, 2); lua_newtable(l); json_next_token(json, &token); /* Handle empty arrays */ if (token.type == T_ARR_END) { json_decode_ascend(json); return; } for (i = 1; ; i++) { json_process_value(l, json, &token); lua_rawseti(l, -2, i); /* arr[i] = value */ json_next_token(json, &token); if (token.type == T_ARR_END) { json_decode_ascend(json); return; } if (token.type != T_COMMA) json_throw_parse_error(l, json, "comma or array end", &token); json_next_token(json, &token); } } /* Handle the "value" context */ static void json_process_value(lua_State *l, json_parse_t *json, json_token_t *token) { switch (token->type) { case T_STRING: lua_pushlstring(l, token->value.string, token->string_len); break;; case T_NUMBER: { double num = token->value.number; double intpart; /* Convert to integer when possible for Lua 5.1 compatibility. * This ensures tostring(cjson.decode('{"id":42}').id) returns "42" not "42.0" */ if (modf(num, &intpart) == 0.0 && intpart >= LUA_MININTEGER && intpart <= LUA_MAXINTEGER) { lua_pushinteger(l, (lua_Integer)intpart); } else { lua_pushnumber(l, num); } break; } case T_BOOLEAN: lua_pushboolean(l, token->value.boolean); break;; case T_OBJ_BEGIN: json_parse_object_context(l, json); break;; case T_ARR_BEGIN: json_parse_array_context(l, json); break;; case T_NULL: /* In Lua, setting "t[k] = nil" will delete k from the table. * Hence a NULL pointer lightuserdata object is used instead */ lua_pushlightuserdata(l, NULL); break;; default: json_throw_parse_error(l, json, "value", token); } } static int json_decode(lua_State *l) { json_parse_t json; json_token_t token; size_t json_len; luaL_argcheck(l, lua_gettop(l) == 1, 1, "expected 1 argument"); json.cfg = json_fetch_config(l); json.data = luaL_checklstring(l, 1, &json_len); json.current_depth = 0; json.ptr = json.data; /* Detect Unicode other than UTF-8 (see RFC 4627, Sec 3) * * CJSON can support any simple data type, hence only the first * character is guaranteed to be ASCII (at worst: '"'). This is * still enough to detect whether the wrong encoding is in use. */ if (json_len >= 2 && (!json.data[0] || !json.data[1])) luaL_error(l, "JSON parser does not support UTF-16 or UTF-32"); /* Ensure the temporary buffer can hold the entire string. * This means we no longer need to do length checks since the decoded * string must be smaller than the entire json string */ json.tmp = strbuf_new(json_len); json_next_token(&json, &token); json_process_value(l, &json, &token); /* Ensure there is no more input left */ json_next_token(&json, &token); if (token.type != T_END) json_throw_parse_error(l, &json, "the end", &token); strbuf_free(json.tmp); return 1; } /* ===== INITIALISATION ===== */ #if !defined(LUA_VERSION_NUM) || LUA_VERSION_NUM < 502 /* Compatibility for Lua 5.1. * * luaL_setfuncs() is used to create a module table where the functions have * json_config_t as their first upvalue. Code borrowed from Lua 5.2 source. */ static void luaL_setfuncs (lua_State *l, const luaL_Reg *reg, int nup) { int i; luaL_checkstack(l, nup, "too many upvalues"); for (; reg->name != NULL; reg++) { /* fill the table with given functions */ for (i = 0; i < nup; i++) /* copy upvalues to the top */ lua_pushvalue(l, -nup); lua_pushcclosure(l, reg->func, nup); /* closure with those upvalues */ lua_setfield(l, -(nup + 2), reg->name); } lua_pop(l, nup); /* remove upvalues */ } #endif /* Call target function in protected mode with all supplied args. * Assumes target function only returns a single non-nil value. * Convert and return thrown errors as: nil, "error message" */ static int json_protect_conversion(lua_State *l) { int err; /* Deliberately throw an error for invalid arguments */ luaL_argcheck(l, lua_gettop(l) == 1, 1, "expected 1 argument"); /* pcall() the function stored as upvalue(1) */ lua_pushvalue(l, lua_upvalueindex(1)); lua_insert(l, 1); err = lua_pcall(l, 1, 1, 0); if (!err) return 1; if (err == LUA_ERRRUN) { lua_pushnil(l); lua_insert(l, -2); return 2; } /* Since we are not using a custom error handler, the only remaining * errors are memory related */ return luaL_error(l, "Memory allocation error in CJSON protected call"); } /* Return cjson module table */ static int lua_cjson_new(lua_State *l) { luaL_Reg reg[] = { { "encode", json_encode }, { "decode", json_decode }, { "encode_sparse_array", json_cfg_encode_sparse_array }, { "encode_max_depth", json_cfg_encode_max_depth }, { "decode_max_depth", json_cfg_decode_max_depth }, { "encode_number_precision", json_cfg_encode_number_precision }, { "encode_keep_buffer", json_cfg_encode_keep_buffer }, { "encode_invalid_numbers", json_cfg_encode_invalid_numbers }, { "decode_invalid_numbers", json_cfg_decode_invalid_numbers }, { "new", lua_cjson_new }, { NULL, NULL } }; /* Initialise number conversions */ fpconv_init(); /* cjson module table */ lua_newtable(l); /* Register functions with config data as upvalue */ json_create_config(l); luaL_setfuncs(l, reg, 1); /* Set cjson.null */ lua_pushlightuserdata(l, NULL); lua_setfield(l, -2, "null"); /* Set module name / version fields */ lua_pushliteral(l, CJSON_MODNAME); lua_setfield(l, -2, "_NAME"); lua_pushliteral(l, CJSON_VERSION); lua_setfield(l, -2, "_VERSION"); return 1; } /* Return cjson.safe module table */ static int lua_cjson_safe_new(lua_State *l) { const char *func[] = { "decode", "encode", NULL }; int i; lua_cjson_new(l); /* Fix new() method */ lua_pushcfunction(l, lua_cjson_safe_new); lua_setfield(l, -2, "new"); for (i = 0; func[i]; i++) { lua_getfield(l, -1, func[i]); lua_pushcclosure(l, json_protect_conversion, 1); lua_setfield(l, -2, func[i]); } return 1; } int luaopen_cjson(lua_State *l) { lua_cjson_new(l); lua_pushvalue(l, -1); lua_setglobal(l, CJSON_MODNAME); /* Return cjson table */ return 1; } int luaopen_cjson_safe(lua_State *l) { lua_cjson_safe_new(l); /* Return cjson.safe table */ return 1; } /* vi:ai et sw=4 ts=4: */ ================================================ FILE: src/redis/lua/cjson/strbuf.c ================================================ /* strbuf - String buffer routines * * Copyright (c) 2010-2012 Mark Pulford * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include #include #include #include #include "strbuf.h" static void die(const char *fmt, ...) { va_list arg; va_start(arg, fmt); vfprintf(stderr, fmt, arg); va_end(arg); fprintf(stderr, "\n"); exit(-1); } void strbuf_init(strbuf_t *s, int len) { int size; if (len <= 0) size = STRBUF_DEFAULT_SIZE; else size = len + 1; /* \0 terminator */ s->buf = NULL; s->size = size; s->length = 0; s->increment = STRBUF_DEFAULT_INCREMENT; s->dynamic = 0; s->reallocs = 0; s->debug = 0; s->buf = malloc(size); if (!s->buf) die("Out of memory"); strbuf_ensure_null(s); } strbuf_t *strbuf_new(int len) { strbuf_t *s; s = malloc(sizeof(strbuf_t)); if (!s) die("Out of memory"); strbuf_init(s, len); /* Dynamic strbuf allocation / deallocation */ s->dynamic = 1; return s; } void strbuf_set_increment(strbuf_t *s, int increment) { /* Increment > 0: Linear buffer growth rate * Increment < -1: Exponential buffer growth rate */ if (increment == 0 || increment == -1) die("BUG: Invalid string increment"); s->increment = increment; } static inline void debug_stats(strbuf_t *s) { if (s->debug) { fprintf(stderr, "strbuf(%lx) reallocs: %d, length: %d, size: %d\n", (long)s, s->reallocs, s->length, s->size); } } /* If strbuf_t has not been dynamically allocated, strbuf_free() can * be called any number of times strbuf_init() */ void strbuf_free(strbuf_t *s) { debug_stats(s); if (s->buf) { free(s->buf); s->buf = NULL; } if (s->dynamic) free(s); } char *strbuf_free_to_string(strbuf_t *s, int *len) { char *buf; debug_stats(s); strbuf_ensure_null(s); buf = s->buf; if (len) *len = s->length; if (s->dynamic) free(s); return buf; } static int calculate_new_size(strbuf_t *s, int len) { int reqsize, newsize; if (len <= 0) die("BUG: Invalid strbuf length requested"); /* Ensure there is room for optional NULL termination */ reqsize = len + 1; /* If the user has requested to shrink the buffer, do it exactly */ if (s->size > reqsize) return reqsize; newsize = s->size; if (s->increment < 0) { /* Exponential sizing */ while (newsize < reqsize) newsize *= -s->increment; } else { /* Linear sizing */ newsize = ((newsize + s->increment - 1) / s->increment) * s->increment; } return newsize; } /* Ensure strbuf can handle a string length bytes long (ignoring NULL * optional termination). */ void strbuf_resize(strbuf_t *s, int len) { int newsize; newsize = calculate_new_size(s, len); if (s->debug > 1) { fprintf(stderr, "strbuf(%lx) resize: %d => %d\n", (long)s, s->size, newsize); } s->size = newsize; s->buf = realloc(s->buf, s->size); if (!s->buf) die("Out of memory"); s->reallocs++; } void strbuf_append_string(strbuf_t *s, const char *str) { int space, i; space = strbuf_empty_length(s); for (i = 0; str[i]; i++) { if (space < 1) { strbuf_resize(s, s->length + 1); space = strbuf_empty_length(s); } s->buf[s->length] = str[i]; s->length++; space--; } } /* strbuf_append_fmt() should only be used when an upper bound * is known for the output string. */ void strbuf_append_fmt(strbuf_t *s, int len, const char *fmt, ...) { va_list arg; int fmt_len; strbuf_ensure_empty_length(s, len); va_start(arg, fmt); fmt_len = vsnprintf(s->buf + s->length, len, fmt, arg); va_end(arg); if (fmt_len < 0) die("BUG: Unable to convert number"); /* This should never happen.. */ s->length += fmt_len; } /* strbuf_append_fmt_retry() can be used when the there is no known * upper bound for the output string. */ void strbuf_append_fmt_retry(strbuf_t *s, const char *fmt, ...) { va_list arg; int fmt_len, try; int empty_len; /* If the first attempt to append fails, resize the buffer appropriately * and try again */ for (try = 0; ; try++) { va_start(arg, fmt); /* Append the new formatted string */ /* fmt_len is the length of the string required, excluding the * trailing NULL */ empty_len = strbuf_empty_length(s); /* Add 1 since there is also space to store the terminating NULL. */ fmt_len = vsnprintf(s->buf + s->length, empty_len + 1, fmt, arg); va_end(arg); if (fmt_len <= empty_len) break; /* SUCCESS */ if (try > 0) die("BUG: length of formatted string changed"); strbuf_resize(s, s->length + fmt_len); } s->length += fmt_len; } /* vi:ai et sw=4 ts=4: */ ================================================ FILE: src/redis/lua/cjson/strbuf.h ================================================ /* strbuf - String buffer routines * * Copyright (c) 2010-2012 Mark Pulford * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include #include /* Size: Total bytes allocated to *buf * Length: String length, excluding optional NULL terminator. * Increment: Allocation increments when resizing the string buffer. * Dynamic: True if created via strbuf_new() */ typedef struct { char *buf; int size; int length; int increment; int dynamic; int reallocs; int debug; } strbuf_t; #ifndef STRBUF_DEFAULT_SIZE #define STRBUF_DEFAULT_SIZE 1023 #endif #ifndef STRBUF_DEFAULT_INCREMENT #define STRBUF_DEFAULT_INCREMENT -2 #endif /* Initialise */ extern strbuf_t *strbuf_new(int len); extern void strbuf_init(strbuf_t *s, int len); extern void strbuf_set_increment(strbuf_t *s, int increment); /* Release */ extern void strbuf_free(strbuf_t *s); extern char *strbuf_free_to_string(strbuf_t *s, int *len); /* Management */ extern void strbuf_resize(strbuf_t *s, int len); static int strbuf_empty_length(strbuf_t *s); static int strbuf_length(strbuf_t *s); static char *strbuf_string(strbuf_t *s, int *len); static void strbuf_ensure_empty_length(strbuf_t *s, int len); static char *strbuf_empty_ptr(strbuf_t *s); static void strbuf_extend_length(strbuf_t *s, int len); /* Update */ extern void strbuf_append_fmt(strbuf_t *s, int len, const char *fmt, ...); extern void strbuf_append_fmt_retry(strbuf_t *s, const char *format, ...); static void strbuf_append_mem(strbuf_t *s, const char *c, int len); extern void strbuf_append_string(strbuf_t *s, const char *str); static void strbuf_append_char(strbuf_t *s, const char c); static void strbuf_ensure_null(strbuf_t *s); /* Reset string for before use */ static inline void strbuf_reset(strbuf_t *s) { s->length = 0; } static inline int strbuf_allocated(strbuf_t *s) { return s->buf != NULL; } /* Return bytes remaining in the string buffer * Ensure there is space for a NULL terminator. */ static inline int strbuf_empty_length(strbuf_t *s) { return s->size - s->length - 1; } static inline void strbuf_ensure_empty_length(strbuf_t *s, int len) { if (len > strbuf_empty_length(s)) strbuf_resize(s, s->length + len); } static inline char *strbuf_empty_ptr(strbuf_t *s) { return s->buf + s->length; } static inline void strbuf_extend_length(strbuf_t *s, int len) { s->length += len; } static inline int strbuf_length(strbuf_t *s) { return s->length; } static inline void strbuf_append_char(strbuf_t *s, const char c) { strbuf_ensure_empty_length(s, 1); s->buf[s->length++] = c; } static inline void strbuf_append_char_unsafe(strbuf_t *s, const char c) { s->buf[s->length++] = c; } static inline void strbuf_append_mem(strbuf_t *s, const char *c, int len) { strbuf_ensure_empty_length(s, len); memcpy(s->buf + s->length, c, len); s->length += len; } static inline void strbuf_append_mem_unsafe(strbuf_t *s, const char *c, int len) { memcpy(s->buf + s->length, c, len); s->length += len; } static inline void strbuf_ensure_null(strbuf_t *s) { s->buf[s->length] = 0; } static inline char *strbuf_string(strbuf_t *s, int *len) { if (len) *len = s->length; return s->buf; } /* vi:ai et sw=4 ts=4: */ ================================================ FILE: src/redis/lua/cmsgpack/lua_cmsgpack.c ================================================ #include #include #include #include #include #include "lua.h" #include "lauxlib.h" #define LUACMSGPACK_NAME "cmsgpack" #define LUACMSGPACK_SAFE_NAME "cmsgpack_safe" #define LUACMSGPACK_VERSION "lua-cmsgpack 0.4.0" #define LUACMSGPACK_COPYRIGHT "Copyright (C) 2012, Salvatore Sanfilippo" #define LUACMSGPACK_DESCRIPTION "MessagePack C implementation for Lua" /* Allows a preprocessor directive to override MAX_NESTING */ #ifndef LUACMSGPACK_MAX_NESTING #define LUACMSGPACK_MAX_NESTING 16 /* Max tables nesting. */ #endif /* Check if float or double can be an integer without loss of precision */ #define IS_INT_TYPE_EQUIVALENT(x, T) (!isinf(x) && (T)(x) == (x)) #define IS_INT64_EQUIVALENT(x) IS_INT_TYPE_EQUIVALENT(x, int64_t) #define IS_INT_EQUIVALENT(x) IS_INT_TYPE_EQUIVALENT(x, int) /* If size of pointer is equal to a 4 byte integer, we're on 32 bits. */ #if UINTPTR_MAX == UINT_MAX #define BITS_32 1 #else #define BITS_32 0 #endif #if BITS_32 #define lua_pushunsigned(L, n) lua_pushnumber(L, n) #else #define lua_pushunsigned(L, n) lua_pushinteger(L, n) #endif /* ============================================================================= * MessagePack implementation and bindings for Lua 5.1/5.2. * Copyright(C) 2012 Salvatore Sanfilippo * * http://github.com/antirez/lua-cmsgpack * * For MessagePack specification check the following web site: * http://wiki.msgpack.org/display/MSGPACK/Format+specification * * See Copyright Notice at the end of this file. * * CHANGELOG: * 19-Feb-2012 (ver 0.1.0): Initial release. * 20-Feb-2012 (ver 0.2.0): Tables encoding improved. * 20-Feb-2012 (ver 0.2.1): Minor bug fixing. * 20-Feb-2012 (ver 0.3.0): Module renamed lua-cmsgpack (was lua-msgpack). * 04-Apr-2014 (ver 0.3.1): Lua 5.2 support and minor bug fix. * 07-Apr-2014 (ver 0.4.0): Multiple pack/unpack, lua allocator, efficiency. * ========================================================================== */ /* -------------------------- Endian conversion -------------------------------- * We use it only for floats and doubles, all the other conversions performed * in an endian independent fashion. So the only thing we need is a function * that swaps a binary string if arch is little endian (and left it untouched * otherwise). */ /* Reverse memory bytes if arch is little endian. Given the conceptual * simplicity of the Lua build system we prefer check for endianess at runtime. * The performance difference should be acceptable. */ void memrevifle(void *ptr, size_t len) { unsigned char *p = (unsigned char *)ptr, *e = (unsigned char *)p+len-1, aux; int test = 1; unsigned char *testp = (unsigned char*) &test; if (testp[0] == 0) return; /* Big endian, nothing to do. */ len /= 2; while(len--) { aux = *p; *p = *e; *e = aux; p++; e--; } } /* ---------------------------- String buffer ---------------------------------- * This is a simple implementation of string buffers. The only operation * supported is creating empty buffers and appending bytes to it. * The string buffer uses 2x preallocation on every realloc for O(N) append * behavior. */ typedef struct mp_buf { unsigned char *b; size_t len, free; } mp_buf; void *mp_realloc(lua_State *L, void *target, size_t osize,size_t nsize) { void *(*local_realloc) (void *, void *, size_t osize, size_t nsize) = NULL; void *ud; local_realloc = lua_getallocf(L, &ud); return local_realloc(ud, target, osize, nsize); } mp_buf *mp_buf_new(lua_State *L) { mp_buf *buf = NULL; /* Old size = 0; new size = sizeof(*buf) */ buf = (mp_buf*)mp_realloc(L, NULL, 0, sizeof(*buf)); buf->b = NULL; buf->len = buf->free = 0; return buf; } void mp_buf_append(lua_State *L, mp_buf *buf, const unsigned char *s, size_t len) { if (buf->free < len) { size_t newsize = (buf->len+len)*2; buf->b = (unsigned char*)mp_realloc(L, buf->b, buf->len + buf->free, newsize); buf->free = newsize - buf->len; } memcpy(buf->b+buf->len,s,len); buf->len += len; buf->free -= len; } void mp_buf_free(lua_State *L, mp_buf *buf) { mp_realloc(L, buf->b, buf->len + buf->free, 0); /* realloc to 0 = free */ mp_realloc(L, buf, sizeof(*buf), 0); } /* ---------------------------- String cursor ---------------------------------- * This simple data structure is used for parsing. Basically you create a cursor * using a string pointer and a length, then it is possible to access the * current string position with cursor->p, check the remaining length * in cursor->left, and finally consume more string using * mp_cur_consume(cursor,len), to advance 'p' and subtract 'left'. * An additional field cursor->error is set to zero on initialization and can * be used to report errors. */ #define MP_CUR_ERROR_NONE 0 #define MP_CUR_ERROR_EOF 1 /* Not enough data to complete operation. */ #define MP_CUR_ERROR_BADFMT 2 /* Bad data format */ typedef struct mp_cur { const unsigned char *p; size_t left; int err; } mp_cur; void mp_cur_init(mp_cur *cursor, const unsigned char *s, size_t len) { cursor->p = s; cursor->left = len; cursor->err = MP_CUR_ERROR_NONE; } #define mp_cur_consume(_c,_len) do { _c->p += _len; _c->left -= _len; } while(0) /* When there is not enough room we set an error in the cursor and return. This * is very common across the code so we have a macro to make the code look * a bit simpler. */ #define mp_cur_need(_c,_len) do { \ if (_c->left < _len) { \ _c->err = MP_CUR_ERROR_EOF; \ return; \ } \ } while(0) /* ------------------------- Low level MP encoding -------------------------- */ void mp_encode_bytes(lua_State *L, mp_buf *buf, const unsigned char *s, size_t len) { unsigned char hdr[5]; int hdrlen; if (len < 32) { hdr[0] = 0xa0 | (len&0xff); /* fix raw */ hdrlen = 1; } else if (len <= 0xff) { hdr[0] = 0xd9; hdr[1] = len; hdrlen = 2; } else if (len <= 0xffff) { hdr[0] = 0xda; hdr[1] = (len&0xff00)>>8; hdr[2] = len&0xff; hdrlen = 3; } else { hdr[0] = 0xdb; hdr[1] = (len&0xff000000)>>24; hdr[2] = (len&0xff0000)>>16; hdr[3] = (len&0xff00)>>8; hdr[4] = len&0xff; hdrlen = 5; } mp_buf_append(L,buf,hdr,hdrlen); mp_buf_append(L,buf,s,len); } /* we assume IEEE 754 internal format for single and double precision floats. */ void mp_encode_double(lua_State *L, mp_buf *buf, double d) { unsigned char b[9]; float f = d; assert(sizeof(f) == 4 && sizeof(d) == 8); if (d == (double)f) { b[0] = 0xca; /* float IEEE 754 */ memcpy(b+1,&f,4); memrevifle(b+1,4); mp_buf_append(L,buf,b,5); } else if (sizeof(d) == 8) { b[0] = 0xcb; /* double IEEE 754 */ memcpy(b+1,&d,8); memrevifle(b+1,8); mp_buf_append(L,buf,b,9); } } void mp_encode_int(lua_State *L, mp_buf *buf, int64_t n) { unsigned char b[9]; int enclen; if (n >= 0) { if (n <= 127) { b[0] = n & 0x7f; /* positive fixnum */ enclen = 1; } else if (n <= 0xff) { b[0] = 0xcc; /* uint 8 */ b[1] = n & 0xff; enclen = 2; } else if (n <= 0xffff) { b[0] = 0xcd; /* uint 16 */ b[1] = (n & 0xff00) >> 8; b[2] = n & 0xff; enclen = 3; } else if (n <= 0xffffffffLL) { b[0] = 0xce; /* uint 32 */ b[1] = (n & 0xff000000) >> 24; b[2] = (n & 0xff0000) >> 16; b[3] = (n & 0xff00) >> 8; b[4] = n & 0xff; enclen = 5; } else { b[0] = 0xcf; /* uint 64 */ b[1] = (n & 0xff00000000000000LL) >> 56; b[2] = (n & 0xff000000000000LL) >> 48; b[3] = (n & 0xff0000000000LL) >> 40; b[4] = (n & 0xff00000000LL) >> 32; b[5] = (n & 0xff000000) >> 24; b[6] = (n & 0xff0000) >> 16; b[7] = (n & 0xff00) >> 8; b[8] = n & 0xff; enclen = 9; } } else { if (n >= -32) { b[0] = ((signed char)n); /* negative fixnum */ enclen = 1; } else if (n >= -128) { b[0] = 0xd0; /* int 8 */ b[1] = n & 0xff; enclen = 2; } else if (n >= -32768) { b[0] = 0xd1; /* int 16 */ b[1] = (n & 0xff00) >> 8; b[2] = n & 0xff; enclen = 3; } else if (n >= -2147483648LL) { b[0] = 0xd2; /* int 32 */ b[1] = (n & 0xff000000) >> 24; b[2] = (n & 0xff0000) >> 16; b[3] = (n & 0xff00) >> 8; b[4] = n & 0xff; enclen = 5; } else { b[0] = 0xd3; /* int 64 */ b[1] = (n & 0xff00000000000000LL) >> 56; b[2] = (n & 0xff000000000000LL) >> 48; b[3] = (n & 0xff0000000000LL) >> 40; b[4] = (n & 0xff00000000LL) >> 32; b[5] = (n & 0xff000000) >> 24; b[6] = (n & 0xff0000) >> 16; b[7] = (n & 0xff00) >> 8; b[8] = n & 0xff; enclen = 9; } } mp_buf_append(L,buf,b,enclen); } void mp_encode_array(lua_State *L, mp_buf *buf, int64_t n) { unsigned char b[5]; int enclen; if (n <= 15) { b[0] = 0x90 | (n & 0xf); /* fix array */ enclen = 1; } else if (n <= 65535) { b[0] = 0xdc; /* array 16 */ b[1] = (n & 0xff00) >> 8; b[2] = n & 0xff; enclen = 3; } else { b[0] = 0xdd; /* array 32 */ b[1] = (n & 0xff000000) >> 24; b[2] = (n & 0xff0000) >> 16; b[3] = (n & 0xff00) >> 8; b[4] = n & 0xff; enclen = 5; } mp_buf_append(L,buf,b,enclen); } void mp_encode_map(lua_State *L, mp_buf *buf, int64_t n) { unsigned char b[5]; int enclen; if (n <= 15) { b[0] = 0x80 | (n & 0xf); /* fix map */ enclen = 1; } else if (n <= 65535) { b[0] = 0xde; /* map 16 */ b[1] = (n & 0xff00) >> 8; b[2] = n & 0xff; enclen = 3; } else { b[0] = 0xdf; /* map 32 */ b[1] = (n & 0xff000000) >> 24; b[2] = (n & 0xff0000) >> 16; b[3] = (n & 0xff00) >> 8; b[4] = n & 0xff; enclen = 5; } mp_buf_append(L,buf,b,enclen); } /* --------------------------- Lua types encoding --------------------------- */ void mp_encode_lua_string(lua_State *L, mp_buf *buf) { size_t len; const char *s; s = lua_tolstring(L,-1,&len); mp_encode_bytes(L,buf,(const unsigned char*)s,len); } void mp_encode_lua_bool(lua_State *L, mp_buf *buf) { unsigned char b = lua_toboolean(L,-1) ? 0xc3 : 0xc2; mp_buf_append(L,buf,&b,1); } /* Lua 5.3 has a built in 64-bit integer type */ void mp_encode_lua_integer(lua_State *L, mp_buf *buf) { #if (LUA_VERSION_NUM < 503) && BITS_32 lua_Number i = lua_tonumber(L,-1); #else lua_Integer i = lua_tointeger(L,-1); #endif mp_encode_int(L, buf, (int64_t)i); } /* Lua 5.2 and lower only has 64-bit doubles, so we need to * detect if the double may be representable as an int * for Lua < 5.3 */ void mp_encode_lua_number(lua_State *L, mp_buf *buf) { lua_Number n = lua_tonumber(L,-1); if (IS_INT64_EQUIVALENT(n)) { mp_encode_lua_integer(L, buf); } else { mp_encode_double(L,buf,(double)n); } } void mp_encode_lua_type(lua_State *L, mp_buf *buf, int level); /* Convert a lua table into a message pack list. */ void mp_encode_lua_table_as_array(lua_State *L, mp_buf *buf, int level) { #if LUA_VERSION_NUM < 502 size_t len = lua_objlen(L,-1), j; #else size_t len = lua_rawlen(L,-1), j; #endif mp_encode_array(L,buf,len); luaL_checkstack(L, 1, "in function mp_encode_lua_table_as_array"); for (j = 1; j <= len; j++) { lua_pushnumber(L,j); lua_gettable(L,-2); mp_encode_lua_type(L,buf,level+1); } } /* Convert a lua table into a message pack key-value map. */ void mp_encode_lua_table_as_map(lua_State *L, mp_buf *buf, int level) { size_t len = 0; /* First step: count keys into table. No other way to do it with the * Lua API, we need to iterate a first time. Note that an alternative * would be to do a single run, and then hack the buffer to insert the * map opcodes for message pack. Too hackish for this lib. */ luaL_checkstack(L, 3, "in function mp_encode_lua_table_as_map"); lua_pushnil(L); while(lua_next(L,-2)) { lua_pop(L,1); /* remove value, keep key for next iteration. */ len++; } /* Step two: actually encoding of the map. */ mp_encode_map(L,buf,len); lua_pushnil(L); while(lua_next(L,-2)) { /* Stack: ... key value */ lua_pushvalue(L,-2); /* Stack: ... key value key */ mp_encode_lua_type(L,buf,level+1); /* encode key */ mp_encode_lua_type(L,buf,level+1); /* encode val */ } } /* Returns true if the Lua table on top of the stack is exclusively composed * of keys from numerical keys from 1 up to N, with N being the total number * of elements, without any hole in the middle. */ int table_is_an_array(lua_State *L) { int count = 0, max = 0; #if LUA_VERSION_NUM < 503 lua_Number n; #else lua_Integer n; #endif /* Stack top on function entry */ int stacktop; stacktop = lua_gettop(L); lua_pushnil(L); while(lua_next(L,-2)) { /* Stack: ... key value */ lua_pop(L,1); /* Stack: ... key */ /* The <= 0 check is valid here because we're comparing indexes. */ #if LUA_VERSION_NUM < 503 if ((LUA_TNUMBER != lua_type(L,-1)) || (n = lua_tonumber(L, -1)) <= 0 || !IS_INT_EQUIVALENT(n)) #else if (!lua_isinteger(L,-1) || (n = lua_tointeger(L, -1)) <= 0) #endif { lua_settop(L, stacktop); return 0; } max = (n > max ? n : max); count++; } /* We have the total number of elements in "count". Also we have * the max index encountered in "max". We can't reach this code * if there are indexes <= 0. If you also note that there can not be * repeated keys into a table, you have that if max==count you are sure * that there are all the keys form 1 to count (both included). */ lua_settop(L, stacktop); return max == count; } /* If the length operator returns non-zero, that is, there is at least * an object at key '1', we serialize to message pack list. Otherwise * we use a map. */ void mp_encode_lua_table(lua_State *L, mp_buf *buf, int level) { if (table_is_an_array(L)) mp_encode_lua_table_as_array(L,buf,level); else mp_encode_lua_table_as_map(L,buf,level); } void mp_encode_lua_null(lua_State *L, mp_buf *buf) { unsigned char b[1]; b[0] = 0xc0; mp_buf_append(L,buf,b,1); } void mp_encode_lua_type(lua_State *L, mp_buf *buf, int level) { int t = lua_type(L,-1); /* Limit the encoding of nested tables to a specified maximum depth, so that * we survive when called against circular references in tables. */ if (t == LUA_TTABLE && level == LUACMSGPACK_MAX_NESTING) t = LUA_TNIL; switch(t) { case LUA_TSTRING: mp_encode_lua_string(L,buf); break; case LUA_TBOOLEAN: mp_encode_lua_bool(L,buf); break; case LUA_TNUMBER: #if LUA_VERSION_NUM < 503 mp_encode_lua_number(L,buf); break; #else if (lua_isinteger(L, -1)) { mp_encode_lua_integer(L, buf); } else { mp_encode_lua_number(L, buf); } break; #endif case LUA_TTABLE: mp_encode_lua_table(L,buf,level); break; default: mp_encode_lua_null(L,buf); break; } lua_pop(L,1); } /* * Packs all arguments as a stream for multiple upacking later. * Returns error if no arguments provided. */ int mp_pack(lua_State *L) { int nargs = lua_gettop(L); int i; mp_buf *buf; if (nargs == 0) return luaL_argerror(L, 0, "MessagePack pack needs input."); if (!lua_checkstack(L, nargs)) return luaL_argerror(L, 0, "Too many arguments for MessagePack pack."); buf = mp_buf_new(L); for(i = 1; i <= nargs; i++) { /* Copy argument i to top of stack for _encode processing; * the encode function pops it from the stack when complete. */ luaL_checkstack(L, 1, "in function mp_check"); lua_pushvalue(L, i); mp_encode_lua_type(L,buf,0); lua_pushlstring(L,(char*)buf->b,buf->len); /* Reuse the buffer for the next operation by * setting its free count to the total buffer size * and the current position to zero. */ buf->free += buf->len; buf->len = 0; } mp_buf_free(L, buf); /* Concatenate all nargs buffers together */ lua_concat(L, nargs); return 1; } /* ------------------------------- Decoding --------------------------------- */ void mp_decode_to_lua_type(lua_State *L, mp_cur *c); void mp_decode_to_lua_array(lua_State *L, mp_cur *c, size_t len) { assert(len <= UINT_MAX); int index = 1; lua_newtable(L); luaL_checkstack(L, 1, "in function mp_decode_to_lua_array"); while(len--) { lua_pushnumber(L,index++); mp_decode_to_lua_type(L,c); if (c->err) return; lua_settable(L,-3); } } void mp_decode_to_lua_hash(lua_State *L, mp_cur *c, size_t len) { assert(len <= UINT_MAX); lua_newtable(L); while(len--) { mp_decode_to_lua_type(L,c); /* key */ if (c->err) return; mp_decode_to_lua_type(L,c); /* value */ if (c->err) return; lua_settable(L,-3); } } /* Decode a Message Pack raw object pointed by the string cursor 'c' to * a Lua type, that is left as the only result on the stack. */ void mp_decode_to_lua_type(lua_State *L, mp_cur *c) { mp_cur_need(c,1); /* If we return more than 18 elements, we must resize the stack to * fit all our return values. But, there is no way to * determine how many objects a msgpack will unpack to up front, so * we request a +1 larger stack on each iteration (noop if stack is * big enough, and when stack does require resize it doubles in size) */ luaL_checkstack(L, 1, "too many return values at once; " "use unpack_one or unpack_limit instead."); switch(c->p[0]) { case 0xcc: /* uint 8 */ mp_cur_need(c,2); lua_pushunsigned(L,c->p[1]); mp_cur_consume(c,2); break; case 0xd0: /* int 8 */ mp_cur_need(c,2); lua_pushinteger(L,(signed char)c->p[1]); mp_cur_consume(c,2); break; case 0xcd: /* uint 16 */ mp_cur_need(c,3); lua_pushunsigned(L, (c->p[1] << 8) | c->p[2]); mp_cur_consume(c,3); break; case 0xd1: /* int 16 */ mp_cur_need(c,3); lua_pushinteger(L,(int16_t) (c->p[1] << 8) | c->p[2]); mp_cur_consume(c,3); break; case 0xce: /* uint 32 */ mp_cur_need(c,5); lua_pushunsigned(L, ((uint32_t)c->p[1] << 24) | ((uint32_t)c->p[2] << 16) | ((uint32_t)c->p[3] << 8) | (uint32_t)c->p[4]); mp_cur_consume(c,5); break; case 0xd2: /* int 32 */ mp_cur_need(c,5); lua_pushinteger(L, ((int32_t)c->p[1] << 24) | ((int32_t)c->p[2] << 16) | ((int32_t)c->p[3] << 8) | (int32_t)c->p[4]); mp_cur_consume(c,5); break; case 0xcf: /* uint 64 */ mp_cur_need(c,9); lua_pushunsigned(L, ((uint64_t)c->p[1] << 56) | ((uint64_t)c->p[2] << 48) | ((uint64_t)c->p[3] << 40) | ((uint64_t)c->p[4] << 32) | ((uint64_t)c->p[5] << 24) | ((uint64_t)c->p[6] << 16) | ((uint64_t)c->p[7] << 8) | (uint64_t)c->p[8]); mp_cur_consume(c,9); break; case 0xd3: /* int 64 */ mp_cur_need(c,9); #if LUA_VERSION_NUM < 503 lua_pushnumber(L, #else lua_pushinteger(L, #endif ((int64_t)c->p[1] << 56) | ((int64_t)c->p[2] << 48) | ((int64_t)c->p[3] << 40) | ((int64_t)c->p[4] << 32) | ((int64_t)c->p[5] << 24) | ((int64_t)c->p[6] << 16) | ((int64_t)c->p[7] << 8) | (int64_t)c->p[8]); mp_cur_consume(c,9); break; case 0xc0: /* nil */ lua_pushnil(L); mp_cur_consume(c,1); break; case 0xc3: /* true */ lua_pushboolean(L,1); mp_cur_consume(c,1); break; case 0xc2: /* false */ lua_pushboolean(L,0); mp_cur_consume(c,1); break; case 0xca: /* float */ mp_cur_need(c,5); assert(sizeof(float) == 4); { float f; memcpy(&f,c->p+1,4); memrevifle(&f,4); lua_pushnumber(L,f); mp_cur_consume(c,5); } break; case 0xcb: /* double */ mp_cur_need(c,9); assert(sizeof(double) == 8); { double d; memcpy(&d,c->p+1,8); memrevifle(&d,8); lua_pushnumber(L,d); mp_cur_consume(c,9); } break; case 0xd9: /* raw 8 */ mp_cur_need(c,2); { size_t l = c->p[1]; mp_cur_need(c,2+l); lua_pushlstring(L,(char*)c->p+2,l); mp_cur_consume(c,2+l); } break; case 0xda: /* raw 16 */ mp_cur_need(c,3); { size_t l = (c->p[1] << 8) | c->p[2]; mp_cur_need(c,3+l); lua_pushlstring(L,(char*)c->p+3,l); mp_cur_consume(c,3+l); } break; case 0xdb: /* raw 32 */ mp_cur_need(c,5); { size_t l = ((size_t)c->p[1] << 24) | ((size_t)c->p[2] << 16) | ((size_t)c->p[3] << 8) | (size_t)c->p[4]; mp_cur_consume(c,5); mp_cur_need(c,l); lua_pushlstring(L,(char*)c->p,l); mp_cur_consume(c,l); } break; case 0xdc: /* array 16 */ mp_cur_need(c,3); { size_t l = (c->p[1] << 8) | c->p[2]; mp_cur_consume(c,3); mp_decode_to_lua_array(L,c,l); } break; case 0xdd: /* array 32 */ mp_cur_need(c,5); { size_t l = ((size_t)c->p[1] << 24) | ((size_t)c->p[2] << 16) | ((size_t)c->p[3] << 8) | (size_t)c->p[4]; mp_cur_consume(c,5); mp_decode_to_lua_array(L,c,l); } break; case 0xde: /* map 16 */ mp_cur_need(c,3); { size_t l = (c->p[1] << 8) | c->p[2]; mp_cur_consume(c,3); mp_decode_to_lua_hash(L,c,l); } break; case 0xdf: /* map 32 */ mp_cur_need(c,5); { size_t l = ((size_t)c->p[1] << 24) | ((size_t)c->p[2] << 16) | ((size_t)c->p[3] << 8) | (size_t)c->p[4]; mp_cur_consume(c,5); mp_decode_to_lua_hash(L,c,l); } break; default: /* types that can't be idenitified by first byte value. */ if ((c->p[0] & 0x80) == 0) { /* positive fixnum */ lua_pushunsigned(L,c->p[0]); mp_cur_consume(c,1); } else if ((c->p[0] & 0xe0) == 0xe0) { /* negative fixnum */ lua_pushinteger(L,(signed char)c->p[0]); mp_cur_consume(c,1); } else if ((c->p[0] & 0xe0) == 0xa0) { /* fix raw */ size_t l = c->p[0] & 0x1f; mp_cur_need(c,1+l); lua_pushlstring(L,(char*)c->p+1,l); mp_cur_consume(c,1+l); } else if ((c->p[0] & 0xf0) == 0x90) { /* fix map */ size_t l = c->p[0] & 0xf; mp_cur_consume(c,1); mp_decode_to_lua_array(L,c,l); } else if ((c->p[0] & 0xf0) == 0x80) { /* fix map */ size_t l = c->p[0] & 0xf; mp_cur_consume(c,1); mp_decode_to_lua_hash(L,c,l); } else { c->err = MP_CUR_ERROR_BADFMT; } } } int mp_unpack_full(lua_State *L, int limit, int offset) { size_t len; const char *s; mp_cur c; int cnt; /* Number of objects unpacked */ int decode_all = (!limit && !offset); s = luaL_checklstring(L,1,&len); /* if no match, exits */ if (offset < 0 || limit < 0) /* requesting negative off or lim is invalid */ return luaL_error(L, "Invalid request to unpack with offset of %d and limit of %d.", offset, len); else if (offset > len) return luaL_error(L, "Start offset %d greater than input length %d.", offset, len); if (decode_all) limit = INT_MAX; mp_cur_init(&c,(const unsigned char *)s+offset,len-offset); /* We loop over the decode because this could be a stream * of multiple top-level values serialized together */ for(cnt = 0; c.left > 0 && cnt < limit; cnt++) { mp_decode_to_lua_type(L,&c); if (c.err == MP_CUR_ERROR_EOF) { return luaL_error(L,"Missing bytes in input."); } else if (c.err == MP_CUR_ERROR_BADFMT) { return luaL_error(L,"Bad data format in input."); } } if (!decode_all) { /* c->left is the remaining size of the input buffer. * subtract the entire buffer size from the unprocessed size * to get our next start offset */ int offset = len - c.left; luaL_checkstack(L, 1, "in function mp_unpack_full"); /* Return offset -1 when we have have processed the entire buffer. */ lua_pushinteger(L, c.left == 0 ? -1 : offset); /* Results are returned with the arg elements still * in place. Lua takes care of only returning * elements above the args for us. * In this case, we have one arg on the stack * for this function, so we insert our first return * value at position 2. */ lua_insert(L, 2); cnt += 1; /* increase return count by one to make room for offset */ } return cnt; } int mp_unpack(lua_State *L) { return mp_unpack_full(L, 0, 0); } int mp_unpack_one(lua_State *L) { int offset = luaL_optinteger(L, 2, 0); /* Variable pop because offset may not exist */ lua_pop(L, lua_gettop(L)-1); return mp_unpack_full(L, 1, offset); } int mp_unpack_limit(lua_State *L) { int limit = luaL_checkinteger(L, 2); int offset = luaL_optinteger(L, 3, 0); /* Variable pop because offset may not exist */ lua_pop(L, lua_gettop(L)-1); return mp_unpack_full(L, limit, offset); } int mp_safe(lua_State *L) { int argc, err, total_results; argc = lua_gettop(L); /* This adds our function to the bottom of the stack * (the "call this function" position) */ lua_pushvalue(L, lua_upvalueindex(1)); lua_insert(L, 1); err = lua_pcall(L, argc, LUA_MULTRET, 0); total_results = lua_gettop(L); if (!err) { return total_results; } else { lua_pushnil(L); lua_insert(L,-2); return 2; } } /* -------------------------------------------------------------------------- */ const struct luaL_Reg cmds[] = { {"pack", mp_pack}, {"unpack", mp_unpack}, {"unpack_one", mp_unpack_one}, {"unpack_limit", mp_unpack_limit}, {0} }; int luaopen_create(lua_State *L) { int i; /* Manually construct our module table instead of * relying on _register or _newlib */ lua_newtable(L); for (i = 0; i < (sizeof(cmds)/sizeof(*cmds) - 1); i++) { lua_pushcfunction(L, cmds[i].func); lua_setfield(L, -2, cmds[i].name); } /* Add metadata */ lua_pushliteral(L, LUACMSGPACK_NAME); lua_setfield(L, -2, "_NAME"); lua_pushliteral(L, LUACMSGPACK_VERSION); lua_setfield(L, -2, "_VERSION"); lua_pushliteral(L, LUACMSGPACK_COPYRIGHT); lua_setfield(L, -2, "_COPYRIGHT"); lua_pushliteral(L, LUACMSGPACK_DESCRIPTION); lua_setfield(L, -2, "_DESCRIPTION"); return 1; } LUALIB_API int luaopen_cmsgpack(lua_State *L) { luaopen_create(L); lua_pushvalue(L, -1); lua_setglobal(L, LUACMSGPACK_NAME); return 1; } LUALIB_API int luaopen_cmsgpack_safe(lua_State *L) { int i; luaopen_cmsgpack(L); /* Wrap all functions in the safe handler */ for (i = 0; i < (sizeof(cmds)/sizeof(*cmds) - 1); i++) { lua_getfield(L, -1, cmds[i].name); lua_pushcclosure(L, mp_safe, 1); lua_setfield(L, -2, cmds[i].name); } #if LUA_VERSION_NUM < 502 /* Register name globally for 5.1 */ lua_pushvalue(L, -1); lua_setglobal(L, LUACMSGPACK_SAFE_NAME); #endif return 1; } /****************************************************************************** * Copyright (C) 2012 Salvatore Sanfilippo. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ******************************************************************************/ ================================================ FILE: src/redis/lua/struct/lua_struct.c ================================================ /* ** {====================================================== ** Library for packing/unpacking structures. ** $Id: struct.c,v 1.7 2018/05/11 22:04:31 roberto Exp $ ** See Copyright Notice at the end of this file ** ======================================================= */ /* ** Valid formats: ** > - big endian ** < - little endian ** ![num] - alignment ** x - pading ** b/B - signed/unsigned byte ** h/H - signed/unsigned short ** l/L - signed/unsigned long ** T - size_t ** i/In - signed/unsigned integer with size 'n' (default is size of int) ** cn - sequence of 'n' chars (from/to a string); when packing, n==0 means the whole string; when unpacking, n==0 means use the previous read number as the string length ** s - zero-terminated string ** f - float ** d - double ** ' ' - ignored */ #include #include #include #include #include #include "lua.h" #include "lauxlib.h" /* basic integer type */ #if !defined(STRUCT_INT) #define STRUCT_INT long #endif typedef STRUCT_INT Inttype; /* corresponding unsigned version */ typedef unsigned STRUCT_INT Uinttype; /* maximum size (in bytes) for integral types */ #define MAXINTSIZE 32 /* is 'x' a power of 2? */ #define isp2(x) ((x) > 0 && ((x) & ((x) - 1)) == 0) /* dummy structure to get alignment requirements */ struct cD { char c; double d; }; #define PADDING (sizeof(struct cD) - sizeof(double)) #define MAXALIGN (PADDING > sizeof(int) ? PADDING : sizeof(int)) /* endian options */ #define BIG 0 #define LITTLE 1 static union { int dummy; char endian; } const native = {1}; typedef struct Header { int endian; int align; } Header; static int getnum (lua_State *L, const char **fmt, int df) { if (!isdigit(**fmt)) /* no number? */ return df; /* return default value */ else { int a = 0; do { if (a > (INT_MAX / 10) || a * 10 > (INT_MAX - (**fmt - '0'))) luaL_error(L, "integral size overflow"); a = a*10 + *((*fmt)++) - '0'; } while (isdigit(**fmt)); return a; } } #define defaultoptions(h) ((h)->endian = native.endian, (h)->align = 1) static size_t optsize (lua_State *L, char opt, const char **fmt) { switch (opt) { case 'B': case 'b': return sizeof(char); case 'H': case 'h': return sizeof(short); case 'L': case 'l': return sizeof(long); case 'T': return sizeof(size_t); case 'f': return sizeof(float); case 'd': return sizeof(double); case 'x': return 1; case 'c': return getnum(L, fmt, 1); case 'i': case 'I': { int sz = getnum(L, fmt, sizeof(int)); if (sz > MAXINTSIZE) luaL_error(L, "integral size %d is larger than limit of %d", sz, MAXINTSIZE); return sz; } default: return 0; /* other cases do not need alignment */ } } /* ** return number of bytes needed to align an element of size 'size' ** at current position 'len' */ static int gettoalign (size_t len, Header *h, int opt, size_t size) { if (size == 0 || opt == 'c') return 0; if (size > (size_t)h->align) size = h->align; /* respect max. alignment */ return (size - (len & (size - 1))) & (size - 1); } /* ** options to control endianess and alignment */ static void controloptions (lua_State *L, int opt, const char **fmt, Header *h) { switch (opt) { case ' ': return; /* ignore white spaces */ case '>': h->endian = BIG; return; case '<': h->endian = LITTLE; return; case '!': { int a = getnum(L, fmt, MAXALIGN); if (!isp2(a)) luaL_error(L, "alignment %d is not a power of 2", a); h->align = a; return; } default: { const char *msg = lua_pushfstring(L, "invalid format option '%c'", opt); luaL_argerror(L, 1, msg); } } } static void putinteger (lua_State *L, luaL_Buffer *b, int arg, int endian, int size) { lua_Number n = luaL_checknumber(L, arg); Uinttype value; char buff[MAXINTSIZE]; if (n < 0) value = (Uinttype)(Inttype)n; else value = (Uinttype)n; if (endian == LITTLE) { int i; for (i = 0; i < size; i++) { buff[i] = (value & 0xff); value >>= 8; } } else { int i; for (i = size - 1; i >= 0; i--) { buff[i] = (value & 0xff); value >>= 8; } } luaL_addlstring(b, buff, size); } static void correctbytes (char *b, int size, int endian) { if (endian != native.endian) { int i = 0; while (i < --size) { char temp = b[i]; b[i++] = b[size]; b[size] = temp; } } } static int b_pack (lua_State *L) { luaL_Buffer b; const char *fmt = luaL_checkstring(L, 1); Header h; int arg = 2; size_t totalsize = 0; defaultoptions(&h); lua_pushnil(L); /* mark to separate arguments from string buffer */ luaL_buffinit(L, &b); while (*fmt != '\0') { int opt = *fmt++; size_t size = optsize(L, opt, &fmt); int toalign = gettoalign(totalsize, &h, opt, size); totalsize += toalign; while (toalign-- > 0) luaL_addchar(&b, '\0'); switch (opt) { case 'b': case 'B': case 'h': case 'H': case 'l': case 'L': case 'T': case 'i': case 'I': { /* integer types */ putinteger(L, &b, arg++, h.endian, size); break; } case 'x': { luaL_addchar(&b, '\0'); break; } case 'f': { float f = (float)luaL_checknumber(L, arg++); correctbytes((char *)&f, size, h.endian); luaL_addlstring(&b, (char *)&f, size); break; } case 'd': { double d = luaL_checknumber(L, arg++); correctbytes((char *)&d, size, h.endian); luaL_addlstring(&b, (char *)&d, size); break; } case 'c': case 's': { size_t l; const char *s = luaL_checklstring(L, arg++, &l); if (size == 0) size = l; luaL_argcheck(L, l >= (size_t)size, arg, "string too short"); luaL_addlstring(&b, s, size); if (opt == 's') { luaL_addchar(&b, '\0'); /* add zero at the end */ size++; } break; } default: controloptions(L, opt, &fmt, &h); } totalsize += size; } luaL_pushresult(&b); return 1; } static lua_Number getinteger (const char *buff, int endian, int issigned, int size) { Uinttype l = 0; int i; if (endian == BIG) { for (i = 0; i < size; i++) { l <<= 8; l |= (Uinttype)(unsigned char)buff[i]; } } else { for (i = size - 1; i >= 0; i--) { l <<= 8; l |= (Uinttype)(unsigned char)buff[i]; } } if (!issigned) return (lua_Number)l; else { /* signed format */ Uinttype mask = (Uinttype)(~((Uinttype)0)) << (size*8 - 1); if (l & mask) /* negative value? */ l |= mask; /* signal extension */ return (lua_Number)(Inttype)l; } } static int b_unpack (lua_State *L) { Header h; const char *fmt = luaL_checkstring(L, 1); size_t ld; const char *data = luaL_checklstring(L, 2, &ld); size_t pos = luaL_optinteger(L, 3, 1); luaL_argcheck(L, pos > 0, 3, "offset must be 1 or greater"); pos--; /* Lua indexes are 1-based, but here we want 0-based for C * pointer math. */ int n = 0; /* number of results */ defaultoptions(&h); while (*fmt) { int opt = *fmt++; size_t size = optsize(L, opt, &fmt); pos += gettoalign(pos, &h, opt, size); luaL_argcheck(L, size <= ld && pos <= ld - size, 2, "data string too short"); /* stack space for item + next position */ luaL_checkstack(L, 2, "too many results"); switch (opt) { case 'b': case 'B': case 'h': case 'H': case 'l': case 'L': case 'T': case 'i': case 'I': { /* integer types */ int issigned = islower(opt); lua_Number res = getinteger(data+pos, h.endian, issigned, size); lua_pushnumber(L, res); n++; break; } case 'x': { break; } case 'f': { float f; memcpy(&f, data+pos, size); correctbytes((char *)&f, sizeof(f), h.endian); lua_pushnumber(L, f); n++; break; } case 'd': { double d; memcpy(&d, data+pos, size); correctbytes((char *)&d, sizeof(d), h.endian); lua_pushnumber(L, d); n++; break; } case 'c': { if (size == 0) { if (n == 0 || !lua_isnumber(L, -1)) luaL_error(L, "format 'c0' needs a previous size"); size = lua_tonumber(L, -1); lua_pop(L, 1); n--; luaL_argcheck(L, size <= ld && pos <= ld - size, 2, "data string too short"); } lua_pushlstring(L, data+pos, size); n++; break; } case 's': { const char *e = (const char *)memchr(data+pos, '\0', ld - pos); if (e == NULL) luaL_error(L, "unfinished string in data"); size = (e - (data+pos)) + 1; lua_pushlstring(L, data+pos, size - 1); n++; break; } default: controloptions(L, opt, &fmt, &h); } pos += size; } lua_pushinteger(L, pos + 1); /* next position */ return n + 1; } static int b_size (lua_State *L) { Header h; const char *fmt = luaL_checkstring(L, 1); size_t pos = 0; defaultoptions(&h); while (*fmt) { int opt = *fmt++; size_t size = optsize(L, opt, &fmt); pos += gettoalign(pos, &h, opt, size); if (opt == 's') luaL_argerror(L, 1, "option 's' has no fixed size"); else if (opt == 'c' && size == 0) luaL_argerror(L, 1, "option 'c0' has no fixed size"); if (!isalnum(opt)) controloptions(L, opt, &fmt, &h); pos += size; } lua_pushinteger(L, pos); return 1; } /* }====================================================== */ static const struct luaL_Reg thislib[] = { {"pack", b_pack}, {"unpack", b_unpack}, {"size", b_size}, {NULL, NULL} }; LUALIB_API int luaopen_struct (lua_State *L); LUALIB_API int luaopen_struct (lua_State *L) { luaL_newlib(L, thislib); lua_setglobal(L, "struct"); return 1; } /****************************************************************************** * Copyright (C) 2010-2018 Lua.org, PUC-Rio. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ******************************************************************************/ ================================================ FILE: src/redis/lzf.h ================================================ /* * Copyright (c) 2000-2008 Marc Alexander Lehmann * * Redistribution and use in source and binary forms, with or without modifica- * tion, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MER- * CHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTH- * ERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * * Alternatively, the contents of this file may be used under the terms of * the GNU General Public License ("GPL") version 2 or any later version, * in which case the provisions of the GPL are applicable instead of * the above. If you wish to allow the use of your version of this file * only under the terms of the GPL and not to allow others to use your * version of this file under the BSD license, indicate your decision * by deleting the provisions above and replace them with the notice * and other provisions required by the GPL. If you do not delete the * provisions above, a recipient may use your version of this file under * either the BSD or the GPL. */ #ifndef LZF_H #define LZF_H /*********************************************************************** ** ** lzf -- an extremely fast/free compression/decompression-method ** http://liblzf.plan9.de/ ** ** This algorithm is believed to be patent-free. ** ***********************************************************************/ #define LZF_VERSION 0x0105 /* 1.5, API version */ /* * Compress in_len bytes stored at the memory block starting at * in_data and write the result to out_data, up to a maximum length * of out_len bytes. * * If the output buffer is not large enough or any error occurs return 0, * otherwise return the number of bytes used, which might be considerably * more than in_len (but less than 104% of the original size), so it * makes sense to always use out_len == in_len - 1), to ensure _some_ * compression, and store the data uncompressed otherwise (with a flag, of * course. * * lzf_compress might use different algorithms on different systems and * even different runs, thus might result in different compressed strings * depending on the phase of the moon or similar factors. However, all * these strings are architecture-independent and will result in the * original data when decompressed using lzf_decompress. * * The buffers must not be overlapping. * * If the option LZF_STATE_ARG is enabled, an extra argument must be * supplied which is not reflected in this header file. Refer to lzfP.h * and lzf_c.c. * */ size_t lzf_compress (const void *const in_data, size_t in_len, void *out_data, size_t out_len #if LZF_STATE_ARG , LZF_STATE htab #endif ); /* * Decompress data compressed with some version of the lzf_compress * function and stored at location in_data and length in_len. The result * will be stored at out_data up to a maximum of out_len characters. * * If the output buffer is not large enough to hold the decompressed * data, a 0 is returned and errno is set to E2BIG. Otherwise the number * of decompressed bytes (i.e. the original length of the data) is * returned. * * If an error in the compressed data is detected, a zero is returned and * errno is set to EINVAL. * * This function is very fast, about as fast as a copying loop. */ size_t lzf_decompress (const void *const in_data, size_t in_len, void *out_data, size_t out_len); #endif ================================================ FILE: src/redis/lzfP.h ================================================ /* * Copyright (c) 2000-2007 Marc Alexander Lehmann * * Redistribution and use in source and binary forms, with or without modifica- * tion, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MER- * CHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTH- * ERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * * Alternatively, the contents of this file may be used under the terms of * the GNU General Public License ("GPL") version 2 or any later version, * in which case the provisions of the GPL are applicable instead of * the above. If you wish to allow the use of your version of this file * only under the terms of the GPL and not to allow others to use your * version of this file under the BSD license, indicate your decision * by deleting the provisions above and replace them with the notice * and other provisions required by the GPL. If you do not delete the * provisions above, a recipient may use your version of this file under * either the BSD or the GPL. */ #ifndef LZFP_h #define LZFP_h // ROMAN: #define STANDALONE 1 /* at the moment, this is ok. */ /* ROMAN: Moved below since it depends on LZF_STATE #ifndef STANDALONE # include "lzf.h" #endif */ /* * Size of hashtable is (1 << HLOG) * sizeof (char *) * decompression is independent of the hash table size * the difference between 15 and 14 is very small * for small blocks (and 14 is usually a bit faster). * For a low-memory/faster configuration, use HLOG == 13; * For best compression, use 15 or 16 (or more, up to 22). */ #ifndef HLOG # define HLOG 16 #endif /* * Sacrifice very little compression quality in favour of compression speed. * This gives almost the same compression as the default code, and is * (very roughly) 15% faster. This is the preferred mode of operation. */ #ifndef VERY_FAST # define VERY_FAST 1 #endif /* * Sacrifice some more compression quality in favour of compression speed. * (roughly 1-2% worse compression for large blocks and * 9-10% for small, redundant, blocks and >>20% better speed in both cases) * In short: when in need for speed, enable this for binary data, * possibly disable this for text data. */ #ifndef ULTRA_FAST # define ULTRA_FAST 0 #endif /* * Unconditionally aligning does not cost very much, so do it if unsure */ #ifndef STRICT_ALIGN # if !(defined(__i386) || defined (__amd64)) # define STRICT_ALIGN 1 # else # define STRICT_ALIGN 0 # endif #endif /* * You may choose to pre-set the hash table (might be faster on some * modern cpus and large (>>64k) blocks, and also makes compression * deterministic/repeatable when the configuration otherwise is the same). */ #ifndef INIT_HTAB # define INIT_HTAB 0 #endif /* * Avoid assigning values to errno variable? for some embedding purposes * (linux kernel for example), this is necessary. NOTE: this breaks * the documentation in lzf.h. Avoiding errno has no speed impact. */ #ifndef AVOID_ERRNO # define AVOID_ERRNO 0 #endif /* * Whether to pass the LZF_STATE variable as argument, or allocate it * on the stack. For small-stack environments, define this to 1. * NOTE: this breaks the prototype in lzf.h. */ #ifndef LZF_STATE_ARG # define LZF_STATE_ARG 1 // ROMAN #endif /* * Whether to add extra checks for input validity in lzf_decompress * and return EINVAL if the input stream has been corrupted. This * only shields against overflowing the input buffer and will not * detect most corrupted streams. * This check is not normally noticeable on modern hardware * (<1% slowdown), but might slow down older cpus considerably. */ #ifndef CHECK_INPUT # define CHECK_INPUT 1 #endif /* * Whether to store pointers or offsets inside the hash table. On * 64 bit architectures, pointers take up twice as much space, * and might also be slower. Default is to autodetect. * Notice: Don't set this value to 1, it will result in 'LZF_HSLOT' * not being able to store offset above UINT32_MAX in 64bit. */ #define LZF_USE_OFFSETS 0 /*****************************************************************************/ /* nothing should be changed below */ #ifdef __cplusplus # include # include using namespace std; #else # include # include #endif #ifndef LZF_USE_OFFSETS # if defined (WIN32) # define LZF_USE_OFFSETS defined(_M_X64) # else # if __cplusplus > 199711L # include # else # include # endif # define LZF_USE_OFFSETS (UINTPTR_MAX > 0xffffffffU) # endif #endif typedef unsigned char u8; #if LZF_USE_OFFSETS # define LZF_HSLOT_BIAS ((const u8 *)in_data) typedef unsigned int LZF_HSLOT; #else # define LZF_HSLOT_BIAS 0 typedef const u8 *LZF_HSLOT; #endif typedef LZF_HSLOT LZF_STATE[1 << (HLOG)]; // ROMAN: moved here deliberately because we depend on LZF_STATE. #ifndef STANDALONE # include "lzf.h" #endif #if !STRICT_ALIGN /* for unaligned accesses we need a 16 bit datatype. */ # if USHRT_MAX == 65535 typedef unsigned short u16; # elif UINT_MAX == 65535 typedef unsigned int u16; # else # undef STRICT_ALIGN # define STRICT_ALIGN 1 # endif #endif #if ULTRA_FAST # undef VERY_FAST #endif #endif ================================================ FILE: src/redis/lzf_c.c ================================================ /* * Copyright (c) 2000-2010 Marc Alexander Lehmann * * Redistribution and use in source and binary forms, with or without modifica- * tion, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MER- * CHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTH- * ERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * * Alternatively, the contents of this file may be used under the terms of * the GNU General Public License ("GPL") version 2 or any later version, * in which case the provisions of the GPL are applicable instead of * the above. If you wish to allow the use of your version of this file * only under the terms of the GPL and not to allow others to use your * version of this file under the BSD license, indicate your decision * by deleting the provisions above and replace them with the notice * and other provisions required by the GPL. If you do not delete the * provisions above, a recipient may use your version of this file under * either the BSD or the GPL. */ #include "lzfP.h" #define HSIZE (1 << (HLOG)) /* * don't play with this unless you benchmark! * the data format is not dependent on the hash function. * the hash function might seem strange, just believe me, * it works ;) */ #ifndef FRST # define FRST(p) (((p[0]) << 8) | p[1]) # define NEXT(v,p) (((v) << 8) | p[2]) # if ULTRA_FAST # define IDX(h) ((( h >> (3*8 - HLOG)) - h ) & (HSIZE - 1)) # elif VERY_FAST # define IDX(h) ((( h >> (3*8 - HLOG)) - h*5) & (HSIZE - 1)) # else # define IDX(h) ((((h ^ (h << 5)) >> (3*8 - HLOG)) - h*5) & (HSIZE - 1)) # endif #endif /* * IDX works because it is very similar to a multiplicative hash, e.g. * ((h * 57321 >> (3*8 - HLOG)) & (HSIZE - 1)) * the latter is also quite fast on newer CPUs, and compresses similarly. * * the next one is also quite good, albeit slow ;) * (int)(cos(h & 0xffffff) * 1e6) */ #if 0 /* original lzv-like hash function, much worse and thus slower */ # define FRST(p) (p[0] << 5) ^ p[1] # define NEXT(v,p) ((v) << 5) ^ p[2] # define IDX(h) ((h) & (HSIZE - 1)) #endif #define MAX_LIT (1 << 5) #define MAX_OFF (1 << 13) #define MAX_REF ((1 << 8) + (1 << 3)) #if __GNUC__ >= 3 # define expect(expr,value) __builtin_expect ((expr),(value)) # define inline inline #else # define expect(expr,value) (expr) # define inline static #endif #define expect_false(expr) expect ((expr) != 0, 0) #define expect_true(expr) expect ((expr) != 0, 1) #if defined(__has_attribute) # if __has_attribute(no_sanitize) # define NO_SANITIZE(sanitizer) __attribute__((no_sanitize(sanitizer))) # endif #endif #if !defined(NO_SANITIZE) # define NO_SANITIZE(sanitizer) #endif /* * compressed format * * 000LLLLL ; literal, L+1=1..33 octets * LLLooooo oooooooo ; backref L+1=1..7 octets, o+1=1..4096 offset * 111ooooo LLLLLLLL oooooooo ; backref L+8 octets, o+1=1..4096 offset * */ NO_SANITIZE("alignment") size_t lzf_compress (const void *const in_data, size_t in_len, void *out_data, size_t out_len #if LZF_STATE_ARG , LZF_STATE htab #endif ) { #if !LZF_STATE_ARG LZF_STATE htab; #endif const u8 *ip = (const u8 *)in_data; u8 *op = (u8 *)out_data; const u8 *in_end = ip + in_len; u8 *out_end = op + out_len; const u8 *ref; /* off requires a type wide enough to hold a general pointer difference. * ISO C doesn't have that (size_t might not be enough and ptrdiff_t only * works for differences within a single object). We also assume that no * no bit pattern traps. Since the only platform that is both non-POSIX * and fails to support both assumptions is windows 64 bit, we make a * special workaround for it. */ #if defined (WIN32) && defined (_M_X64) unsigned _int64 off; /* workaround for missing POSIX compliance */ #else size_t off; #endif unsigned int hval; int lit; if (!in_len || !out_len) return 0; #if INIT_HTAB memset (htab, 0, sizeof (htab)); #endif lit = 0; op++; /* start run */ hval = FRST (ip); while (ip < in_end - 2) { LZF_HSLOT *hslot; hval = NEXT (hval, ip); hslot = htab + IDX (hval); ref = *hslot ? (*hslot + LZF_HSLOT_BIAS) : NULL; /* avoid applying zero offset to null pointer */ *hslot = ip - LZF_HSLOT_BIAS; if (1 #if INIT_HTAB && ref < ip /* the next test will actually take care of this, but this is faster */ #endif && (off = ip - ref - 1) < MAX_OFF && ref > (u8 *)in_data && ref[2] == ip[2] #if STRICT_ALIGN && ((ref[1] << 8) | ref[0]) == ((ip[1] << 8) | ip[0]) #else && *(u16 *)ref == *(u16 *)ip #endif ) { /* match found at *ref++ */ unsigned int len = 2; size_t maxlen = in_end - ip - len; maxlen = maxlen > MAX_REF ? MAX_REF : maxlen; if (expect_false (op + 3 + 1 >= out_end)) /* first a faster conservative test */ if (op - !lit + 3 + 1 >= out_end) /* second the exact but rare test */ return 0; op [- lit - 1] = lit - 1; /* stop run */ op -= !lit; /* undo run if length is zero */ for (;;) { if (expect_true (maxlen > 16)) { len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; len++; if (ref [len] != ip [len]) break; } do len++; while (len < maxlen && ref[len] == ip[len]); break; } len -= 2; /* len is now #octets - 1 */ ip++; if (len < 7) { *op++ = (off >> 8) + (len << 5); } else { *op++ = (off >> 8) + ( 7 << 5); *op++ = len - 7; } *op++ = off; lit = 0; op++; /* start run */ ip += len + 1; if (expect_false (ip >= in_end - 2)) break; #if ULTRA_FAST || VERY_FAST --ip; # if VERY_FAST && !ULTRA_FAST --ip; # endif hval = FRST (ip); hval = NEXT (hval, ip); htab[IDX (hval)] = ip - LZF_HSLOT_BIAS; ip++; # if VERY_FAST && !ULTRA_FAST hval = NEXT (hval, ip); htab[IDX (hval)] = ip - LZF_HSLOT_BIAS; ip++; # endif #else ip -= len + 1; do { hval = NEXT (hval, ip); htab[IDX (hval)] = ip - LZF_HSLOT_BIAS; ip++; } while (len--); #endif } else { /* one more literal byte we must copy */ if (expect_false (op >= out_end)) return 0; lit++; *op++ = *ip++; if (expect_false (lit == MAX_LIT)) { op [- lit - 1] = lit - 1; /* stop run */ lit = 0; op++; /* start run */ } } } if (op + 3 > out_end) /* at most 3 bytes can be missing here */ return 0; while (ip < in_end) { lit++; *op++ = *ip++; if (expect_false (lit == MAX_LIT)) { op [- lit - 1] = lit - 1; /* stop run */ lit = 0; op++; /* start run */ } } op [- lit - 1] = lit - 1; /* end run */ op -= !lit; /* undo run if length is zero */ return op - (u8 *)out_data; } ================================================ FILE: src/redis/lzf_d.c ================================================ /* * Copyright (c) 2000-2010 Marc Alexander Lehmann * * Redistribution and use in source and binary forms, with or without modifica- * tion, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MER- * CHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTH- * ERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * * Alternatively, the contents of this file may be used under the terms of * the GNU General Public License ("GPL") version 2 or any later version, * in which case the provisions of the GPL are applicable instead of * the above. If you wish to allow the use of your version of this file * only under the terms of the GPL and not to allow others to use your * version of this file under the BSD license, indicate your decision * by deleting the provisions above and replace them with the notice * and other provisions required by the GPL. If you do not delete the * provisions above, a recipient may use your version of this file under * either the BSD or the GPL. */ #include "lzfP.h" #if AVOID_ERRNO # define SET_ERRNO(n) #else # include # define SET_ERRNO(n) errno = (n) #endif #if USE_REP_MOVSB /* small win on amd, big loss on intel */ #if (__i386 || __amd64) && __GNUC__ >= 3 # define lzf_movsb(dst, src, len) \ asm ("rep movsb" \ : "=D" (dst), "=S" (src), "=c" (len) \ : "0" (dst), "1" (src), "2" (len)); #endif #endif #if defined(__GNUC__) && __GNUC__ >= 7 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wimplicit-fallthrough" #endif size_t lzf_decompress (const void *const in_data, size_t in_len, void *out_data, size_t out_len) { u8 const *ip = (const u8 *)in_data; u8 *op = (u8 *)out_data; u8 const *const in_end = ip + in_len; u8 *const out_end = op + out_len; while (ip < in_end) { unsigned int ctrl; ctrl = *ip++; if (ctrl < (1 << 5)) /* literal run */ { ctrl++; if (op + ctrl > out_end) { SET_ERRNO (E2BIG); return 0; } #if CHECK_INPUT if (ip + ctrl > in_end) { SET_ERRNO (EINVAL); return 0; } #endif #ifdef lzf_movsb lzf_movsb (op, ip, ctrl); #else switch (ctrl) { case 32: *op++ = *ip++; case 31: *op++ = *ip++; case 30: *op++ = *ip++; case 29: *op++ = *ip++; case 28: *op++ = *ip++; case 27: *op++ = *ip++; case 26: *op++ = *ip++; case 25: *op++ = *ip++; case 24: *op++ = *ip++; case 23: *op++ = *ip++; case 22: *op++ = *ip++; case 21: *op++ = *ip++; case 20: *op++ = *ip++; case 19: *op++ = *ip++; case 18: *op++ = *ip++; case 17: *op++ = *ip++; case 16: *op++ = *ip++; case 15: *op++ = *ip++; case 14: *op++ = *ip++; case 13: *op++ = *ip++; case 12: *op++ = *ip++; case 11: *op++ = *ip++; case 10: *op++ = *ip++; case 9: *op++ = *ip++; case 8: *op++ = *ip++; case 7: *op++ = *ip++; case 6: *op++ = *ip++; case 5: *op++ = *ip++; case 4: *op++ = *ip++; case 3: *op++ = *ip++; case 2: *op++ = *ip++; case 1: *op++ = *ip++; } #endif } else /* back reference */ { unsigned int len = ctrl >> 5; u8 *ref = op - ((ctrl & 0x1f) << 8) - 1; #if CHECK_INPUT if (ip >= in_end) { SET_ERRNO (EINVAL); return 0; } #endif if (len == 7) { len += *ip++; #if CHECK_INPUT if (ip >= in_end) { SET_ERRNO (EINVAL); return 0; } #endif } ref -= *ip++; if (op + len + 2 > out_end) { SET_ERRNO (E2BIG); return 0; } if (ref < (u8 *)out_data) { SET_ERRNO (EINVAL); return 0; } #ifdef lzf_movsb len += 2; lzf_movsb (op, ref, len); #else switch (len) { default: len += 2; if (op >= ref + len) { /* disjunct areas */ memcpy (op, ref, len); op += len; } else { /* overlapping, use octte by octte copying */ do *op++ = *ref++; while (--len); } break; case 9: *op++ = *ref++; /* fall-thru */ case 8: *op++ = *ref++; /* fall-thru */ case 7: *op++ = *ref++; /* fall-thru */ case 6: *op++ = *ref++; /* fall-thru */ case 5: *op++ = *ref++; /* fall-thru */ case 4: *op++ = *ref++; /* fall-thru */ case 3: *op++ = *ref++; /* fall-thru */ case 2: *op++ = *ref++; /* fall-thru */ case 1: *op++ = *ref++; /* fall-thru */ case 0: *op++ = *ref++; /* two octets more */ *op++ = *ref++; /* fall-thru */ } #endif } } return op - (u8 *)out_data; } #if defined(__GNUC__) && __GNUC__ >= 5 #pragma GCC diagnostic pop #endif ================================================ FILE: src/redis/rax.c ================================================ /* Rax -- A radix tree implementation. * * Version 1.2 -- 7 February 2019 * * Copyright (c) 2017-2019, Redis Ltd. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include "rax.h" #ifndef RAX_MALLOC_INCLUDE #define RAX_MALLOC_INCLUDE "rax_malloc.h" #endif #include RAX_MALLOC_INCLUDE /* -------------------------------- Debugging ------------------------------ */ void raxDebugShowNode(const char *msg, raxNode *n); /* Turn debugging messages on/off by compiling with RAX_DEBUG_MSG macro on. * When RAX_DEBUG_MSG is defined by default Rax operations will emit a lot * of debugging info to the standard output, however you can still turn * debugging on/off in order to enable it only when you suspect there is an * operation causing a bug using the function raxSetDebugMsg(). */ #ifdef RAX_DEBUG_MSG #define debugf(...) \ if (raxDebugMsg) { \ printf("%s:%s:%d:\t", __FILE__, __func__, __LINE__); \ printf(__VA_ARGS__); \ fflush(stdout); \ } #define debugnode(msg, n) raxDebugShowNode(msg, n) #else #define debugf(...) #define debugnode(msg, n) #endif /* By default log debug info if RAX_DEBUG_MSG is defined. */ static int raxDebugMsg = 1; /* When debug messages are enabled, turn them on/off dynamically. By * default they are enabled. Set the state to 0 to disable, and 1 to * re-enable. */ void raxSetDebugMsg(int onoff) { raxDebugMsg = onoff; } /* ------------------------- raxStack functions -------------------------- * The raxStack is a simple stack of pointers that is capable of switching * from using a stack-allocated array to dynamic heap once a given number of * items are reached. It is used in order to retain the list of parent nodes * while walking the radix tree in order to implement certain operations that * need to navigate the tree upward. * ------------------------------------------------------------------------- */ /* Initialize the stack. */ static inline void raxStackInit(raxStack *ts) { ts->stack = ts->static_items; ts->items = 0; ts->maxitems = RAX_STACK_STATIC_ITEMS; ts->oom = 0; } /* Push an item into the stack, returns 1 on success, 0 on out of memory. */ static inline int raxStackPush(raxStack *ts, void *ptr) { if (ts->items == ts->maxitems) { if (ts->stack == ts->static_items) { ts->stack = rax_malloc(sizeof(void *) * ts->maxitems * 2); if (ts->stack == NULL) { ts->stack = ts->static_items; ts->oom = 1; errno = ENOMEM; return 0; } memcpy(ts->stack, ts->static_items, sizeof(void *) * ts->maxitems); } else { void **newalloc = rax_realloc(ts->stack, sizeof(void *) * ts->maxitems * 2); if (newalloc == NULL) { ts->oom = 1; errno = ENOMEM; return 0; } ts->stack = newalloc; } ts->maxitems *= 2; } ts->stack[ts->items] = ptr; ts->items++; return 1; } /* Pop an item from the stack, the function returns NULL if there are no * items to pop. */ static inline void *raxStackPop(raxStack *ts) { if (ts->items == 0) return NULL; ts->items--; return ts->stack[ts->items]; } /* Return the stack item at the top of the stack without actually consuming * it. */ static inline void *raxStackPeek(raxStack *ts) { if (ts->items == 0) return NULL; return ts->stack[ts->items - 1]; } /* Free the stack in case we used heap allocation. */ static inline void raxStackFree(raxStack *ts) { if (ts->stack != ts->static_items) rax_free(ts->stack); } /* ---------------------------------------------------------------------------- * Radix tree implementation * --------------------------------------------------------------------------*/ /* Return the padding needed in the characters section of a node having size * 'nodesize'. The padding is needed to store the child pointers to aligned * addresses. Note that we add 4 to the node size because the node has a four * bytes header. */ #define raxPadding(nodesize) ((sizeof(void *) - (((nodesize) + 4) % sizeof(void *))) & (sizeof(void *) - 1)) /* Return the pointer to the last child pointer in a node. For the compressed * nodes this is the only child pointer. */ #define raxNodeLastChildPtr(n) \ ((raxNode **)(((char *)(n)) + raxNodeCurrentLength(n) - sizeof(raxNode *) - \ (((n)->iskey && !(n)->isnull) ? sizeof(void *) : 0))) /* Return the pointer to the first child pointer. */ #define raxNodeFirstChildPtr(n) ((raxNode **)((n)->data + (n)->size + raxPadding((n)->size))) /* Return the current total size of the node. Note that the second line * computes the padding after the string of characters, needed in order to * save pointers to aligned addresses. */ #define raxNodeCurrentLength(n) \ (sizeof(raxNode) + (n)->size + raxPadding((n)->size) + \ ((n)->iscompr ? sizeof(raxNode *) : sizeof(raxNode *) * (n)->size) + \ (((n)->iskey && !(n)->isnull) * sizeof(void *))) /* Allocate a new non compressed node with the specified number of children. * If datafield is true, the allocation is made large enough to hold the * associated data pointer. * Returns the new node pointer. On out of memory NULL is returned. */ raxNode *raxNewNode(size_t children, int datafield) { size_t nodesize = sizeof(raxNode) + children + raxPadding(children) + sizeof(raxNode *) * children; if (datafield) nodesize += sizeof(void *); raxNode *node = rax_malloc(nodesize); if (node == NULL) return NULL; node->iskey = 0; node->isnull = 0; node->iscompr = 0; node->size = children; return node; } /* Allocate a new rax and return its pointer. On out of memory the function * returns NULL. */ rax *raxNew(void) { rax *rax = rax_malloc(sizeof(*rax)); if (rax == NULL) return NULL; rax->numele = 0; rax->numnodes = 1; rax->head = raxNewNode(0, 0); if (rax->head == NULL) { rax_free(rax); return NULL; } else { rax->alloc_size = rax_ptr_alloc_size(rax) + rax_ptr_alloc_size(rax->head); return rax; } } /* realloc the node to make room for auxiliary data in order * to store an item in that node. On out of memory NULL is returned. */ raxNode *raxReallocForData(raxNode *n, void *data) { if (data == NULL) return n; /* No reallocation needed, setting isnull=1 */ size_t curlen = raxNodeCurrentLength(n); return rax_realloc(n, curlen + sizeof(void *)); } /* Set the node auxiliary data to the specified pointer. */ void raxSetData(raxNode *n, void *data) { n->iskey = 1; if (data != NULL) { n->isnull = 0; void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *)); memcpy(ndata, &data, sizeof(data)); } else { n->isnull = 1; } } /* Get the node auxiliary data. */ void *raxGetData(raxNode *n) { if (n->isnull) return NULL; void **ndata = (void **)((char *)n + raxNodeCurrentLength(n) - sizeof(void *)); void *data; memcpy(&data, ndata, sizeof(data)); return data; } /* Add a new child to the node 'n' representing the character 'c' and return * its new pointer, as well as the child pointer by reference. Additionally * '***parentlink' is populated with the raxNode pointer-to-pointer of where * the new child was stored, which is useful for the caller to replace the * child pointer if it gets reallocated. * * On success the new parent node pointer is returned (it may change because * of the realloc, so the caller should discard 'n' and use the new value). * On out of memory NULL is returned, and the old node is still valid. */ raxNode *raxAddChild(raxNode *n, unsigned char c, raxNode **childptr, raxNode ***parentlink) { assert(n->iscompr == 0); size_t curlen = raxNodeCurrentLength(n); n->size++; size_t newlen = raxNodeCurrentLength(n); n->size--; /* For now restore the original size. We'll update it only on success at the end. */ /* Alloc the new child we will link to 'n'. */ raxNode *child = raxNewNode(0, 0); if (child == NULL) return NULL; /* Make space in the original node. */ raxNode *newn = rax_realloc(n, newlen); if (newn == NULL) { rax_free(child); return NULL; } n = newn; /* After the reallocation, we have up to 8/16 (depending on the system * pointer size, and the required node padding) bytes at the end, that is, * the additional char in the 'data' section, plus one pointer to the new * child, plus the padding needed in order to store addresses into aligned * locations. * * So if we start with the following node, having "abde" edges. * * Note: * - We assume 4 bytes pointer for simplicity. * - Each space below corresponds to one byte * * [HDR*][abde][Aptr][Bptr][Dptr][Eptr]|AUXP| * * After the reallocation we need: 1 byte for the new edge character * plus 4 bytes for a new child pointer (assuming 32 bit machine). * However after adding 1 byte to the edge char, the header + the edge * characters are no longer aligned, so we also need 3 bytes of padding. * In total the reallocation will add 1+4+3 bytes = 8 bytes: * * (Blank bytes are represented by ".") * * [HDR*][abde][Aptr][Bptr][Dptr][Eptr]|AUXP|[....][....] * * Let's find where to insert the new child in order to make sure * it is inserted in-place lexicographically. Assuming we are adding * a child "c" in our case pos will be = 2 after the end of the following * loop. */ int pos; for (pos = 0; pos < n->size; pos++) { if (n->data[pos] > c) break; } /* Now, if present, move auxiliary data pointer at the end * so that we can mess with the other data without overwriting it. * We will obtain something like that: * * [HDR*][abde][Aptr][Bptr][Dptr][Eptr][....][....]|AUXP| */ unsigned char *src, *dst; if (n->iskey && !n->isnull) { src = ((unsigned char *)n + curlen - sizeof(void *)); dst = ((unsigned char *)n + newlen - sizeof(void *)); memmove(dst, src, sizeof(void *)); } /* Compute the "shift", that is, how many bytes we need to move the * pointers section forward because of the addition of the new child * byte in the string section. Note that if we had no padding, that * would be always "1", since we are adding a single byte in the string * section of the node (where now there is "abde" basically). * * However we have padding, so it could be zero, or up to 8. * * Another way to think at the shift is, how many bytes we need to * move child pointers forward *other than* the obvious sizeof(void*) * needed for the additional pointer itself. */ size_t shift = newlen - curlen - sizeof(void *); /* We said we are adding a node with edge 'c'. The insertion * point is between 'b' and 'd', so the 'pos' variable value is * the index of the first child pointer that we need to move forward * to make space for our new pointer. * * To start, move all the child pointers after the insertion point * of shift+sizeof(pointer) bytes on the right, to obtain: * * [HDR*][abde][Aptr][Bptr][....][....][Dptr][Eptr]|AUXP| */ src = n->data + n->size + raxPadding(n->size) + sizeof(raxNode *) * pos; memmove(src + shift + sizeof(raxNode *), src, sizeof(raxNode *) * (n->size - pos)); /* Move the pointers to the left of the insertion position as well. Often * we don't need to do anything if there was already some padding to use. In * that case the final destination of the pointers will be the same, however * in our example there was no pre-existing padding, so we added one byte * plus three bytes of padding. After the next memmove() things will look * like that: * * [HDR*][abde][....][Aptr][Bptr][....][Dptr][Eptr]|AUXP| */ if (shift) { src = (unsigned char *)raxNodeFirstChildPtr(n); memmove(src + shift, src, sizeof(raxNode *) * pos); } /* Now make the space for the additional char in the data section, * but also move the pointers before the insertion point to the right * by shift bytes, in order to obtain the following: * * [HDR*][ab.d][e...][Aptr][Bptr][....][Dptr][Eptr]|AUXP| */ src = n->data + pos; memmove(src + 1, src, n->size - pos); /* We can now set the character and its child node pointer to get: * * [HDR*][abcd][e...][Aptr][Bptr][....][Dptr][Eptr]|AUXP| * [HDR*][abcd][e...][Aptr][Bptr][Cptr][Dptr][Eptr]|AUXP| */ n->data[pos] = c; n->size++; src = (unsigned char *)raxNodeFirstChildPtr(n); raxNode **childfield = (raxNode **)(src + sizeof(raxNode *) * pos); memcpy(childfield, &child, sizeof(child)); *childptr = child; *parentlink = childfield; return n; } /* Turn the node 'n', that must be a node without any children, into a * compressed node representing a set of nodes linked one after the other * and having exactly one child each. The node can be a key or not: this * property and the associated value if any will be preserved. * * The function also returns a child node, since the last node of the * compressed chain cannot be part of the chain: it has zero children while * we can only compress inner nodes with exactly one child each. */ raxNode *raxCompressNode(raxNode *n, unsigned char *s, size_t len, raxNode **child) { assert(n->size == 0 && n->iscompr == 0); void *data = NULL; /* Initialized only to avoid warnings. */ size_t newsize; debugf("Compress node: %.*s\n", (int)len, s); /* Allocate the child to link to this node. */ *child = raxNewNode(0, 0); if (*child == NULL) return NULL; /* Make space in the parent node. */ newsize = sizeof(raxNode) + len + raxPadding(len) + sizeof(raxNode *); if (n->iskey) { data = raxGetData(n); /* To restore it later. */ if (!n->isnull) newsize += sizeof(void *); } raxNode *newn = rax_realloc(n, newsize); if (newn == NULL) { rax_free(*child); return NULL; } n = newn; n->iscompr = 1; n->size = len; memcpy(n->data, s, len); if (n->iskey) raxSetData(n, data); raxNode **childfield = raxNodeLastChildPtr(n); memcpy(childfield, child, sizeof(*child)); return n; } /* Low level function that walks the tree looking for the string * 's' of 'len' bytes. The function returns the number of characters * of the key that was possible to process: if the returned integer * is the same as 'len', then it means that the node corresponding to the * string was found (however it may not be a key in case the node->iskey is * zero or if simply we stopped in the middle of a compressed node, so that * 'splitpos' is non zero). * * Otherwise if the returned integer is not the same as 'len', there was an * early stop during the tree walk because of a character mismatch. * * The node where the search ended (because the full string was processed * or because there was an early stop) is returned by reference as * '*stopnode' if the passed pointer is not NULL. This node link in the * parent's node is returned as '*plink' if not NULL. Finally, if the * search stopped in a compressed node, '*splitpos' returns the index * inside the compressed node where the search ended. This is useful to * know where to split the node for insertion. * * Note that when we stop in the middle of a compressed node with * a perfect match, this function will return a length equal to the * 'len' argument (all the key matched), and will return a *splitpos which is * always positive (that will represent the index of the character immediately * *after* the last match in the current compressed node). * * When instead we stop at a compressed node and *splitpos is zero, it * means that the current node represents the key (that is, none of the * compressed node characters are needed to represent the key, just all * its parents nodes). */ static inline size_t raxLowWalk(rax *rax, unsigned char *s, size_t len, raxNode **stopnode, raxNode ***plink, int *splitpos, raxStack *ts) { raxNode *h = rax->head; raxNode **parentlink = &rax->head; size_t i = 0; /* Position in the string. */ size_t j = 0; /* Position in the node children (or bytes if compressed).*/ while (h->size && i < len) { debugnode("Lookup current node", h); unsigned char *v = h->data; if (h->iscompr) { for (j = 0; j < h->size && i < len; j++, i++) { if (v[j] != s[i]) break; } if (j != h->size) break; } else { /* Even when h->size is large, linear scan provides good * performances compared to other approaches that are in theory * more sounding, like performing a binary search. */ for (j = 0; j < h->size; j++) { if (v[j] == s[i]) break; } if (j == h->size) break; i++; } if (ts) raxStackPush(ts, h); /* Save stack of parent nodes. */ raxNode **children = raxNodeFirstChildPtr(h); if (h->iscompr) j = 0; /* Compressed node only child is at index 0. */ memcpy(&h, children + j, sizeof(h)); parentlink = children + j; j = 0; /* If the new node is non compressed and we do not iterate again (since i == len) set the split position to 0 to signal this node represents the searched key. */ } debugnode("Lookup stop node is", h); if (stopnode) *stopnode = h; if (plink) *plink = parentlink; if (splitpos && h->iscompr) *splitpos = j; return i; } /* Insert the element 's' of size 'len', setting as auxiliary data * the pointer 'data'. If the element is already present, the associated * data is updated (only if 'overwrite' is set to 1), and 0 is returned, * otherwise the element is inserted and 1 is returned. On out of memory the * function returns 0 as well but sets errno to ENOMEM, otherwise errno will * be set to 0. */ int raxGenericInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old, int overwrite) { size_t i; int j = 0; /* Split position. If raxLowWalk() stops in a compressed node, the index 'j' represents the char we stopped within the compressed node, that is, the position where to split the node for insertion. */ raxNode *h, **parentlink; debugf("### Insert %.*s with value %p\n", (int)len, s, data); i = raxLowWalk(rax, s, len, &h, &parentlink, &j, NULL); /* If i == len we walked following the whole string. If we are not * in the middle of a compressed node, the string is either already * inserted or this middle node is currently not a key, but can represent * our key. We have just to reallocate the node and make space for the * data pointer. */ if (i == len && (!h->iscompr || j == 0 /* not in the middle if j is 0 */)) { debugf("### Insert: node representing key exists\n"); /* Make space for the value pointer if needed. */ if (!h->iskey || (h->isnull && overwrite)) { size_t oldalloc = rax_ptr_alloc_size(h); h = raxReallocForData(h, data); if (h) { memcpy(parentlink, &h, sizeof(h)); rax->alloc_size = rax->alloc_size - oldalloc + rax_ptr_alloc_size(h); } } if (h == NULL) { errno = ENOMEM; return 0; } /* Update the existing key if there is already one. */ if (h->iskey) { if (old) *old = raxGetData(h); if (overwrite) raxSetData(h, data); errno = 0; return 0; /* Element already exists. */ } /* Otherwise set the node as a key. Note that raxSetData() * will set h->iskey. */ raxSetData(h, data); rax->numele++; return 1; /* Element inserted. */ } /* If the node we stopped at is a compressed node, we need to * split it before to continue. * * Splitting a compressed node have a few possible cases. * Imagine that the node 'h' we are currently at is a compressed * node containing the string "ANNIBALE" (it means that it represents * nodes A -> N -> N -> I -> B -> A -> L -> E with the only child * pointer of this node pointing at the 'E' node, because remember that * we have characters at the edges of the graph, not inside the nodes * themselves. * * In order to show a real case imagine our node to also point to * another compressed node, that finally points at the node without * children, representing 'O': * * "ANNIBALE" -> "SCO" -> [] * * When inserting we may face the following cases. Note that all the cases * require the insertion of a non compressed node with exactly two * children, except for the last case which just requires splitting a * compressed node. * * 1) Inserting "ANNIENTARE" * * |B| -> "ALE" -> "SCO" -> [] * "ANNI" -> |-| * |E| -> (... continue algo ...) "NTARE" -> [] * * 2) Inserting "ANNIBALI" * * |E| -> "SCO" -> [] * "ANNIBAL" -> |-| * |I| -> (... continue algo ...) [] * * 3) Inserting "AGO" (Like case 1, but set iscompr = 0 into original node) * * |N| -> "NIBALE" -> "SCO" -> [] * |A| -> |-| * |G| -> (... continue algo ...) |O| -> [] * * 4) Inserting "CIAO" * * |A| -> "NNIBALE" -> "SCO" -> [] * |-| * |C| -> (... continue algo ...) "IAO" -> [] * * 5) Inserting "ANNI" * * "ANNI" -> "BALE" -> "SCO" -> [] * * The final algorithm for insertion covering all the above cases is as * follows. * * ============================= ALGO 1 ============================= * * For the above cases 1 to 4, that is, all cases where we stopped in * the middle of a compressed node for a character mismatch, do: * * Let $SPLITPOS be the zero-based index at which, in the * compressed node array of characters, we found the mismatching * character. For example if the node contains "ANNIBALE" and we add * "ANNIENTARE" the $SPLITPOS is 4, that is, the index at which the * mismatching character is found. * * 1. Save the current compressed node $NEXT pointer (the pointer to the * child element, that is always present in compressed nodes). * * 2. Create "split node" having as child the non common letter * at the compressed node. The other non common letter (at the key) * will be added later as we continue the normal insertion algorithm * at step "6". * * 3a. IF $SPLITPOS == 0: * Replace the old node with the split node, by copying the auxiliary * data if any. Fix parent's reference. Free old node eventually * (we still need its data for the next steps of the algorithm). * * 3b. IF $SPLITPOS != 0: * Trim the compressed node (reallocating it as well) in order to * contain $splitpos characters. Change child pointer in order to link * to the split node. If new compressed node len is just 1, set * iscompr to 0 (layout is the same). Fix parent's reference. * * 4a. IF the postfix len (the length of the remaining string of the * original compressed node after the split character) is non zero, * create a "postfix node". If the postfix node has just one character * set iscompr to 0, otherwise iscompr to 1. Set the postfix node * child pointer to $NEXT. * * 4b. IF the postfix len is zero, just use $NEXT as postfix pointer. * * 5. Set child[0] of split node to postfix node. * * 6. Set the split node as the current node, set current index at child[1] * and continue insertion algorithm as usually. * * ============================= ALGO 2 ============================= * * For case 5, that is, if we stopped in the middle of a compressed * node but no mismatch was found, do: * * Let $SPLITPOS be the zero-based index at which, in the * compressed node array of characters, we stopped iterating because * there were no more keys character to match. So in the example of * the node "ANNIBALE", adding the string "ANNI", the $SPLITPOS is 4. * * 1. Save the current compressed node $NEXT pointer (the pointer to the * child element, that is always present in compressed nodes). * * 2. Create a "postfix node" containing all the characters from $SPLITPOS * to the end. Use $NEXT as the postfix node child pointer. * If the postfix node length is 1, set iscompr to 0. * Set the node as a key with the associated value of the new * inserted key. * * 3. Trim the current node to contain the first $SPLITPOS characters. * As usually if the new node length is just 1, set iscompr to 0. * Take the iskey / associated value as it was in the original node. * Fix the parent's reference. * * 4. Set the postfix node as the only child pointer of the trimmed * node created at step 1. */ /* ------------------------- ALGORITHM 1 --------------------------- */ if (h->iscompr && i != len) { debugf("ALGO 1: Stopped at compressed node %.*s (%p)\n", h->size, h->data, (void *)h); debugf("Still to insert: %.*s\n", (int)(len - i), s + i); debugf("Splitting at %d: '%c'\n", j, ((char *)h->data)[j]); debugf("Other (key) letter is '%c'\n", s[i]); /* 1: Save next pointer. */ raxNode **childfield = raxNodeLastChildPtr(h); raxNode *next; memcpy(&next, childfield, sizeof(next)); debugf("Next is %p\n", (void *)next); debugf("iskey %d\n", h->iskey); if (h->iskey) { debugf("key value is %p\n", raxGetData(h)); } /* Set the length of the additional nodes we will need. */ size_t trimmedlen = j; size_t postfixlen = h->size - j - 1; int split_node_is_key = !trimmedlen && h->iskey && !h->isnull; size_t nodesize; /* 2: Create the split node. Also allocate the other nodes we'll need * ASAP, so that it will be simpler to handle OOM. */ raxNode *splitnode = raxNewNode(1, split_node_is_key); raxNode *trimmed = NULL; raxNode *postfix = NULL; if (trimmedlen) { nodesize = sizeof(raxNode) + trimmedlen + raxPadding(trimmedlen) + sizeof(raxNode *); if (h->iskey && !h->isnull) nodesize += sizeof(void *); trimmed = rax_malloc(nodesize); } if (postfixlen) { nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *); postfix = rax_malloc(nodesize); } /* OOM? Abort now that the tree is untouched. */ if (splitnode == NULL || (trimmedlen && trimmed == NULL) || (postfixlen && postfix == NULL)) { rax_free(splitnode); rax_free(trimmed); rax_free(postfix); errno = ENOMEM; return 0; } splitnode->data[0] = h->data[j]; rax->alloc_size += rax_ptr_alloc_size(splitnode); if (j == 0) { /* 3a: Replace the old node with the split node. */ if (h->iskey) { void *ndata = raxGetData(h); raxSetData(splitnode, ndata); } memcpy(parentlink, &splitnode, sizeof(splitnode)); } else { /* 3b: Trim the compressed node. */ trimmed->size = j; memcpy(trimmed->data, h->data, j); trimmed->iscompr = j > 1 ? 1 : 0; trimmed->iskey = h->iskey; trimmed->isnull = h->isnull; if (h->iskey && !h->isnull) { void *ndata = raxGetData(h); raxSetData(trimmed, ndata); } raxNode **cp = raxNodeLastChildPtr(trimmed); memcpy(cp, &splitnode, sizeof(splitnode)); memcpy(parentlink, &trimmed, sizeof(trimmed)); parentlink = cp; /* Set parentlink to splitnode parent. */ rax->numnodes++; rax->alloc_size += rax_ptr_alloc_size(trimmed); } /* 4: Create the postfix node: what remains of the original * compressed node after the split. */ if (postfixlen) { /* 4a: create a postfix node. */ postfix->iskey = 0; postfix->isnull = 0; postfix->size = postfixlen; postfix->iscompr = postfixlen > 1; memcpy(postfix->data, h->data + j + 1, postfixlen); raxNode **cp = raxNodeLastChildPtr(postfix); memcpy(cp, &next, sizeof(next)); rax->numnodes++; rax->alloc_size += rax_ptr_alloc_size(postfix); } else { /* 4b: just use next as postfix node. */ postfix = next; } /* 5: Set splitnode first child as the postfix node. */ raxNode **splitchild = raxNodeLastChildPtr(splitnode); memcpy(splitchild, &postfix, sizeof(postfix)); /* 6. Continue insertion: this will cause the splitnode to * get a new child (the non common character at the currently * inserted key). */ rax->alloc_size -= rax_ptr_alloc_size(h); rax_free(h); h = splitnode; } else if (h->iscompr && i == len) { /* ------------------------- ALGORITHM 2 --------------------------- */ debugf("ALGO 2: Stopped at compressed node %.*s (%p) j = %d\n", h->size, h->data, (void *)h, j); /* Allocate postfix & trimmed nodes ASAP to fail for OOM gracefully. */ size_t postfixlen = h->size - j; size_t nodesize = sizeof(raxNode) + postfixlen + raxPadding(postfixlen) + sizeof(raxNode *); if (data != NULL) nodesize += sizeof(void *); raxNode *postfix = rax_malloc(nodesize); nodesize = sizeof(raxNode) + j + raxPadding(j) + sizeof(raxNode *); if (h->iskey && !h->isnull) nodesize += sizeof(void *); raxNode *trimmed = rax_malloc(nodesize); if (postfix == NULL || trimmed == NULL) { rax_free(postfix); rax_free(trimmed); errno = ENOMEM; return 0; } /* 1: Save next pointer. */ raxNode **childfield = raxNodeLastChildPtr(h); raxNode *next; memcpy(&next, childfield, sizeof(next)); /* 2: Create the postfix node. */ postfix->size = postfixlen; postfix->iscompr = postfixlen > 1; postfix->iskey = 1; postfix->isnull = 0; memcpy(postfix->data, h->data + j, postfixlen); raxSetData(postfix, data); raxNode **cp = raxNodeLastChildPtr(postfix); memcpy(cp, &next, sizeof(next)); rax->numnodes++; rax->alloc_size += rax_ptr_alloc_size(postfix); /* 3: Trim the compressed node. */ trimmed->size = j; trimmed->iscompr = j > 1; trimmed->iskey = 0; trimmed->isnull = 0; memcpy(trimmed->data, h->data, j); memcpy(parentlink, &trimmed, sizeof(trimmed)); if (h->iskey) { void *aux = raxGetData(h); raxSetData(trimmed, aux); } rax->alloc_size += rax_ptr_alloc_size(trimmed); /* Fix the trimmed node child pointer to point to * the postfix node. */ cp = raxNodeLastChildPtr(trimmed); memcpy(cp, &postfix, sizeof(postfix)); /* Finish! We don't need to continue with the insertion * algorithm for ALGO 2. The key is already inserted. */ rax->numele++; rax->alloc_size -= rax_ptr_alloc_size(h); rax_free(h); return 1; /* Key inserted. */ } /* We walked the radix tree as far as we could, but still there are left * chars in our string. We need to insert the missing nodes. */ while (i < len) { raxNode *child; size_t oldalloc = rax_ptr_alloc_size(h); /* If this node is going to have a single child, and there * are other characters, so that that would result in a chain * of single-childed nodes, turn it into a compressed node. */ if (h->size == 0 && len - i > 1) { debugf("Inserting compressed node\n"); size_t comprsize = len - i; if (comprsize > RAX_NODE_MAX_SIZE) comprsize = RAX_NODE_MAX_SIZE; raxNode *newh = raxCompressNode(h, s + i, comprsize, &child); if (newh == NULL) goto oom; h = newh; memcpy(parentlink, &h, sizeof(h)); parentlink = raxNodeLastChildPtr(h); i += comprsize; } else { debugf("Inserting normal node\n"); raxNode **new_parentlink; raxNode *newh = raxAddChild(h, s[i], &child, &new_parentlink); if (newh == NULL) goto oom; h = newh; memcpy(parentlink, &h, sizeof(h)); parentlink = new_parentlink; i++; } rax->numnodes++; rax->alloc_size = rax->alloc_size - oldalloc + rax_ptr_alloc_size(h) + rax_ptr_alloc_size(child); h = child; } size_t oldalloc = rax_ptr_alloc_size(h); raxNode *newh = raxReallocForData(h, data); if (newh == NULL) goto oom; h = newh; if (!h->iskey) rax->numele++; raxSetData(h, data); memcpy(parentlink, &h, sizeof(h)); rax->alloc_size = rax->alloc_size - oldalloc + rax_ptr_alloc_size(h); return 1; /* Element inserted. */ oom: /* This code path handles out of memory after part of the sub-tree was * already modified. Set the node as a key, and then remove it. However we * do that only if the node is a terminal node, otherwise if the OOM * happened reallocating a node in the middle, we don't need to free * anything. */ if (h->size == 0) { h->isnull = 1; h->iskey = 1; rax->numele++; /* Compensate the next remove. */ checkedRaxRemove(rax, s, i, NULL); } errno = ENOMEM; return 0; } /* Overwriting insert. Just a wrapper for raxGenericInsert() that will * update the element if there is already one for the same key. */ int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) { return raxGenericInsert(rax, s, len, data, old, 1); } /* Non overwriting insert function: if an element with the same key * exists, the value is not updated and the function returns 0. * This is just a wrapper for raxGenericInsert(). */ int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old) { return raxGenericInsert(rax, s, len, data, old, 0); } /* Find a key in the rax: return 1 if the item is found, 0 otherwise. * If there is an item and 'value' is passed in a non-NULL pointer, * the value associated with the item is set at that address. */ int raxFind(rax *rax, unsigned char *s, size_t len, void **value) { raxNode *h; debugf("### Lookup: %.*s\n", (int)len, s); int splitpos = 0; size_t i = raxLowWalk(rax, s, len, &h, NULL, &splitpos, NULL); if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) return 0; if (value != NULL) *value = raxGetData(h); return 1; } /* Return the memory address where the 'parent' node stores the specified * 'child' pointer, so that the caller can update the pointer with another * one if needed. The function assumes it will find a match, otherwise the * operation is an undefined behavior (it will continue scanning the * memory without any bound checking). */ raxNode **raxFindParentLink(raxNode *parent, raxNode *child) { raxNode **cp = raxNodeFirstChildPtr(parent); raxNode *c; while (1) { memcpy(&c, cp, sizeof(c)); if (c == child) break; cp++; } return cp; } /* Low level child removal from node. The new node pointer (after the child * removal) is returned. Note that this function does not fix the pointer * of the parent node in its parent, so this task is up to the caller. * The function never fails for out of memory. */ raxNode *raxRemoveChild(raxNode *parent, raxNode *child) { debugnode("raxRemoveChild before", parent); /* If parent is a compressed node (having a single child, as for definition * of the data structure), the removal of the child consists into turning * it into a normal node without children. */ if (parent->iscompr) { void *data = NULL; if (parent->iskey) data = raxGetData(parent); parent->isnull = 0; parent->iscompr = 0; parent->size = 0; if (parent->iskey) raxSetData(parent, data); debugnode("raxRemoveChild after", parent); return parent; } /* Otherwise we need to scan for the child pointer and memmove() * accordingly. * * 1. To start we seek the first element in both the children * pointers and edge bytes in the node. */ raxNode **cp = raxNodeFirstChildPtr(parent); raxNode **c = cp; unsigned char *e = parent->data; /* 2. Search the child pointer to remove inside the array of children * pointers. */ while (1) { raxNode *aux; memcpy(&aux, c, sizeof(aux)); if (aux == child) break; c++; e++; } /* 3. Remove the edge and the pointer by memmoving the remaining children * pointer and edge bytes one position before. */ int taillen = parent->size - (e - parent->data) - 1; debugf("raxRemoveChild tail len: %d\n", taillen); memmove(e, e + 1, taillen); /* Compute the shift, that is the amount of bytes we should move our * child pointers to the left, since the removal of one edge character * and the corresponding padding change, may change the layout. * We just check if in the old version of the node there was at the * end just a single byte and all padding: in that case removing one char * will remove a whole sizeof(void*) word. */ size_t shift = ((parent->size + 4) % sizeof(void *)) == 1 ? sizeof(void *) : 0; /* Move the children pointers before the deletion point. */ if (shift) memmove(((char *)cp) - shift, cp, (parent->size - taillen - 1) * sizeof(raxNode **)); /* Move the remaining "tail" pointers at the right position as well. */ size_t valuelen = (parent->iskey && !parent->isnull) ? sizeof(void *) : 0; memmove(((char *)c) - shift, c + 1, taillen * sizeof(raxNode **) + valuelen); /* 4. Update size. */ parent->size--; /* realloc the node according to the theoretical memory usage, to free * data if we are over-allocating right now. */ raxNode *newnode = rax_realloc(parent, raxNodeCurrentLength(parent)); if (newnode) { debugnode("raxRemoveChild after", newnode); } /* Note: if rax_realloc() fails we just return the old address, which * is valid. */ return newnode ? newnode : parent; } /* Remove the specified item. Returns 1 if the item was found and * deleted, 0 otherwise. */ int raxRemove(rax *rax, unsigned char *s, size_t len, void **old) { raxNode *h; raxStack ts; debugf("### Delete: %.*s\n", (int)len, s); raxStackInit(&ts); int splitpos = 0; size_t i = raxLowWalk(rax, s, len, &h, NULL, &splitpos, &ts); if (i != len || (h->iscompr && splitpos != 0) || !h->iskey) { raxStackFree(&ts); return 0; } if (old) *old = raxGetData(h); h->iskey = 0; rax->numele--; /* If this node has no children, the deletion needs to reclaim the * no longer used nodes. This is an iterative process that needs to * walk the three upward, deleting all the nodes with just one child * that are not keys, until the head of the rax is reached or the first * node with more than one child is found. */ int trycompress = 0; /* Will be set to 1 if we should try to optimize the tree resulting from the deletion. */ if (h->size == 0) { debugf("Key deleted in node without children. Cleanup needed.\n"); raxNode *child = NULL; while (h != rax->head) { child = h; debugf("Freeing child %p [%.*s] key:%d\n", (void *)child, (int)child->size, (char *)child->data, child->iskey); rax->alloc_size -= rax_ptr_alloc_size(child); rax_free(child); rax->numnodes--; h = raxStackPop(&ts); /* If this node has more then one child, or actually holds * a key, stop here. */ if (h->iskey || (!h->iscompr && h->size != 1)) break; } if (child) { debugf("Unlinking child %p from parent %p\n", (void *)child, (void *)h); size_t oldalloc = rax_ptr_alloc_size(h); raxNode *new = raxRemoveChild(h, child); rax->alloc_size = rax->alloc_size - oldalloc + rax_ptr_alloc_size(new); if (new != h) { raxNode *parent = raxStackPeek(&ts); raxNode **parentlink; if (parent == NULL) { parentlink = &rax->head; } else { parentlink = raxFindParentLink(parent, h); } memcpy(parentlink, &new, sizeof(new)); } /* If after the removal the node has just a single child * and is not a key, we need to try to compress it. */ if (new->size == 1 && new->iskey == 0) { trycompress = 1; h = new; } } } else if (h->size == 1) { /* If the node had just one child, after the removal of the key * further compression with adjacent nodes is potentially possible. */ trycompress = 1; } /* Don't try node compression if our nodes pointers stack is not * complete because of OOM while executing raxLowWalk() */ if (trycompress && ts.oom) trycompress = 0; /* Recompression: if trycompress is true, 'h' points to a radix tree node * that changed in a way that could allow to compress nodes in this * sub-branch. Compressed nodes represent chains of nodes that are not * keys and have a single child, so there are two deletion events that * may alter the tree so that further compression is needed: * * 1) A node with a single child was a key and now no longer is a key. * 2) A node with two children now has just one child. * * We try to navigate upward till there are other nodes that can be * compressed, when we reach the upper node which is not a key and has * a single child, we scan the chain of children to collect the * compressible part of the tree, and replace the current node with the * new one, fixing the child pointer to reference the first non * compressible node. * * Example of case "1". A tree stores the keys "FOO" = 1 and * "FOOBAR" = 2: * * * "FOO" -> "BAR" -> [] (2) * (1) * * After the removal of "FOO" the tree can be compressed as: * * "FOOBAR" -> [] (2) * * * Example of case "2". A tree stores the keys "FOOBAR" = 1 and * "FOOTER" = 2: * * |B| -> "AR" -> [] (1) * "FOO" -> |-| * |T| -> "ER" -> [] (2) * * After the removal of "FOOTER" the resulting tree is: * * "FOO" -> |B| -> "AR" -> [] (1) * * That can be compressed into: * * "FOOBAR" -> [] (1) */ if (trycompress) { debugf("After removing %.*s:\n", (int)len, s); debugnode("Compression may be needed", h); debugf("Seek start node\n"); /* Try to reach the upper node that is compressible. * At the end of the loop 'h' will point to the first node we * can try to compress and 'parent' to its parent. */ raxNode *parent; while (1) { parent = raxStackPop(&ts); if (!parent || parent->iskey || (!parent->iscompr && parent->size != 1)) break; h = parent; debugnode("Going up to", h); } raxNode *start = h; /* Compression starting node. */ /* Scan chain of nodes we can compress. */ size_t comprsize = h->size; int nodes = 1; while (h->size != 0) { raxNode **cp = raxNodeLastChildPtr(h); memcpy(&h, cp, sizeof(h)); if (h->iskey || (!h->iscompr && h->size != 1)) break; /* Stop here if going to the next node would result into * a compressed node larger than h->size can hold. */ if (comprsize + h->size > RAX_NODE_MAX_SIZE) break; nodes++; comprsize += h->size; } if (nodes > 1) { /* If we can compress, create the new node and populate it. */ size_t nodesize = sizeof(raxNode) + comprsize + raxPadding(comprsize) + sizeof(raxNode *); raxNode *new = rax_malloc(nodesize); /* An out of memory here just means we cannot optimize this * node, but the tree is left in a consistent state. */ if (new == NULL) { raxStackFree(&ts); return 1; } new->iskey = 0; new->isnull = 0; new->iscompr = 1; new->size = comprsize; rax->numnodes++; rax->alloc_size += rax_ptr_alloc_size(new); /* Scan again, this time to populate the new node content and * to fix the new node child pointer. At the same time we free * all the nodes that we'll no longer use. */ comprsize = 0; h = start; while (h->size != 0) { memcpy(new->data + comprsize, h->data, h->size); comprsize += h->size; raxNode **cp = raxNodeLastChildPtr(h); raxNode *tofree = h; memcpy(&h, cp, sizeof(h)); rax->alloc_size -= rax_ptr_alloc_size(tofree); rax_free(tofree); rax->numnodes--; if (h->iskey || (!h->iscompr && h->size != 1)) break; if (comprsize + h->size > RAX_NODE_MAX_SIZE) break; } debugnode("New node", new); /* Now 'h' points to the first node that we still need to use, * so our new node child pointer will point to it. */ raxNode **cp = raxNodeLastChildPtr(new); memcpy(cp, &h, sizeof(h)); /* Fix parent link. */ if (parent) { raxNode **parentlink = raxFindParentLink(parent, start); memcpy(parentlink, &new, sizeof(new)); } else { rax->head = new; } debugf("Compressed %d nodes, %d total bytes\n", nodes, (int)comprsize); } } raxStackFree(&ts); return 1; } /* This is the core of raxFree(): performs a depth-first scan of the * tree and releases all the nodes found. */ void raxRecursiveFree(rax *rax, raxNode *n, void (*free_callback)(void*, void*), void* argument) { debugnode("free traversing",n); int numchildren = n->iscompr ? 1 : n->size; raxNode **cp = raxNodeLastChildPtr(n); while (numchildren--) { raxNode *child; memcpy(&child, cp, sizeof(child)); raxRecursiveFree(rax,child,free_callback,argument); cp--; } debugnode("free depth-first", n); if (free_callback && n->iskey && !n->isnull) free_callback(raxGetData(n), argument); rax_free(n); rax->numnodes--; } /* Free the entire radix tree, invoking a free_callback function for each key's data. * An additional argument is passed to the free_callback function.*/ void raxFreeWithCallbackAndArgument(rax *rax, void (*free_callback)(void*, void*), void* argument) { raxRecursiveFree(rax,rax->head,free_callback, argument); assert(rax->numnodes == 0); rax_free(rax); } /* Wrapper for the callback to adapt it for the context */ void freeCallbackWrapper(void* data, void* argument) { if (!argument) { return; } void (*free_callback)(void*) = (void (*)(void*))argument; free_callback(data); } /* Free a whole radix tree, calling the specified callback in order to * free the auxiliary data. */ void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)) { raxFreeWithCallbackAndArgument(rax, freeCallbackWrapper, (void*)free_callback); } /* Free a whole radix tree. */ void raxFree(rax *rax) { raxFreeWithCallback(rax, NULL); } /* ------------------------------- Iterator --------------------------------- */ /* Initialize a Rax iterator. This call should be performed a single time * to initialize the iterator, and must be followed by a raxSeek() call, * otherwise the raxPrev()/raxNext() functions will just return EOF. */ void raxStart(raxIterator *it, rax *rt) { it->flags = RAX_ITER_EOF; /* No crash if the iterator is not seeked. */ it->rt = rt; it->key_len = 0; it->key = it->key_static_string; it->key_max = RAX_ITER_STATIC_LEN; it->data = NULL; it->node_cb = NULL; raxStackInit(&it->stack); } /* Append characters at the current key string of the iterator 'it'. This * is a low level function used to implement the iterator, not callable by * the user. Returns 0 on out of memory, otherwise 1 is returned. */ int raxIteratorAddChars(raxIterator *it, unsigned char *s, size_t len) { if (len == 0) return 1; if (it->key_max < it->key_len + len) { unsigned char *old = (it->key == it->key_static_string) ? NULL : it->key; size_t new_max = (it->key_len + len) * 2; it->key = rax_realloc(old, new_max); if (it->key == NULL) { it->key = (!old) ? it->key_static_string : old; errno = ENOMEM; return 0; } if (old == NULL) memcpy(it->key, it->key_static_string, it->key_len); it->key_max = new_max; } /* Use memmove since there could be an overlap between 's' and * it->key when we use the current key in order to re-seek. */ memmove(it->key + it->key_len, s, len); it->key_len += len; return 1; } /* Remove the specified number of chars from the right of the current * iterator key. */ void raxIteratorDelChars(raxIterator *it, size_t count) { it->key_len -= count; } /* Do an iteration step towards the next element. At the end of the step the * iterator key will represent the (new) current key. If it is not possible * to step in the specified direction since there are no longer elements, the * iterator is flagged with RAX_ITER_EOF. * * If 'noup' is true the function starts directly scanning for the next * lexicographically smaller children, and the current node is already assumed * to be the parent of the last key node, so the first operation to go back to * the parent will be skipped. This option is used by raxSeek() when * implementing seeking a non existing element with the ">" or "<" options: * the starting node is not a key in that particular case, so we start the scan * from a node that does not represent the key set. * * The function returns 1 on success or 0 on out of memory. */ int raxIteratorNextStep(raxIterator *it, int noup) { if (it->flags & RAX_ITER_EOF) { return 1; } else if (it->flags & RAX_ITER_JUST_SEEKED) { it->flags &= ~RAX_ITER_JUST_SEEKED; return 1; } /* Save key len, stack items and the node where we are currently * so that on iterator EOF we can restore the current key and state. */ size_t orig_key_len = it->key_len; size_t orig_stack_items = it->stack.items; raxNode *orig_node = it->node; while (1) { int children = it->node->iscompr ? 1 : it->node->size; if (!noup && children) { debugf("GO DEEPER\n"); /* Seek the lexicographically smaller key in this subtree, which * is the first one found always going towards the first child * of every successive node. */ if (!raxStackPush(&it->stack, it->node)) return 0; raxNode **cp = raxNodeFirstChildPtr(it->node); if (!raxIteratorAddChars(it, it->node->data, it->node->iscompr ? it->node->size : 1)) return 0; memcpy(&it->node, cp, sizeof(it->node)); /* Call the node callback if any, and replace the node pointer * if the callback returns true. */ if (it->node_cb && it->node_cb(&it->node)) memcpy(cp, &it->node, sizeof(it->node)); /* For "next" step, stop every time we find a key along the * way, since the key is lexicographically smaller compared to * what follows in the sub-children. */ if (it->node->iskey) { it->data = raxGetData(it->node); return 1; } } else { /* If we finished exploring the previous sub-tree, switch to the * new one: go upper until a node is found where there are * children representing keys lexicographically greater than the * current key. */ while (1) { int old_noup = noup; /* Already on head? Can't go up, iteration finished. */ if (!noup && it->node == it->rt->head) { it->flags |= RAX_ITER_EOF; it->stack.items = orig_stack_items; it->key_len = orig_key_len; it->node = orig_node; return 1; } /* If there are no children at the current node, try parent's * next child. */ unsigned char prevchild = it->key[it->key_len - 1]; if (!noup) { it->node = raxStackPop(&it->stack); } else { noup = 0; } /* Adjust the current key to represent the node we are * at. */ int todel = it->node->iscompr ? it->node->size : 1; raxIteratorDelChars(it, todel); /* Try visiting the next child if there was at least one * additional child. */ if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { raxNode **cp = raxNodeFirstChildPtr(it->node); int i = 0; while (i < it->node->size) { debugf("SCAN NEXT %c\n", it->node->data[i]); if (it->node->data[i] > prevchild) break; i++; cp++; } if (i != it->node->size) { debugf("SCAN found a new node\n"); raxIteratorAddChars(it, it->node->data + i, 1); if (!raxStackPush(&it->stack, it->node)) return 0; memcpy(&it->node, cp, sizeof(it->node)); /* Call the node callback if any, and replace the node * pointer if the callback returns true. */ if (it->node_cb && it->node_cb(&it->node)) memcpy(cp, &it->node, sizeof(it->node)); if (it->node->iskey) { it->data = raxGetData(it->node); return 1; } break; } } } } } } /* Seek the greatest key in the subtree at the current node. Return 0 on * out of memory, otherwise 1. This is a helper function for different * iteration functions below. */ int raxSeekGreatest(raxIterator *it) { while (it->node->size) { if (it->node->iscompr) { if (!raxIteratorAddChars(it, it->node->data, it->node->size)) return 0; } else { if (!raxIteratorAddChars(it, it->node->data + it->node->size - 1, 1)) return 0; } raxNode **cp = raxNodeLastChildPtr(it->node); if (!raxStackPush(&it->stack, it->node)) return 0; memcpy(&it->node, cp, sizeof(it->node)); } return 1; } /* Like raxIteratorNextStep() but implements an iteration step moving * to the lexicographically previous element. The 'noup' option has a similar * effect to the one of raxIteratorNextStep(). */ int raxIteratorPrevStep(raxIterator *it, int noup) { if (it->flags & RAX_ITER_EOF) { return 1; } else if (it->flags & RAX_ITER_JUST_SEEKED) { it->flags &= ~RAX_ITER_JUST_SEEKED; return 1; } /* Save key len, stack items and the node where we are currently * so that on iterator EOF we can restore the current key and state. */ size_t orig_key_len = it->key_len; size_t orig_stack_items = it->stack.items; raxNode *orig_node = it->node; while (1) { int old_noup = noup; /* Already on head? Can't go up, iteration finished. */ if (!noup && it->node == it->rt->head) { it->flags |= RAX_ITER_EOF; it->stack.items = orig_stack_items; it->key_len = orig_key_len; it->node = orig_node; return 1; } unsigned char prevchild = it->key[it->key_len - 1]; if (!noup) { it->node = raxStackPop(&it->stack); } else { noup = 0; } /* Adjust the current key to represent the node we are * at. */ int todel = it->node->iscompr ? it->node->size : 1; raxIteratorDelChars(it, todel); /* Try visiting the prev child if there is at least one * child. */ if (!it->node->iscompr && it->node->size > (old_noup ? 0 : 1)) { raxNode **cp = raxNodeLastChildPtr(it->node); int i = it->node->size - 1; while (i >= 0) { debugf("SCAN PREV %c\n", it->node->data[i]); if (it->node->data[i] < prevchild) break; i--; cp--; } /* If we found a new subtree to explore in this node, * go deeper following all the last children in order to * find the key lexicographically greater. */ if (i != -1) { debugf("SCAN found a new node\n"); /* Enter the node we just found. */ if (!raxIteratorAddChars(it, it->node->data + i, 1)) return 0; if (!raxStackPush(&it->stack, it->node)) return 0; memcpy(&it->node, cp, sizeof(it->node)); /* Seek sub-tree max. */ if (!raxSeekGreatest(it)) return 0; } } /* Return the key: this could be the key we found scanning a new * subtree, or if we did not find a new subtree to explore here, * before giving up with this node, check if it's a key itself. */ if (it->node->iskey) { it->data = raxGetData(it->node); return 1; } } } /* Seek an iterator at the specified element. * Return 0 if the seek failed for syntax error or out of memory. Otherwise * 1 is returned. When 0 is returned for out of memory, errno is set to * the ENOMEM value. */ int raxSeek(raxIterator *it, const char *op, unsigned char *ele, size_t len) { int eq = 0, lt = 0, gt = 0, first = 0, last = 0; it->stack.items = 0; /* Just resetting. Initialized by raxStart(). */ it->flags |= RAX_ITER_JUST_SEEKED; it->flags &= ~RAX_ITER_EOF; it->key_len = 0; it->node = NULL; /* Set flags according to the operator used to perform the seek. */ if (op[0] == '>') { gt = 1; if (op[1] == '=') eq = 1; } else if (op[0] == '<') { lt = 1; if (op[1] == '=') eq = 1; } else if (op[0] == '=') { eq = 1; } else if (op[0] == '^') { first = 1; } else if (op[0] == '$') { last = 1; } else { errno = 0; return 0; /* Error. */ } /* If there are no elements, set the EOF condition immediately and * return. */ if (it->rt->numele == 0) { it->flags |= RAX_ITER_EOF; return 1; } if (first) { /* Seeking the first key greater or equal to the empty string * is equivalent to seeking the smaller key available. */ return raxSeek(it, ">=", NULL, 0); } if (last) { /* Find the greatest key taking always the last child till a * final node is found. */ it->node = it->rt->head; if (!raxSeekGreatest(it)) return 0; assert(it->node->iskey); it->data = raxGetData(it->node); return 1; } /* We need to seek the specified key. What we do here is to actually * perform a lookup, and later invoke the prev/next key code that * we already use for iteration. */ int splitpos = 0; size_t i = raxLowWalk(it->rt, ele, len, &it->node, NULL, &splitpos, &it->stack); /* Return OOM on incomplete stack info. */ if (it->stack.oom) return 0; if (eq && i == len && (!it->node->iscompr || splitpos == 0) && it->node->iskey) { /* We found our node, since the key matches and we have an * "equal" condition. */ if (!raxIteratorAddChars(it, ele, len)) return 0; /* OOM. */ it->data = raxGetData(it->node); } else if (lt || gt) { /* Exact key not found or eq flag not set. We have to set as current * key the one represented by the node we stopped at, and perform * a next/prev operation to seek. */ raxIteratorAddChars(it, ele, i - splitpos); /* We need to set the iterator in the correct state to call next/prev * step in order to seek the desired element. */ debugf("After initial seek: i=%d len=%d key=%.*s\n", (int)i, (int)len, (int)it->key_len, it->key); if (i != len && !it->node->iscompr) { /* If we stopped in the middle of a normal node because of a * mismatch, add the mismatching character to the current key * and call the iterator with the 'noup' flag so that it will try * to seek the next/prev child in the current node directly based * on the mismatching character. */ if (!raxIteratorAddChars(it, ele + i, 1)) return 0; debugf("Seek normal node on mismatch: %.*s\n", (int)it->key_len, (char *)it->key); it->flags &= ~RAX_ITER_JUST_SEEKED; if (lt && !raxIteratorPrevStep(it, 1)) return 0; if (gt && !raxIteratorNextStep(it, 1)) return 0; it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ } else if (i != len && it->node->iscompr) { debugf("Compressed mismatch: %.*s\n", (int)it->key_len, (char *)it->key); /* In case of a mismatch within a compressed node. */ int nodechar = it->node->data[splitpos]; int keychar = ele[i]; it->flags &= ~RAX_ITER_JUST_SEEKED; if (gt) { /* If the key the compressed node represents is greater * than our seek element, continue forward, otherwise set the * state in order to go back to the next sub-tree. */ if (nodechar > keychar) { if (!raxIteratorNextStep(it, 0)) return 0; } else { if (!raxIteratorAddChars(it, it->node->data, it->node->size)) return 0; if (!raxIteratorNextStep(it, 1)) return 0; } } if (lt) { /* If the key the compressed node represents is smaller * than our seek element, seek the greater key in this * subtree, otherwise set the state in order to go back to * the previous sub-tree. */ if (nodechar < keychar) { if (!raxSeekGreatest(it)) return 0; it->data = raxGetData(it->node); } else { if (!raxIteratorAddChars(it, it->node->data, it->node->size)) return 0; if (!raxIteratorPrevStep(it, 1)) return 0; } } it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ } else { debugf("No mismatch: %.*s\n", (int)it->key_len, (char *)it->key); /* If there was no mismatch we are into a node representing the * key, (but which is not a key or the seek operator does not * include 'eq'), or we stopped in the middle of a compressed node * after processing all the key. Continue iterating as this was * a legitimate key we stopped at. */ it->flags &= ~RAX_ITER_JUST_SEEKED; if (it->node->iscompr && it->node->iskey && splitpos && lt) { /* If we stopped in the middle of a compressed node with * perfect match, and the condition is to seek a key "<" than * the specified one, then if this node is a key it already * represents our match. For instance we may have nodes: * * "f" -> "oobar" = 1 -> "" = 2 * * Representing keys "f" = 1, "foobar" = 2. A seek for * the key < "foo" will stop in the middle of the "oobar" * node, but will be our match, representing the key "f". * * So in that case, we don't seek backward. */ it->data = raxGetData(it->node); } else { if (gt && !raxIteratorNextStep(it, 0)) return 0; if (lt && !raxIteratorPrevStep(it, 0)) return 0; } it->flags |= RAX_ITER_JUST_SEEKED; /* Ignore next call. */ } } else { /* If we are here just eq was set but no match was found. */ it->flags |= RAX_ITER_EOF; return 1; } return 1; } /* Go to the next element in the scope of the iterator 'it'. * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ int raxNext(raxIterator *it) { if (!raxIteratorNextStep(it, 0)) { errno = ENOMEM; return 0; } if (it->flags & RAX_ITER_EOF) { errno = 0; return 0; } return 1; } /* Go to the previous element in the scope of the iterator 'it'. * If EOF (or out of memory) is reached, 0 is returned, otherwise 1 is * returned. In case 0 is returned because of OOM, errno is set to ENOMEM. */ int raxPrev(raxIterator *it) { if (!raxIteratorPrevStep(it, 0)) { errno = ENOMEM; return 0; } if (it->flags & RAX_ITER_EOF) { errno = 0; return 0; } return 1; } /* Perform a random walk starting in the current position of the iterator. * Return 0 if the tree is empty or on out of memory. Otherwise 1 is returned * and the iterator is set to the node reached after doing a random walk * of 'steps' steps. If the 'steps' argument is 0, the random walk is performed * using a random number of steps between 1 and two times the logarithm of * the number of elements. * * NOTE: if you use this function to generate random elements from the radix * tree, expect a disappointing distribution. A random walk produces good * random elements if the tree is not sparse, however in the case of a radix * tree certain keys will be reported much more often than others. At least * this function should be able to explore every possible element eventually. */ int raxRandomWalk(raxIterator *it, size_t steps) { if (it->rt->numele == 0) { it->flags |= RAX_ITER_EOF; return 0; } if (steps == 0) { size_t fle = 1 + floor(log(it->rt->numele)); fle *= 2; steps = 1 + rand() % fle; } raxNode *n = it->node; while (steps > 0 || !n->iskey) { int numchildren = n->iscompr ? 1 : n->size; int r = rand() % (numchildren + (n != it->rt->head)); if (r == numchildren) { /* Go up to parent. */ n = raxStackPop(&it->stack); int todel = n->iscompr ? n->size : 1; raxIteratorDelChars(it, todel); } else { /* Select a random child. */ if (n->iscompr) { if (!raxIteratorAddChars(it, n->data, n->size)) return 0; } else { if (!raxIteratorAddChars(it, n->data + r, 1)) return 0; } raxNode **cp = raxNodeFirstChildPtr(n) + r; if (!raxStackPush(&it->stack, n)) return 0; memcpy(&n, cp, sizeof(n)); } if (n->iskey) steps--; } it->node = n; it->data = raxGetData(it->node); return 1; } /* Compare the key currently pointed by the iterator to the specified * key according to the specified operator. Returns 1 if the comparison is * true, otherwise 0 is returned. */ int raxCompare(raxIterator *iter, const char *op, unsigned char *key, size_t key_len) { int eq = 0, lt = 0, gt = 0; if (op[0] == '=' || op[1] == '=') eq = 1; if (op[0] == '>') gt = 1; else if (op[0] == '<') lt = 1; else if (op[1] != '=') return 0; /* Syntax error. */ size_t minlen = key_len < iter->key_len ? key_len : iter->key_len; int cmp = memcmp(iter->key, key, minlen); /* Handle == */ if (lt == 0 && gt == 0) return cmp == 0 && key_len == iter->key_len; /* Handle >, >=, <, <= */ if (cmp == 0) { /* Same prefix: longer wins. */ if (eq && key_len == iter->key_len) return 1; else if (lt) return iter->key_len < key_len; else if (gt) return iter->key_len > key_len; else return 0; /* Avoid warning, just 'eq' is handled before. */ } else if (cmp > 0) { return gt ? 1 : 0; } else /* (cmp < 0) */ { return lt ? 1 : 0; } } /* Free the iterator. */ void raxStop(raxIterator *it) { if (it->key != it->key_static_string) rax_free(it->key); raxStackFree(&it->stack); } /* Return if the iterator is in an EOF state. This happens when raxSeek() * failed to seek an appropriate element, so that raxNext() or raxPrev() * will return zero, or when an EOF condition was reached while iterating * with raxNext() and raxPrev(). */ int raxEOF(raxIterator *it) { return it->flags & RAX_ITER_EOF; } /* Return the number of elements inside the radix tree. */ uint64_t raxSize(rax *rax) { return rax->numele; } /* Return the rax tree allocation size in bytes */ size_t raxAllocSize(rax *rax) { return rax->alloc_size; } /* ----------------------------- Introspection ------------------------------ */ /* This function is mostly used for debugging and learning purposes. * It shows an ASCII representation of a tree on standard output, outline * all the nodes and the contained keys. * * The representation is as follow: * * "foobar" (compressed node) * [abc] (normal node with three children) * [abc]=0x12345678 (node is a key, pointing to value 0x12345678) * [] (a normal empty node) * * Children are represented in new indented lines, each children prefixed by * the "`-(x)" string, where "x" is the edge byte. * * [abc] * `-(a) "ladin" * `-(b) [kj] * `-(c) [] * * However when a node has a single child the following representation * is used instead: * * [abc] -> "ladin" -> [] */ /* The actual implementation of raxShow(). */ void raxRecursiveShow(int level, int lpad, raxNode *n) { char s = n->iscompr ? '"' : '['; char e = n->iscompr ? '"' : ']'; int numchars = printf("%c%.*s%c", s, n->size, n->data, e); if (n->iskey) { numchars += printf("=%p", raxGetData(n)); } int numchildren = n->iscompr ? 1 : n->size; /* Note that 7 and 4 magic constants are the string length * of " `-(x) " and " -> " respectively. */ if (level) { lpad += (numchildren > 1) ? 7 : 4; if (numchildren == 1) lpad += numchars; } raxNode **cp = raxNodeFirstChildPtr(n); for (int i = 0; i < numchildren; i++) { char *branch = " `-(%c) "; if (numchildren > 1) { printf("\n"); for (int j = 0; j < lpad; j++) putchar(' '); printf(branch, n->data[i]); } else { printf(" -> "); } raxNode *child; memcpy(&child, cp, sizeof(child)); raxRecursiveShow(level + 1, lpad, child); cp++; } } /* Show a tree, as outlined in the comment above. */ void raxShow(rax *rax) { raxRecursiveShow(0, 0, rax->head); putchar('\n'); } /* Used by debugnode() macro to show info about a given node. */ void raxDebugShowNode(const char *msg, raxNode *n) { if (raxDebugMsg == 0) return; printf("%s: %p [%.*s] key:%u size:%u children:", msg, (void *)n, (int)n->size, (char *)n->data, n->iskey, n->size); int numcld = n->iscompr ? 1 : n->size; raxNode **cldptr = raxNodeLastChildPtr(n) - (numcld - 1); while (numcld--) { raxNode *child; memcpy(&child, cldptr, sizeof(child)); cldptr++; printf("%p ", (void *)child); } printf("\n"); fflush(stdout); } /* Touch all the nodes of a tree returning a check sum. This is useful * in order to make Valgrind detect if there is something wrong while * reading the data structure. * * This function was used in order to identify Rax bugs after a big refactoring * using this technique: * * 1. The rax-test is executed using Valgrind, adding a printf() so that for * the fuzz tester we see what iteration in the loop we are in. * 2. After every modification of the radix tree made by the fuzz tester * in rax-test.c, we add a call to raxTouch(). * 3. Now as soon as an operation will corrupt the tree, raxTouch() will * detect it (via Valgrind) immediately. We can add more calls to narrow * the state. * 4. At this point a good idea is to enable Rax debugging messages immediately * before the moment the tree is corrupted, to see what happens. */ unsigned long raxTouch(raxNode *n) { debugf("Touching %p\n", (void *)n); unsigned long sum = 0; if (n->iskey) { sum += (unsigned long)raxGetData(n); } int numchildren = n->iscompr ? 1 : n->size; raxNode **cp = raxNodeFirstChildPtr(n); int count = 0; for (int i = 0; i < numchildren; i++) { if (numchildren > 1) { sum += (long)n->data[i]; } raxNode *child; memcpy(&child, cp, sizeof(child)); if (child == (void *)0x65d1760) count++; if (count > 1) exit(1); sum += raxTouch(child); cp++; } return sum; } int checkedRaxRemove(rax *rax, unsigned char *s, size_t len, void **old) { int res = raxRemove(rax, s, len, old); if(res == 0) { // lp freed but node not removed! fprintf(stderr, "Error: corrupted listpack found."); abort(); } return res; } ================================================ FILE: src/redis/rax.h ================================================ /* Rax -- A radix tree implementation. * * Copyright (c) 2017-2018, Redis Ltd. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef RAX_H #define RAX_H #include /* Representation of a radix tree as implemented in this file, that contains * the strings "foo", "foobar" and "footer" after the insertion of each * word. When the node represents a key inside the radix tree, we write it * between [], otherwise it is written between (). * * This is the vanilla representation: * * (f) "" * \ * (o) "f" * \ * (o) "fo" * \ * [t b] "foo" * / \ * "foot" (e) (a) "foob" * / \ * "foote" (r) (r) "fooba" * / \ * "footer" [] [] "foobar" * * However, this implementation implements a very common optimization where * successive nodes having a single child are "compressed" into the node * itself as a string of characters, each representing a next-level child, * and only the link to the node representing the last character node is * provided inside the representation. So the above representation is turned * into: * * ["foo"] "" * | * [t b] "foo" * / \ * "foot" ("er") ("ar") "foob" * / \ * "footer" [] [] "foobar" * * However this optimization makes the implementation a bit more complex. * For instance if a key "first" is added in the above radix tree, a * "node splitting" operation is needed, since the "foo" prefix is no longer * composed of nodes having a single child one after the other. This is the * above tree and the resulting node splitting after this event happens: * * * (f) "" * / * (i o) "f" * / \ * "firs" ("rst") (o) "fo" * / \ * "first" [] [t b] "foo" * / \ * "foot" ("er") ("ar") "foob" * / \ * "footer" [] [] "foobar" * * Similarly after deletion, if a new chain of nodes having a single child * is created (the chain must also not include nodes that represent keys), * it must be compressed back into a single node. * */ #define RAX_NODE_MAX_SIZE ((1 << 29) - 1) typedef struct raxNode { uint32_t iskey : 1; /* Does this node contain a key? */ uint32_t isnull : 1; /* Associated value is NULL (don't store it). */ uint32_t iscompr : 1; /* Node is compressed. */ uint32_t size : 29; /* Number of children, or compressed string len. */ /* Data layout is as follows: * * If node is not compressed we have 'size' bytes, one for each children * character, and 'size' raxNode pointers, point to each child node. * Note how the character is not stored in the children but in the * edge of the parents: * * [header iscompr=0][abc][a-ptr][b-ptr][c-ptr](value-ptr?) * * if node is compressed (iscompr bit is 1) the node has 1 children. * In that case the 'size' bytes of the string stored immediately at * the start of the data section, represent a sequence of successive * nodes linked one after the other, for which only the last one in * the sequence is actually represented as a node, and pointed to by * the current compressed node. * * [header iscompr=1][xyz][z-ptr](value-ptr?) * * Both compressed and not compressed nodes can represent a key * with associated data in the radix tree at any level (not just terminal * nodes). * * If the node has an associated key (iskey=1) and is not NULL * (isnull=0), then after the raxNode pointers pointing to the * children, an additional value pointer is present (as you can see * in the representation above as "value-ptr" field). */ unsigned char data[]; } raxNode; typedef struct rax { raxNode *head; /* Pointer to root node of tree */ uint64_t numele; /* Number of keys in the tree */ uint64_t numnodes; /* Number of rax nodes in the tree */ size_t alloc_size; /* Total allocation size of the tree in bytes */ } rax; /* Stack data structure used by raxLowWalk() in order to, optionally, return * a list of parent nodes to the caller. The nodes do not have a "parent" * field for space concerns, so we use the auxiliary stack when needed. */ #define RAX_STACK_STATIC_ITEMS 32 typedef struct raxStack { void **stack; /* Points to static_items or an heap allocated array. */ size_t items, maxitems; /* Number of items contained and total space. */ /* Up to RAXSTACK_STACK_ITEMS items we avoid to allocate on the heap * and use this static array of pointers instead. */ void *static_items[RAX_STACK_STATIC_ITEMS]; int oom; /* True if pushing into this stack failed for OOM at some point. */ } raxStack; /* Optional callback used for iterators and be notified on each rax node, * including nodes not representing keys. If the callback returns true * the callback changed the node pointer in the iterator structure, and the * iterator implementation will have to replace the pointer in the radix tree * internals. This allows the callback to reallocate the node to perform * very special operations, normally not needed by normal applications. * * This callback is used to perform very low level analysis of the radix tree * structure, scanning each possible node (but the root node), or in order to * reallocate the nodes to reduce the allocation fragmentation (this is the * server's application for this callback). * * This is currently only supported in forward iterations (raxNext) */ typedef int (*raxNodeCallback)(raxNode **noderef); /* Radix tree iterator state is encapsulated into this data structure. */ #define RAX_ITER_STATIC_LEN 128 #define RAX_ITER_JUST_SEEKED (1 << 0) /* Iterator was just seeked. Return current \ element for the first iteration and \ clear the flag. */ #define RAX_ITER_EOF (1 << 1) /* End of iteration reached. */ #define RAX_ITER_SAFE (1 << 2) /* Safe iterator, allows operations while \ iterating. But it is slower. */ typedef struct raxIterator { int flags; rax *rt; /* Radix tree we are iterating. */ unsigned char *key; /* The current string. */ void *data; /* Data associated to this key. */ size_t key_len; /* Current key length. */ size_t key_max; /* Max key len the current key buffer can hold. */ unsigned char key_static_string[RAX_ITER_STATIC_LEN]; raxNode *node; /* Current node. Only for unsafe iteration. */ raxStack stack; /* Stack used for unsafe iteration. */ raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */ } raxIterator; /* Exported API. */ rax *raxNew(void); int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old); int raxRemove(rax *rax, unsigned char *s, size_t len, void **old); int raxFind(rax *rax, unsigned char *s, size_t len, void **value); void raxFree(rax *rax); void raxFreeWithCallback(rax *rax, void (*free_callback)(void*)); void raxFreeWithCallbackAndArgument(rax *rax, void (*free_callback)(void*, void*), void* argument); void raxStart(raxIterator *it, rax *rt); int raxSeek(raxIterator *it, const char *op, unsigned char *ele, size_t len); int raxNext(raxIterator *it); int raxPrev(raxIterator *it); int raxRandomWalk(raxIterator *it, size_t steps); int raxCompare(raxIterator *iter, const char *op, unsigned char *key, size_t key_len); void raxStop(raxIterator *it); int raxEOF(raxIterator *it); void raxShow(rax *rax); uint64_t raxSize(rax *rax); size_t raxAllocSize(rax *rax); unsigned long raxTouch(raxNode *n); void raxSetDebugMsg(int onoff); int checkedRaxRemove(rax *rax, unsigned char *s, size_t len, void **old); /* Internal API. May be used by the node callback in order to access rax nodes * in a low level way, so this function is exported as well. */ void raxSetData(raxNode *n, void *data); #endif ================================================ FILE: src/redis/rax_malloc.h ================================================ /* Rax -- A radix tree implementation. * * Copyright (c) 2017, Redis Ltd. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ /* Allocator selection. * * This file is used in order to change the Rax allocator at compile time. * Just define the following defines to what you want to use. Also add * the include of your alternate allocator if needed (not needed in order * to use the default libc allocator). */ #ifndef RAX_ALLOC_H #define RAX_ALLOC_H #include "zmalloc.h" #define rax_malloc zmalloc #define rax_realloc zrealloc #define rax_free zfree #define rax_ptr_alloc_size zmalloc_size #endif ================================================ FILE: src/redis/rdb.h ================================================ /* * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __RDB_H #define __RDB_H #include #include #include #include "redis_aux.h" /* The current RDB version. When the format changes in a way that is no longer * backward compatible this number gets incremented. */ #define RDB_VERSION 12 /* We would like to serialize to version 9 such that our rdb files * can be loaded by redis version 6 (RDB_VERSION 9) */ #define RDB_SER_VERSION 9 /* Defines related to the dump file format. To store 32 bits lengths for short * keys requires a lot of space, so we check the most significant 2 bits of * the first byte to interpreter the length: * * 00|XXXXXX => if the two MSB are 00 the len is the 6 bits of this byte * 01|XXXXXX XXXXXXXX => 01, the len is 14 bits, 6 bits + 8 bits of next byte * 10|000000 [32 bit integer] => A full 32 bit len in net byte order will follow * 10|000001 [64 bit integer] => A full 64 bit len in net byte order will follow * 11|OBKIND this means: specially encoded object will follow. The six bits * number specify the kind of object that follows. * See the RDB_ENC_* defines. * * Lengths up to 63 are stored using a single byte, most DB keys, and may * values, will fit inside. */ #define RDB_6BITLEN 0 #define RDB_14BITLEN 1 #define RDB_32BITLEN 0x80 #define RDB_64BITLEN 0x81 #define RDB_ENCVAL 3 #define RDB_LENERR UINT64_MAX /* When a length of a string object stored on disk has the first two bits * set, the remaining six bits specify a special encoding for the object * accordingly to the following defines: */ #define RDB_ENC_INT8 0 /* 8 bit signed integer */ #define RDB_ENC_INT16 1 /* 16 bit signed integer */ #define RDB_ENC_INT32 2 /* 32 bit signed integer */ #define RDB_ENC_LZF 3 /* string compressed with FASTLZ */ /* Map object types to RDB object types. Macros starting with OBJ_ are for * memory storage and may change. Instead RDB types must be fixed because * we store them on disk. */ #define RDB_TYPE_STRING 0 #define RDB_TYPE_LIST 1 #define RDB_TYPE_SET 2 #define RDB_TYPE_ZSET 3 #define RDB_TYPE_HASH 4 #define RDB_TYPE_ZSET_2 5 /* ZSET version 2 with doubles stored in binary. */ #define RDB_TYPE_MODULE 6 #define RDB_TYPE_MODULE_PRE_GA 6 /* Used in 4.0 release candidates */ #define RDB_TYPE_MODULE_2 7 /* Module value with annotations for parsing without the generating module being loaded. */ /* NOTE: WHEN ADDING NEW RDB TYPE, UPDATE rdbIsObjectType() BELOW */ /* Object types for encoded objects. */ #define RDB_TYPE_HASH_ZIPMAP 9 #define RDB_TYPE_LIST_ZIPLIST 10 #define RDB_TYPE_SET_INTSET 11 #define RDB_TYPE_ZSET_ZIPLIST 12 #define RDB_TYPE_HASH_ZIPLIST 13 #define RDB_TYPE_LIST_QUICKLIST 14 #define RDB_TYPE_STREAM_LISTPACKS 15 #define RDB_TYPE_HASH_LISTPACK 16 #define RDB_TYPE_ZSET_LISTPACK 17 #define RDB_TYPE_LIST_QUICKLIST_2 18 #define RDB_TYPE_STREAM_LISTPACKS_2 19 #define RDB_TYPE_SET_LISTPACK 20 #define RDB_TYPE_STREAM_LISTPACKS_3 21 /* NOTE: WHEN ADDING NEW RDB TYPE, UPDATE rdbIsObjectType() BELOW */ /* Test if a type is an object type. */ #define __rdbIsObjectType(t) (((t) >= 0 && (t) <= 7) || ((t) >= 9 && (t) <= 21)) /* Range 200-240 is used by Dragonfly specific opcodes */ /* Special RDB opcodes (saved/loaded with rdbSaveType/rdbLoadType). */ #define RDB_OPCODE_SLOT_INFO 244 /* Individual slot info, such as slot id and size (cluster mode only). */ #define RDB_OPCODE_FUNCTION 246 /* engine data */ #define RDB_OPCODE_FUNCTION2 245 /* function library data */ #define RDB_OPCODE_FUNCTION_PRE_GA 246 /* old function library data for 7.0 rc1 and rc2 */ #define RDB_OPCODE_MODULE_AUX 247 /* Module auxiliary data. */ #define RDB_OPCODE_IDLE 248 /* LRU idle time. */ #define RDB_OPCODE_FREQ 249 /* LFU frequency. */ #define RDB_OPCODE_AUX 250 /* RDB aux field. */ #define RDB_OPCODE_RESIZEDB 251 /* Hash table resize hint. */ #define RDB_OPCODE_EXPIRETIME_MS 252 /* Expire time in milliseconds. */ #define RDB_OPCODE_EXPIRETIME 253 /* Old expire time in seconds. */ #define RDB_OPCODE_SELECTDB 254 /* DB number of the following keys. */ #define RDB_OPCODE_EOF 255 /* End of the RDB file. */ /* Module serialized values sub opcodes */ #define RDB_MODULE_OPCODE_EOF 0 /* End of module value. */ #define RDB_MODULE_OPCODE_SINT 1 /* Signed integer. */ #define RDB_MODULE_OPCODE_UINT 2 /* Unsigned integer. */ #define RDB_MODULE_OPCODE_FLOAT 3 /* Float. */ #define RDB_MODULE_OPCODE_DOUBLE 4 /* Double. */ #define RDB_MODULE_OPCODE_STRING 5 /* String. */ /* rdbLoad...() functions flags. */ #define RDB_LOAD_NONE 0 #define RDB_LOAD_ENC (1<<0) #define RDB_LOAD_PLAIN (1<<1) #define RDB_LOAD_SDS (1<<2) /* flags on the purpose of rdb save or load */ #define RDBFLAGS_NONE 0 /* No special RDB loading. */ #define RDBFLAGS_AOF_PREAMBLE (1<<0) /* Load/save the RDB as AOF preamble. */ #define RDBFLAGS_REPLICATION (1<<1) /* Load/save for SYNC. */ #define RDBFLAGS_ALLOW_DUP (1<<2) /* Allow duplicated keys when loading.*/ #define RDBFLAGS_FEED_REPL (1<<3) /* Feed replication stream when loading.*/ #define RDBFLAGS_KEEP_CACHE (1<<4) /* Don't reclaim cache after rdb file is generated */ /* When rdbLoadObject() returns NULL, the err flag is * set to hold the type of error that occurred */ #define RDB_LOAD_ERR_EMPTY_KEY 1 /* Error of empty key */ #define RDB_LOAD_ERR_OTHER 2 /* Any other errors */ // ROMAN: those constants should be factored out to redis_base.h or something. // Currently moved here from server.h #define LONG_STR_SIZE 21 /* Bytes needed for long -> str + '\0' */ #define REDIS_VERSION "6.2.11" #endif ================================================ FILE: src/redis/read.c ================================================ /* * Copyright (c) 2009-2011, Salvatore Sanfilippo * Copyright (c) 2010-2011, Pieter Noordhuis * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include #include #include #include "sdsalloc.h" #include "read.h" #include "sds.h" /* Initial size of our nested reply stack and how much we grow it when needd */ #define REDIS_READER_STACK_SIZE 9 static void __redisReaderSetError(redisReader *r, int type, const char *str) { size_t len; if (r->reply != NULL && r->fn && r->fn->freeObject) { r->fn->freeObject(r->reply); r->reply = NULL; } /* Clear input buffer on errors. */ sdsfree(r->buf); r->buf = NULL; r->pos = r->len = 0; /* Reset task stack. */ r->ridx = -1; /* Set error. */ r->err = type; len = strlen(str); len = len < (sizeof(r->errstr)-1) ? len : (sizeof(r->errstr)-1); memcpy(r->errstr,str,len); r->errstr[len] = '\0'; } static size_t chrtos(char *buf, size_t size, char byte) { size_t len = 0; switch(byte) { case '\\': case '"': len = snprintf(buf,size,"\"\\%c\"",byte); break; case '\n': len = snprintf(buf,size,"\"\\n\""); break; case '\r': len = snprintf(buf,size,"\"\\r\""); break; case '\t': len = snprintf(buf,size,"\"\\t\""); break; case '\a': len = snprintf(buf,size,"\"\\a\""); break; case '\b': len = snprintf(buf,size,"\"\\b\""); break; default: if (isprint(byte)) len = snprintf(buf,size,"\"%c\"",byte); else len = snprintf(buf,size,"\"\\x%02x\"",(unsigned char)byte); break; } return len; } static void __redisReaderSetErrorProtocolByte(redisReader *r, char byte) { char cbuf[8], sbuf[128]; chrtos(cbuf,sizeof(cbuf),byte); snprintf(sbuf,sizeof(sbuf), "Protocol error, got %s as reply type byte", cbuf); __redisReaderSetError(r,REDIS_ERR_PROTOCOL,sbuf); } static void __redisReaderSetErrorOOM(redisReader *r) { __redisReaderSetError(r,REDIS_ERR_OOM,"Out of memory"); } static char *readBytes(redisReader *r, unsigned int bytes) { char *p; if (r->len-r->pos >= bytes) { p = r->buf+r->pos; r->pos += bytes; return p; } return NULL; } /* Find pointer to \r\n. */ static char *seekNewline(char *s, size_t len) { char *ret; /* We cannot match with fewer than 2 bytes */ if (len < 2) return NULL; /* Search up to len - 1 characters */ len--; /* Look for the \r */ while ((ret = memchr(s, '\r', len)) != NULL) { if (ret[1] == '\n') { /* Found. */ break; } /* Continue searching. */ ret++; len -= ret - s; s = ret; } return ret; } /* Convert a string into a long long. Returns REDIS_OK if the string could be * parsed into a (non-overflowing) long long, REDIS_ERR otherwise. The value * will be set to the parsed value when appropriate. * * Note that this function demands that the string strictly represents * a long long: no spaces or other characters before or after the string * representing the number are accepted, nor zeroes at the start if not * for the string "0" representing the zero number. * * Because of its strictness, it is safe to use this function to check if * you can convert a string into a long long, and obtain back the string * from the number without any loss in the string representation. */ static int string2ll(const char *s, size_t slen, long long *value) { const char *p = s; size_t plen = 0; int negative = 0; unsigned long long v; if (plen == slen) return REDIS_ERR; /* Special case: first and only digit is 0. */ if (slen == 1 && p[0] == '0') { if (value != NULL) *value = 0; return REDIS_OK; } if (p[0] == '-') { negative = 1; p++; plen++; /* Abort on only a negative sign. */ if (plen == slen) return REDIS_ERR; } /* First digit should be 1-9, otherwise the string should just be 0. */ if (p[0] >= '1' && p[0] <= '9') { v = p[0]-'0'; p++; plen++; } else if (p[0] == '0' && slen == 1) { *value = 0; return REDIS_OK; } else { return REDIS_ERR; } while (plen < slen && p[0] >= '0' && p[0] <= '9') { if (v > (ULLONG_MAX / 10)) /* Overflow. */ return REDIS_ERR; v *= 10; if (v > (ULLONG_MAX - (p[0]-'0'))) /* Overflow. */ return REDIS_ERR; v += p[0]-'0'; p++; plen++; } /* Return if not all bytes were used. */ if (plen < slen) return REDIS_ERR; if (negative) { if (v > ((unsigned long long)(-(LLONG_MIN+1))+1)) /* Overflow. */ return REDIS_ERR; if (value != NULL) *value = -v; } else { if (v > LLONG_MAX) /* Overflow. */ return REDIS_ERR; if (value != NULL) *value = v; } return REDIS_OK; } static char *readLine(redisReader *r, int *_len) { char *p, *s; int len; p = r->buf+r->pos; s = seekNewline(p,(r->len-r->pos)); if (s != NULL) { len = s-(r->buf+r->pos); r->pos += len+2; /* skip \r\n */ if (_len) *_len = len; return p; } return NULL; } static void moveToNextTask(redisReader *r) { redisReadTask *cur, *prv; while (r->ridx >= 0) { /* Return a.s.a.p. when the stack is now empty. */ if (r->ridx == 0) { r->ridx--; return; } cur = r->task[r->ridx]; prv = r->task[r->ridx-1]; assert(prv->type == REDIS_REPLY_ARRAY || prv->type == REDIS_REPLY_MAP || prv->type == REDIS_REPLY_ATTR || prv->type == REDIS_REPLY_SET || prv->type == REDIS_REPLY_PUSH); if (cur->idx == prv->elements-1) { r->ridx--; } else { /* Reset the type because the next item can be anything */ assert(cur->idx < prv->elements); cur->type = -1; cur->elements = -1; cur->idx++; return; } } } static int processLineItem(redisReader *r) { redisReadTask *cur = r->task[r->ridx]; void *obj; char *p; int len; if ((p = readLine(r,&len)) != NULL) { if (cur->type == REDIS_REPLY_INTEGER) { long long v; if (string2ll(p, len, &v) == REDIS_ERR) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad integer value"); return REDIS_ERR; } if (r->fn && r->fn->createInteger) { obj = r->fn->createInteger(cur,v); } else { obj = (void*)REDIS_REPLY_INTEGER; } } else if (cur->type == REDIS_REPLY_DOUBLE) { char buf[326], *eptr; double d; if ((size_t)len >= sizeof(buf)) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Double value is too large"); return REDIS_ERR; } memcpy(buf,p,len); buf[len] = '\0'; if (len == 3 && strcasecmp(buf,"inf") == 0) { d = INFINITY; /* Positive infinite. */ } else if (len == 4 && strcasecmp(buf,"-inf") == 0) { d = -INFINITY; /* Negative infinite. */ } else if ((len == 3 && strcasecmp(buf,"nan") == 0) || (len == 4 && strcasecmp(buf, "-nan") == 0)) { d = NAN; /* nan. */ } else { d = strtod((char*)buf,&eptr); /* RESP3 only allows "inf", "-inf", and finite values, while * strtod() allows other variations on infinity, * etc. We explicity handle our two allowed infinite cases and NaN * above, so strtod() should only result in finite values. */ if (buf[0] == '\0' || eptr != &buf[len] || !isfinite(d)) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad double value"); return REDIS_ERR; } } if (r->fn && r->fn->createDouble) { obj = r->fn->createDouble(cur,d,buf,len); } else { obj = (void*)REDIS_REPLY_DOUBLE; } } else if (cur->type == REDIS_REPLY_NIL) { if (len != 0) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad nil value"); return REDIS_ERR; } if (r->fn && r->fn->createNil) obj = r->fn->createNil(cur); else obj = (void*)REDIS_REPLY_NIL; } else if (cur->type == REDIS_REPLY_BOOL) { int bval; if (len != 1 || !strchr("tTfF", p[0])) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad bool value"); return REDIS_ERR; } bval = p[0] == 't' || p[0] == 'T'; if (r->fn && r->fn->createBool) obj = r->fn->createBool(cur,bval); else obj = (void*)REDIS_REPLY_BOOL; } else if (cur->type == REDIS_REPLY_BIGNUM) { /* Ensure all characters are decimal digits (with possible leading * minus sign). */ for (int i = 0; i < len; i++) { /* XXX Consider: Allow leading '+'? Error on leading '0's? */ if (i == 0 && p[0] == '-') continue; if (p[i] < '0' || p[i] > '9') { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad bignum value"); return REDIS_ERR; } } if (r->fn && r->fn->createString) obj = r->fn->createString(cur,p,len); else obj = (void*)REDIS_REPLY_BIGNUM; } else { /* Type will be error or status. */ for (int i = 0; i < len; i++) { if (p[i] == '\r' || p[i] == '\n') { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad simple string value"); return REDIS_ERR; } } if (r->fn && r->fn->createString) obj = r->fn->createString(cur,p,len); else obj = (void*)(uintptr_t)(cur->type); } if (obj == NULL) { __redisReaderSetErrorOOM(r); return REDIS_ERR; } /* Set reply if this is the root object. */ if (r->ridx == 0) r->reply = obj; moveToNextTask(r); return REDIS_OK; } return REDIS_ERR; } static int processBulkItem(redisReader *r) { redisReadTask *cur = r->task[r->ridx]; void *obj = NULL; char *p, *s; long long len; unsigned long bytelen; int success = 0; p = r->buf+r->pos; s = seekNewline(p,r->len-r->pos); if (s != NULL) { p = r->buf+r->pos; bytelen = s-(r->buf+r->pos)+2; /* include \r\n */ if (string2ll(p, bytelen - 2, &len) == REDIS_ERR) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad bulk string length"); return REDIS_ERR; } if (len < -1 || (LLONG_MAX > SIZE_MAX && len > (long long)SIZE_MAX)) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bulk string length out of range"); return REDIS_ERR; } if (len == -1) { /* The nil object can always be created. */ if (r->fn && r->fn->createNil) obj = r->fn->createNil(cur); else obj = (void*)REDIS_REPLY_NIL; success = 1; } else { /* Only continue when the buffer contains the entire bulk item. */ bytelen += len+2; /* include \r\n */ if (r->pos+bytelen <= r->len) { if ((cur->type == REDIS_REPLY_VERB && len < 4) || (cur->type == REDIS_REPLY_VERB && s[5] != ':')) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Verbatim string 4 bytes of content type are " "missing or incorrectly encoded."); return REDIS_ERR; } if (r->fn && r->fn->createString) obj = r->fn->createString(cur,s+2,len); else obj = (void*)(uintptr_t)cur->type; success = 1; } } /* Proceed when obj was created. */ if (success) { if (obj == NULL) { __redisReaderSetErrorOOM(r); return REDIS_ERR; } r->pos += bytelen; /* Set reply if this is the root object. */ if (r->ridx == 0) r->reply = obj; moveToNextTask(r); return REDIS_OK; } } return REDIS_ERR; } static int redisReaderGrow(redisReader *r) { redisReadTask **aux; int newlen; /* Grow our stack size */ newlen = r->tasks + REDIS_READER_STACK_SIZE; aux = s_realloc(r->task, sizeof(*r->task) * newlen); if (aux == NULL) goto oom; r->task = aux; /* Allocate new tasks */ for (; r->tasks < newlen; r->tasks++) { r->task[r->tasks] = s_calloc(sizeof(**r->task)); if (r->task[r->tasks] == NULL) goto oom; } return REDIS_OK; oom: __redisReaderSetErrorOOM(r); return REDIS_ERR; } /* Process the array, map and set types. */ static int processAggregateItem(redisReader *r) { redisReadTask *cur = r->task[r->ridx]; void *obj; char *p; long long elements; int root = 0, len; if (r->ridx == r->tasks - 1) { if (redisReaderGrow(r) == REDIS_ERR) return REDIS_ERR; } if ((p = readLine(r,&len)) != NULL) { if (string2ll(p, len, &elements) == REDIS_ERR) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Bad multi-bulk length"); return REDIS_ERR; } root = (r->ridx == 0); if (elements < -1 || (LLONG_MAX > SIZE_MAX && elements > SIZE_MAX) || (r->maxelements > 0 && elements > r->maxelements)) { __redisReaderSetError(r,REDIS_ERR_PROTOCOL, "Multi-bulk length out of range"); return REDIS_ERR; } if (elements == -1) { if (r->fn && r->fn->createNil) obj = r->fn->createNil(cur); else obj = (void*)REDIS_REPLY_NIL; if (obj == NULL) { __redisReaderSetErrorOOM(r); return REDIS_ERR; } moveToNextTask(r); } else { if (cur->type == REDIS_REPLY_MAP || cur->type == REDIS_REPLY_ATTR) elements *= 2; if (r->fn && r->fn->createArray) obj = r->fn->createArray(cur,elements); else obj = (void*)(uintptr_t)cur->type; if (obj == NULL) { __redisReaderSetErrorOOM(r); return REDIS_ERR; } /* Modify task stack when there are more than 0 elements. */ if (elements > 0) { cur->elements = elements; cur->obj = obj; r->ridx++; r->task[r->ridx]->type = -1; r->task[r->ridx]->elements = -1; r->task[r->ridx]->idx = 0; r->task[r->ridx]->obj = NULL; r->task[r->ridx]->parent = cur; r->task[r->ridx]->privdata = r->privdata; } else { moveToNextTask(r); } } /* Set reply if this is the root object. */ if (root) r->reply = obj; return REDIS_OK; } return REDIS_ERR; } static int processItem(redisReader *r) { redisReadTask *cur = r->task[r->ridx]; char *p; /* check if we need to read type */ if (cur->type < 0) { if ((p = readBytes(r,1)) != NULL) { switch (p[0]) { case '-': cur->type = REDIS_REPLY_ERROR; break; case '+': cur->type = REDIS_REPLY_STATUS; break; case ':': cur->type = REDIS_REPLY_INTEGER; break; case ',': cur->type = REDIS_REPLY_DOUBLE; break; case '_': cur->type = REDIS_REPLY_NIL; break; case '$': cur->type = REDIS_REPLY_STRING; break; case '*': cur->type = REDIS_REPLY_ARRAY; break; case '%': cur->type = REDIS_REPLY_MAP; break; case '|': cur->type = REDIS_REPLY_ATTR; break; case '~': cur->type = REDIS_REPLY_SET; break; case '#': cur->type = REDIS_REPLY_BOOL; break; case '=': cur->type = REDIS_REPLY_VERB; break; case '>': cur->type = REDIS_REPLY_PUSH; break; case '(': cur->type = REDIS_REPLY_BIGNUM; break; default: __redisReaderSetErrorProtocolByte(r,*p); return REDIS_ERR; } } else { /* could not consume 1 byte */ return REDIS_ERR; } } /* process typed item */ switch(cur->type) { case REDIS_REPLY_ERROR: case REDIS_REPLY_STATUS: case REDIS_REPLY_INTEGER: case REDIS_REPLY_DOUBLE: case REDIS_REPLY_NIL: case REDIS_REPLY_BOOL: case REDIS_REPLY_BIGNUM: return processLineItem(r); case REDIS_REPLY_STRING: case REDIS_REPLY_VERB: return processBulkItem(r); case REDIS_REPLY_ARRAY: case REDIS_REPLY_MAP: case REDIS_REPLY_ATTR: case REDIS_REPLY_SET: case REDIS_REPLY_PUSH: return processAggregateItem(r); default: assert(NULL); return REDIS_ERR; /* Avoid warning. */ } } redisReader *redisReaderCreateWithFunctions(redisReplyObjectFunctions *fn) { redisReader *r; r = s_calloc(sizeof(redisReader)); if (r == NULL) return NULL; r->buf = sdsempty(); if (r->buf == NULL) goto oom; r->task = s_calloc(REDIS_READER_STACK_SIZE * sizeof(*r->task)); if (r->task == NULL) goto oom; for (; r->tasks < REDIS_READER_STACK_SIZE; r->tasks++) { r->task[r->tasks] = s_calloc(sizeof(**r->task)); if (r->task[r->tasks] == NULL) goto oom; } r->fn = fn; r->maxbuf = REDIS_READER_MAX_BUF; r->maxelements = REDIS_READER_MAX_ARRAY_ELEMENTS; r->ridx = -1; return r; oom: redisReaderFree(r); return NULL; } void redisReaderFree(redisReader *r) { if (r == NULL) return; if (r->reply != NULL && r->fn && r->fn->freeObject) r->fn->freeObject(r->reply); if (r->task) { /* We know r->task[i] is allocated if i < r->tasks */ for (int i = 0; i < r->tasks; i++) { s_free(r->task[i]); } s_free(r->task); } sdsfree(r->buf); s_free(r); } int redisReaderFeed(redisReader *r, const char *buf, size_t len) { sds newbuf; /* Return early when this reader is in an erroneous state. */ if (r->err) return REDIS_ERR; /* Copy the provided buffer. */ if (buf != NULL && len >= 1) { /* Destroy internal buffer when it is empty and is quite large. */ if (r->len == 0 && r->maxbuf != 0 && sdsavail(r->buf) > r->maxbuf) { sdsfree(r->buf); r->buf = sdsempty(); if (r->buf == 0) goto oom; r->pos = 0; } newbuf = sdscatlen(r->buf,buf,len); if (newbuf == NULL) goto oom; r->buf = newbuf; r->len = sdslen(r->buf); } return REDIS_OK; oom: __redisReaderSetErrorOOM(r); return REDIS_ERR; } int redisReaderGetReply(redisReader *r, void **reply) { /* Default target pointer to NULL. */ if (reply != NULL) *reply = NULL; /* Return early when this reader is in an erroneous state. */ if (r->err) return REDIS_ERR; /* When the buffer is empty, there will never be a reply. */ if (r->len == 0) return REDIS_OK; /* Set first item to process when the stack is empty. */ if (r->ridx == -1) { r->task[0]->type = -1; r->task[0]->elements = -1; r->task[0]->idx = -1; r->task[0]->obj = NULL; r->task[0]->parent = NULL; r->task[0]->privdata = r->privdata; r->ridx = 0; } /* Process items in reply. */ while (r->ridx >= 0) if (processItem(r) != REDIS_OK) break; /* Return ASAP when an error occurred. */ if (r->err) return REDIS_ERR; /* Discard part of the buffer when we've consumed at least 1k, to avoid * doing unnecessary calls to memmove() in sds.c. */ if (r->pos >= 1024) { if (sdsrange(r->buf,r->pos,-1) < 0) return REDIS_ERR; r->pos = 0; r->len = sdslen(r->buf); } /* Emit a reply when there is one. */ if (r->ridx == -1) { if (reply != NULL) { *reply = r->reply; } else if (r->reply != NULL && r->fn && r->fn->freeObject) { r->fn->freeObject(r->reply); } r->reply = NULL; } return REDIS_OK; } ================================================ FILE: src/redis/read.h ================================================ /* * Copyright (c) 2009-2011, Salvatore Sanfilippo * Copyright (c) 2010-2011, Pieter Noordhuis * * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __HIREDIS_READ_H #define __HIREDIS_READ_H #include /* for size_t */ #define REDIS_ERR -1 #define REDIS_OK 0 /* When an error occurs, the err flag in a context is set to hold the type of * error that occurred. REDIS_ERR_IO means there was an I/O error and you * should use the "errno" variable to find out what is wrong. * For other values, the "errstr" field will hold a description. */ #define REDIS_ERR_IO 1 /* Error in read or write */ #define REDIS_ERR_EOF 3 /* End of file */ #define REDIS_ERR_PROTOCOL 4 /* Protocol error */ #define REDIS_ERR_OOM 5 /* Out of memory */ #define REDIS_ERR_TIMEOUT 6 /* Timed out */ #define REDIS_ERR_OTHER 2 /* Everything else... */ #define REDIS_REPLY_STRING 1 #define REDIS_REPLY_ARRAY 2 #define REDIS_REPLY_INTEGER 3 #define REDIS_REPLY_NIL 4 #define REDIS_REPLY_STATUS 5 #define REDIS_REPLY_ERROR 6 #define REDIS_REPLY_DOUBLE 7 #define REDIS_REPLY_BOOL 8 #define REDIS_REPLY_MAP 9 #define REDIS_REPLY_SET 10 #define REDIS_REPLY_ATTR 11 #define REDIS_REPLY_PUSH 12 #define REDIS_REPLY_BIGNUM 13 #define REDIS_REPLY_VERB 14 /* Default max unused reader buffer. */ #define REDIS_READER_MAX_BUF (1024*16) /* Default multi-bulk element limit */ #define REDIS_READER_MAX_ARRAY_ELEMENTS ((1LL<<32) - 1) #ifdef __cplusplus extern "C" { #endif typedef struct redisReadTask { int type; long long elements; /* number of elements in multibulk container */ int idx; /* index in parent (array) object */ void *obj; /* holds user-generated value for a read task */ struct redisReadTask *parent; /* parent task */ void *privdata; /* user-settable arbitrary field */ } redisReadTask; typedef struct redisReplyObjectFunctions { void *(*createString)(const redisReadTask*, char*, size_t); void *(*createArray)(const redisReadTask*, size_t); void *(*createInteger)(const redisReadTask*, long long); void *(*createDouble)(const redisReadTask*, double, char*, size_t); void *(*createNil)(const redisReadTask*); void *(*createBool)(const redisReadTask*, int); void (*freeObject)(void*); } redisReplyObjectFunctions; typedef struct redisReader { int err; /* Error flags, 0 when there is no error */ char errstr[128]; /* String representation of error when applicable */ char *buf; /* Read buffer */ size_t pos; /* Buffer cursor */ size_t len; /* Buffer length */ size_t maxbuf; /* Max length of unused buffer */ long long maxelements; /* Max multi-bulk elements */ redisReadTask **task; int tasks; int ridx; /* Index of current read task */ void *reply; /* Temporary reply pointer */ redisReplyObjectFunctions *fn; void *privdata; } redisReader; /* Public API for the protocol parser. */ redisReader *redisReaderCreateWithFunctions(redisReplyObjectFunctions *fn); void redisReaderFree(redisReader *r); int redisReaderFeed(redisReader *r, const char *buf, size_t len); int redisReaderGetReply(redisReader *r, void **reply); #define redisReaderSetPrivdata(_r, _p) (int)(((redisReader*)(_r))->privdata = (_p)) #define redisReaderGetObject(_r) (((redisReader*)(_r))->reply) #define redisReaderGetError(_r) (((redisReader*)(_r))->errstr) #ifdef __cplusplus } #endif #endif ================================================ FILE: src/redis/redis_aux.c ================================================ #include "redis_aux.h" #include #include #include "crc64.h" #include "endianconv.h" #include "zmalloc.h" Server server; void InitRedisTables() { crc64_init(); memset(&server, 0, sizeof(server)); server.max_map_field_len = 64; server.max_listpack_map_bytes = 1024; server.stream_node_max_entries = 100; } /* Toggle the 64 bit unsigned integer pointed by *p from little endian to * big endian */ void memrev64(void* p) { unsigned char *x = p, t; t = x[0]; x[0] = x[7]; x[7] = t; t = x[1]; x[1] = x[6]; x[6] = t; t = x[2]; x[2] = x[5]; x[5] = t; t = x[3]; x[3] = x[4]; x[4] = t; } // used by t_stream.c uint64_t intrev64(uint64_t v) { memrev64(&v); return v; } ================================================ FILE: src/redis/redis_aux.h ================================================ #ifndef __REDIS_AUX_H #define __REDIS_AUX_H #include "sds.h" /* redis.h auxiliary definitions */ /* the last one in object.h is OBJ_STREAM and it is 6, * this will add enough place for Redis types to grow */ #define OBJ_JSON 15U #define OBJ_SBF 16U #define OBJ_CMS 17U #define OBJ_TOPK 18U // A pseudo type for keys stored in the db, same as OBJ_MODULE which is not used in Dragonfly. #define OBJ_KEY 5U /* How many types of objects exist */ #define OBJ_TYPE_MAX 19U #define CONFIG_RUN_ID_SIZE 40U typedef struct ServerStub { size_t max_map_field_len, max_listpack_map_bytes; long long stream_node_max_entries; } Server; extern Server server; #define ZSET_MAX_LISTPACK_ENTRIES 128 #define ZSET_MAX_LISTPACK_VALUE 32 void InitRedisTables(); /* The actual Redis Object */ #define OBJ_STRING 0U /* String object. */ #define OBJ_LIST 1U /* List object. */ #define OBJ_SET 2U /* Set object. */ #define OBJ_ZSET 3U /* Sorted set object. */ #define OBJ_HASH 4U /* Hash object. */ #define OBJ_MODULE 5U /* Module object. */ #define OBJ_STREAM 6U /* Stream object. */ /* Objects encoding. Some kind of objects like Strings and Hashes can be * internally represented in multiple ways. The 'encoding' field of the object * is set to one of this fields for this object. */ #define OBJ_ENCODING_RAW 0U /* Raw representation */ #define OBJ_ENCODING_INT 1U /* Encoded as integer */ #define OBJ_ENCODING_HT 2U /* Encoded as hash table */ #define OBJ_ENCODING_ZIPMAP 3U /* Encoded as zipmap */ #define OBJ_ENCODING_LINKEDLIST 4U /* No longer used: old list encoding. */ #define OBJ_ENCODING_ZIPLIST 5U /* Encoded as ziplist */ #define OBJ_ENCODING_INTSET 6U /* Encoded as intset */ #define OBJ_ENCODING_SKIPLIST 7U /* Encoded as skiplist */ #define OBJ_ENCODING_EMBSTR 8U /* Embedded sds string encoding */ // #define OBJ_ENCODING_QUICKLIST 9U /* Encoded as linked list of ziplists */ #define OBJ_ENCODING_STREAM 10U /* Encoded as a radix tree of listpacks */ #define OBJ_ENCODING_LISTPACK 11 /* Encoded as a listpack */ #define OBJ_ENCODING_COMPRESS_INTERNAL 15U /* Kept as lzf compressed, to pass compressed blob to another thread */ #endif /* __REDIS_AUX_H */ ================================================ FILE: src/redis/sds.c ================================================ /* SDSLib 2.0 -- A C dynamic strings library * * Copyright (c) 2006-2015, Salvatore Sanfilippo * Copyright (c) 2015, Oran Agra * Copyright (c) 2015, Redis Labs, Inc * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include "sds.h" #include "sdsalloc.h" const char *SDS_NOINIT = "SDS_NOINIT"; static inline int sdsHdrSize(char type) { switch(type&SDS_TYPE_MASK) { case SDS_TYPE_5: return sizeof(struct sdshdr5); case SDS_TYPE_8: return sizeof(struct sdshdr8); case SDS_TYPE_16: return sizeof(struct sdshdr16); case SDS_TYPE_32: return sizeof(struct sdshdr32); case SDS_TYPE_64: return sizeof(struct sdshdr64); } return 0; } static inline char sdsReqType(size_t string_size) { if (string_size < 1<<5) return SDS_TYPE_5; if (string_size < 1<<8) return SDS_TYPE_8; if (string_size < 1<<16) return SDS_TYPE_16; #if (LONG_MAX == LLONG_MAX) if (string_size < 1ll<<32) return SDS_TYPE_32; return SDS_TYPE_64; #else return SDS_TYPE_32; #endif } static inline size_t sdsTypeMaxSize(char type) { if (type == SDS_TYPE_5) return (1<<5) - 1; if (type == SDS_TYPE_8) return (1<<8) - 1; if (type == SDS_TYPE_16) return (1<<16) - 1; #if (LONG_MAX == LLONG_MAX) if (type == SDS_TYPE_32) return (1ll<<32) - 1; #endif return -1; /* this is equivalent to the max SDS_TYPE_64 or SDS_TYPE_32 */ } /* Create a new sds string with the content specified by the 'init' pointer * and 'initlen'. * If NULL is used for 'init' the string is initialized with zero bytes. * If SDS_NOINIT is used, the buffer is left uninitialized; * * The string is always null-terminated (all the sds strings are, always) so * even if you create an sds string with: * * mystring = sdsnewlen("abc",3); * * You can print the string with printf() as there is an implicit \0 at the * end of the string. However the string is binary safe and can contain * \0 characters in the middle, as the length is stored in the sds header. */ sds _sdsnewlen(const void *init, size_t initlen, int trymalloc) { void *sh; sds s; char type = sdsReqType(initlen); /* Empty strings are usually created in order to append. Use type 8 * since type 5 is not good at this. */ if (type == SDS_TYPE_5 && initlen == 0) type = SDS_TYPE_8; int hdrlen = sdsHdrSize(type); unsigned char *fp; /* flags pointer. */ size_t usable; assert(initlen + hdrlen + 1 > initlen); /* Catch size_t overflow */ sh = trymalloc? s_trymalloc_usable(hdrlen+initlen+1, &usable) : s_malloc_usable(hdrlen+initlen+1, &usable); if (sh == NULL) return NULL; if (init==SDS_NOINIT) init = NULL; else if (!init) memset(sh, 0, hdrlen+initlen+1); s = (char*)sh+hdrlen; fp = ((unsigned char*)s)-1; usable = usable-hdrlen-1; if (usable > sdsTypeMaxSize(type)) usable = sdsTypeMaxSize(type); switch(type) { case SDS_TYPE_5: { *fp = type | (initlen << SDS_TYPE_BITS); break; } case SDS_TYPE_8: { SDS_HDR_VAR(8,s); sh->len = initlen; sh->alloc = usable; *fp = type; break; } case SDS_TYPE_16: { SDS_HDR_VAR(16,s); sh->len = initlen; sh->alloc = usable; *fp = type; break; } case SDS_TYPE_32: { SDS_HDR_VAR(32,s); sh->len = initlen; sh->alloc = usable; *fp = type; break; } case SDS_TYPE_64: { SDS_HDR_VAR(64,s); sh->len = initlen; sh->alloc = usable; *fp = type; break; } } if (initlen && init) memcpy(s, init, initlen); s[initlen] = '\0'; return s; } sds sdsnewlen(const void *init, size_t initlen) { return _sdsnewlen(init, initlen, 0); } /* Create an empty (zero length) sds string. Even in this case the string * always has an implicit null term. */ sds sdsempty(void) { return sdsnewlen("",0); } /* Create a new sds string starting from a null terminated C string. */ sds sdsnew(const char *init) { size_t initlen = (init == NULL) ? 0 : strlen(init); return sdsnewlen(init, initlen); } /* Duplicate an sds string. */ sds sdsdup(const sds s) { return sdsnewlen(s, sdslen(s)); } /* Free an sds string. No operation is performed if 's' is NULL. */ void sdsfree(sds s) { if (s == NULL) return; s_free((char*)s-sdsHdrSize(s[-1])); } /* Set the sds string length to the length as obtained with strlen(), so * considering as content only up to the first null term character. * * This function is useful when the sds string is hacked manually in some * way, like in the following example: * * s = sdsnew("foobar"); * s[2] = '\0'; * sdsupdatelen(s); * printf("%d\n", sdslen(s)); * * The output will be "2", but if we comment out the call to sdsupdatelen() * the output will be "6" as the string was modified but the logical length * remains 6 bytes. */ void sdsupdatelen(sds s) { size_t reallen = strlen(s); sdssetlen(s, reallen); } /* Modify an sds string in-place to make it empty (zero length). * However all the existing buffer is not discarded but set as free space * so that next append operations will not require allocations up to the * number of bytes previously available. */ void sdsclear(sds s) { sdssetlen(s, 0); s[0] = '\0'; } /* Enlarge the free space at the end of the sds string so that the caller * is sure that after calling this function can overwrite up to addlen * bytes after the end of the string, plus one more byte for nul term. * If there's already sufficient free space, this function returns without any * action, if there isn't sufficient free space, it'll allocate what's missing, * and possibly more: * When greedy is 1, enlarge more than needed, to avoid need for future reallocs * on incremental growth. * When greedy is 0, enlarge just enough so that there's free space for 'addlen'. * * Note: this does not change the *length* of the sds string as returned * by sdslen(), but only the free buffer space we have. */ sds _sdsMakeRoomFor(sds s, size_t addlen, int greedy) { void *sh, *newsh; size_t avail = sdsavail(s); size_t len, newlen, reqlen; char type, oldtype = s[-1] & SDS_TYPE_MASK; int hdrlen; size_t usable; /* Return ASAP if there is enough space left. */ if (avail >= addlen) return s; len = sdslen(s); sh = (char*)s-sdsHdrSize(oldtype); reqlen = newlen = (len+addlen); (void)reqlen; assert(newlen > len); /* Catch size_t overflow */ if (greedy == 1) { if (newlen < SDS_MAX_PREALLOC) newlen *= 2; else newlen += SDS_MAX_PREALLOC; } type = sdsReqType(newlen); /* Don't use type 5: the user is appending to the string and type 5 is * not able to remember empty space, so sdsMakeRoomFor() must be called * at every appending operation. */ if (type == SDS_TYPE_5) type = SDS_TYPE_8; hdrlen = sdsHdrSize(type); assert(hdrlen + newlen + 1 > reqlen); /* Catch size_t overflow */ if (oldtype==type) { newsh = s_realloc_usable(sh, hdrlen+newlen+1, &usable); if (newsh == NULL) return NULL; s = (char*)newsh+hdrlen; } else { /* Since the header size changes, need to move the string forward, * and can't use realloc */ newsh = s_malloc_usable(hdrlen+newlen+1, &usable); if (newsh == NULL) return NULL; memcpy((char*)newsh+hdrlen, s, len+1); s_free(sh); s = (char*)newsh+hdrlen; s[-1] = type; sdssetlen(s, len); } usable = usable-hdrlen-1; if (usable > sdsTypeMaxSize(type)) usable = sdsTypeMaxSize(type); sdssetalloc(s, usable); return s; } /* Enlarge the free space at the end of the sds string more than needed, * This is useful to avoid repeated re-allocations when repeatedly appending to the sds. */ sds sdsMakeRoomFor(sds s, size_t addlen) { return _sdsMakeRoomFor(s, addlen, 1); } /* Unlike sdsMakeRoomFor(), this one just grows to the necessary size. */ sds sdsMakeRoomForNonGreedy(sds s, size_t addlen) { return _sdsMakeRoomFor(s, addlen, 0); } /* Reallocate the sds string so that it has no free space at the end. The * contained string remains not altered, but next concatenation operations * will require a reallocation. * * After the call, the passed sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. */ sds sdsRemoveFreeSpace(sds s) { void *sh, *newsh; char type, oldtype = s[-1] & SDS_TYPE_MASK; int hdrlen, oldhdrlen = sdsHdrSize(oldtype); size_t len = sdslen(s); size_t avail = sdsavail(s); sh = (char*)s-oldhdrlen; /* Return ASAP if there is no space left. */ if (avail == 0) return s; /* Check what would be the minimum SDS header that is just good enough to * fit this string. */ type = sdsReqType(len); hdrlen = sdsHdrSize(type); /* If the type is the same, or at least a large enough type is still * required, we just realloc(), letting the allocator to do the copy * only if really needed. Otherwise if the change is huge, we manually * reallocate the string to use the different header type. */ if (oldtype==type || type > SDS_TYPE_8) { newsh = s_realloc(sh, oldhdrlen+len+1); if (newsh == NULL) return NULL; s = (char*)newsh+oldhdrlen; } else { newsh = s_malloc(hdrlen+len+1); if (newsh == NULL) return NULL; memcpy((char*)newsh+hdrlen, s, len+1); s_free(sh); s = (char*)newsh+hdrlen; s[-1] = type; sdssetlen(s, len); } sdssetalloc(s, len); return s; } /* Resize the allocation, this can make the allocation bigger or smaller, * if the size is smaller than currently used len, the data will be truncated */ sds sdsResize(sds s, size_t size) { void *sh, *newsh; char type, oldtype = s[-1] & SDS_TYPE_MASK; int hdrlen, oldhdrlen = sdsHdrSize(oldtype); size_t len = sdslen(s); sh = (char*)s-oldhdrlen; /* Return ASAP if the size is already good. */ if (sdsalloc(s) == size) return s; /* Truncate len if needed. */ if (size < len) len = size; /* Check what would be the minimum SDS header that is just good enough to * fit this string. */ type = sdsReqType(size); /* Don't use type 5, it is not good for strings that are resized. */ if (type == SDS_TYPE_5) type = SDS_TYPE_8; hdrlen = sdsHdrSize(type); /* If the type is the same, or can hold the size in it with low overhead * (larger than SDS_TYPE_8), we just realloc(), letting the allocator * to do the copy only if really needed. Otherwise if the change is * huge, we manually reallocate the string to use the different header * type. */ if (oldtype==type || (type < oldtype && type > SDS_TYPE_8)) { newsh = s_realloc(sh, oldhdrlen+size+1); if (newsh == NULL) return NULL; s = (char*)newsh+oldhdrlen; } else { newsh = s_malloc(hdrlen+size+1); if (newsh == NULL) return NULL; memcpy((char*)newsh+hdrlen, s, len); s_free(sh); s = (char*)newsh+hdrlen; s[-1] = type; } s[len] = 0; sdssetlen(s, len); sdssetalloc(s, size); return s; } /* Return the total size of the allocation of the specified sds string, * including: * 1) The sds header before the pointer. * 2) The string. * 3) The free buffer at the end if any. * 4) The implicit null term. */ size_t sdsAllocSize(sds s) { size_t alloc = sdsalloc(s); return sdsHdrSize(s[-1])+alloc+1; } /* Return the pointer of the actual SDS allocation (normally SDS strings * are referenced by the start of the string buffer). */ void *sdsAllocPtr(sds s) { return (void*) (s-sdsHdrSize(s[-1])); } /* Increment the sds length and decrements the left free space at the * end of the string according to 'incr'. Also set the null term * in the new end of the string. * * This function is used in order to fix the string length after the * user calls sdsMakeRoomFor(), writes something after the end of * the current string, and finally needs to set the new length. * * Note: it is possible to use a negative increment in order to * right-trim the string. * * Usage example: * * Using sdsIncrLen() and sdsMakeRoomFor() it is possible to mount the * following schema, to cat bytes coming from the kernel to the end of an * sds string without copying into an intermediate buffer: * * oldlen = sdslen(s); * s = sdsMakeRoomFor(s, BUFFER_SIZE); * nread = read(fd, s+oldlen, BUFFER_SIZE); * ... check for nread <= 0 and handle it ... * sdsIncrLen(s, nread); */ void sdsIncrLen(sds s, ssize_t incr) { unsigned char flags = s[-1]; size_t len; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: { unsigned char *fp = ((unsigned char*)s)-1; unsigned char oldlen = SDS_TYPE_5_LEN(flags); assert((incr > 0 && oldlen+incr < 32) || (incr < 0 && oldlen >= (unsigned int)(-incr))); *fp = SDS_TYPE_5 | ((oldlen+incr) << SDS_TYPE_BITS); len = oldlen+incr; break; } case SDS_TYPE_8: { SDS_HDR_VAR(8,s); assert((incr >= 0 && sh->alloc-sh->len >= incr) || (incr < 0 && sh->len >= (unsigned int)(-incr))); len = (sh->len += incr); break; } case SDS_TYPE_16: { SDS_HDR_VAR(16,s); assert((incr >= 0 && sh->alloc-sh->len >= incr) || (incr < 0 && sh->len >= (unsigned int)(-incr))); len = (sh->len += incr); break; } case SDS_TYPE_32: { SDS_HDR_VAR(32,s); assert((incr >= 0 && sh->alloc-sh->len >= (unsigned int)incr) || (incr < 0 && sh->len >= (unsigned int)(-incr))); len = (sh->len += incr); break; } case SDS_TYPE_64: { SDS_HDR_VAR(64,s); assert((incr >= 0 && sh->alloc-sh->len >= (uint64_t)incr) || (incr < 0 && sh->len >= (uint64_t)(-incr))); len = (sh->len += incr); break; } default: len = 0; /* Just to avoid compilation warnings. */ } s[len] = '\0'; } /* Grow the sds to have the specified length. Bytes that were not part of * the original length of the sds will be set to zero. * * if the specified length is smaller than the current length, no operation * is performed. */ sds sdsgrowzero(sds s, size_t len) { size_t curlen = sdslen(s); if (len <= curlen) return s; s = sdsMakeRoomFor(s,len-curlen); if (s == NULL) return NULL; /* Make sure added region doesn't contain garbage */ memset(s+curlen,0,(len-curlen+1)); /* also set trailing \0 byte */ sdssetlen(s, len); return s; } /* Append the specified binary-safe string pointed by 't' of 'len' bytes to the * end of the specified sds string 's'. * * After the call, the passed sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. */ sds sdscatlen(sds s, const void *t, size_t len) { size_t curlen = sdslen(s); s = sdsMakeRoomFor(s,len); if (s == NULL) return NULL; memcpy(s+curlen, t, len); sdssetlen(s, curlen+len); s[curlen+len] = '\0'; return s; } /* Append the specified null terminated C string to the sds string 's'. * * After the call, the passed sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. */ sds sdscat(sds s, const char *t) { return sdscatlen(s, t, strlen(t)); } /* Append the specified sds 't' to the existing sds 's'. * * After the call, the modified sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. */ sds sdscatsds(sds s, const sds t) { return sdscatlen(s, t, sdslen(t)); } /* Destructively modify the sds string 's' to hold the specified binary * safe string pointed by 't' of length 'len' bytes. */ sds sdscpylen(sds s, const char *t, size_t len) { if (sdsalloc(s) < len) { s = sdsMakeRoomFor(s,len-sdslen(s)); if (s == NULL) return NULL; } memcpy(s, t, len); s[len] = '\0'; sdssetlen(s, len); return s; } /* Like sdscpylen() but 't' must be a null-terminated string so that the length * of the string is obtained with strlen(). */ sds sdscpy(sds s, const char *t) { return sdscpylen(s, t, strlen(t)); } /* Helper for sdscatlonglong() doing the actual number -> string * conversion. 's' must point to a string with room for at least * SDS_LLSTR_SIZE bytes. * * The function returns the length of the null-terminated string * representation stored at 's'. */ #define SDS_LLSTR_SIZE 21 int sdsll2str(char *s, long long value) { char *p, aux; unsigned long long v; size_t l; /* Generate the string representation, this method produces * a reversed string. */ if (value < 0) { /* Since v is unsigned, if value==LLONG_MIN, -LLONG_MIN will overflow. */ if (value != LLONG_MIN) { v = -value; } else { v = ((unsigned long long)LLONG_MAX) + 1; } } else { v = value; } p = s; do { *p++ = '0'+(v%10); v /= 10; } while(v); if (value < 0) *p++ = '-'; /* Compute length and add null term. */ l = p-s; *p = '\0'; /* Reverse the string. */ p--; while(s < p) { aux = *s; *s = *p; *p = aux; s++; p--; } return l; } /* Identical sdsll2str(), but for unsigned long long type. */ int sdsull2str(char *s, unsigned long long v) { char *p, aux; size_t l; /* Generate the string representation, this method produces * a reversed string. */ p = s; do { *p++ = '0'+(v%10); v /= 10; } while(v); /* Compute length and add null term. */ l = p-s; *p = '\0'; /* Reverse the string. */ p--; while(s < p) { aux = *s; *s = *p; *p = aux; s++; p--; } return l; } /* Create an sds string from a long long value. It is much faster than: * * sdscatprintf(sdsempty(),"%lld\n", value); */ sds sdsfromlonglong(long long value) { char buf[SDS_LLSTR_SIZE + 10]; int len = sdsll2str(buf,value); return sdsnewlen(buf,len); } /* Like sdscatprintf() but gets va_list instead of being variadic. */ sds sdscatvprintf(sds s, const char *fmt, va_list ap) { va_list cpy; char staticbuf[1024], *buf = staticbuf, *t; size_t buflen = strlen(fmt)*2; int bufstrlen; /* We try to start using a static buffer for speed. * If not possible we revert to heap allocation. */ if (buflen > sizeof(staticbuf)) { buf = s_malloc(buflen); if (buf == NULL) return NULL; } else { buflen = sizeof(staticbuf); } /* Alloc enough space for buffer and \0 after failing to * fit the string in the current buffer size. */ while(1) { va_copy(cpy,ap); bufstrlen = vsnprintf(buf, buflen, fmt, cpy); va_end(cpy); if (bufstrlen < 0) { if (buf != staticbuf) s_free(buf); return NULL; } if (((size_t)bufstrlen) >= buflen) { if (buf != staticbuf) s_free(buf); buflen = ((size_t)bufstrlen) + 1; buf = s_malloc(buflen); if (buf == NULL) return NULL; continue; } break; } /* Finally concat the obtained string to the SDS string and return it. */ t = sdscatlen(s, buf, bufstrlen); if (buf != staticbuf) s_free(buf); return t; } /* Append to the sds string 's' a string obtained using printf-alike format * specifier. * * After the call, the modified sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. * * Example: * * s = sdsnew("Sum is: "); * s = sdscatprintf(s,"%d+%d = %d",a,b,a+b). * * Often you need to create a string from scratch with the printf-alike * format. When this is the need, just use sdsempty() as the target string: * * s = sdscatprintf(sdsempty(), "... your format ...", args); */ sds sdscatprintf(sds s, const char *fmt, ...) { va_list ap; char *t; va_start(ap, fmt); t = sdscatvprintf(s,fmt,ap); va_end(ap); return t; } /* This function is similar to sdscatprintf, but much faster as it does * not rely on sprintf() family functions implemented by the libc that * are often very slow. Moreover directly handling the sds string as * new data is concatenated provides a performance improvement. * * However this function only handles an incompatible subset of printf-alike * format specifiers: * * %s - C String * %S - SDS string * %i - signed int * %I - 64 bit signed integer (long long, int64_t) * %u - unsigned int * %U - 64 bit unsigned integer (unsigned long long, uint64_t) * %% - Verbatim "%" character. */ sds sdscatfmt(sds s, char const *fmt, ...) { size_t initlen = sdslen(s); const char *f = fmt; long i; va_list ap; /* To avoid continuous reallocations, let's start with a buffer that * can hold at least two times the format string itself. It's not the * best heuristic but seems to work in practice. */ s = sdsMakeRoomFor(s, strlen(fmt)*2); va_start(ap,fmt); f = fmt; /* Next format specifier byte to process. */ i = initlen; /* Position of the next byte to write to dest str. */ while(*f) { char next, *str; size_t l; long long num; unsigned long long unum; /* Make sure there is always space for at least 1 char. */ if (sdsavail(s)==0) { s = sdsMakeRoomFor(s,1); } switch(*f) { case '%': next = *(f+1); if (next == '\0') break; f++; switch(next) { case 's': case 'S': str = va_arg(ap,char*); l = (next == 's') ? strlen(str) : sdslen(str); if (sdsavail(s) < l) { s = sdsMakeRoomFor(s,l); } memcpy(s+i,str,l); sdsinclen(s,l); i += l; break; case 'i': case 'I': if (next == 'i') num = va_arg(ap,int); else num = va_arg(ap,long long); { char buf[SDS_LLSTR_SIZE]; l = sdsll2str(buf,num); if (sdsavail(s) < l) { s = sdsMakeRoomFor(s,l); } memcpy(s+i,buf,l); sdsinclen(s,l); i += l; } break; case 'u': case 'U': if (next == 'u') unum = va_arg(ap,unsigned int); else unum = va_arg(ap,unsigned long long); { char buf[SDS_LLSTR_SIZE]; l = sdsull2str(buf,unum); if (sdsavail(s) < l) { s = sdsMakeRoomFor(s,l); } memcpy(s+i,buf,l); sdsinclen(s,l); i += l; } break; default: /* Handle %% and generally %. */ s[i++] = next; sdsinclen(s,1); break; } break; default: s[i++] = *f; sdsinclen(s,1); break; } f++; } va_end(ap); /* Add null-term */ s[i] = '\0'; return s; } /* Remove the part of the string from left and from right composed just of * contiguous characters found in 'cset', that is a null terminated C string. * * After the call, the modified sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. * * Example: * * s = sdsnew("AA...AA.a.aa.aHelloWorld :::"); * s = sdstrim(s,"Aa. :"); * printf("%s\n", s); * * Output will be just "HelloWorld". */ sds sdstrim(sds s, const char *cset) { char *end, *sp, *ep; size_t len; sp = s; ep = end = s+sdslen(s)-1; while(sp <= end && strchr(cset, *sp)) sp++; while(ep > sp && strchr(cset, *ep)) ep--; len = (ep-sp)+1; if (s != sp) memmove(s, sp, len); s[len] = '\0'; sdssetlen(s,len); return s; } /* Changes the input string to be a subset of the original. * It does not release the free space in the string, so a call to * sdsRemoveFreeSpace may be wise after. */ void sdssubstr(sds s, size_t start, size_t len) { /* Clamp out of range input */ size_t oldlen = sdslen(s); if (start >= oldlen) start = len = 0; if (len > oldlen-start) len = oldlen-start; /* Move the data */ if (len) memmove(s, s+start, len); s[len] = 0; sdssetlen(s,len); } /* Turn the string into a smaller (or equal) string containing only the * substring specified by the 'start' and 'end' indexes. * * start and end can be negative, where -1 means the last character of the * string, -2 the penultimate character, and so forth. * * The interval is inclusive, so the start and end characters will be part * of the resulting string. * * The string is modified in-place. * * Return value: * -1 (error) if sdslen(s) is larger than maximum positive ssize_t value. * 0 on success. * * Example: * * s = sdsnew("Hello World"); * sdsrange(s,1,-1); => "ello World" */ int sdsrange(sds s, ssize_t start, ssize_t end) { size_t newlen, len = sdslen(s); if (len > SSIZE_MAX) return -1; if (len == 0) return 0; if (start < 0) { start = len+start; if (start < 0) start = 0; } if (end < 0) { end = len+end; if (end < 0) end = 0; } newlen = (start > end) ? 0 : (end-start)+1; if (newlen != 0) { if (start >= (ssize_t)len) { newlen = 0; } else if (end >= (ssize_t)len) { end = len-1; newlen = (start > end) ? 0 : (end-start)+1; } } else { start = 0; } if (start && newlen) memmove(s, s+start, newlen); s[newlen] = 0; sdssetlen(s,newlen); return 0; } /* Apply tolower() to every character of the sds string 's'. */ void sdstolower(sds s) { size_t len = sdslen(s), j; for (j = 0; j < len; j++) s[j] = tolower(s[j]); } /* Apply toupper() to every character of the sds string 's'. */ void sdstoupper(sds s) { size_t len = sdslen(s), j; for (j = 0; j < len; j++) s[j] = toupper(s[j]); } /* Compare two sds strings s1 and s2 with memcmp(). * * Return value: * * positive if s1 > s2. * negative if s1 < s2. * 0 if s1 and s2 are exactly the same binary string. * * If two strings share exactly the same prefix, but one of the two has * additional characters, the longer string is considered to be greater than * the smaller one. */ int sdscmp(const sds s1, const sds s2) { size_t l1, l2, minlen; int cmp; l1 = sdslen(s1); l2 = sdslen(s2); minlen = (l1 < l2) ? l1 : l2; cmp = memcmp(s1,s2,minlen); if (cmp == 0) return l1>l2? 1: (l1". * * After the call, the modified sds string is no longer valid and all the * references must be substituted with the new pointer returned by the call. */ sds sdscatrepr(sds s, const char *p, size_t len) { s = sdscatlen(s,"\"",1); while(len--) { switch(*p) { case '\\': case '"': s = sdscatprintf(s,"\\%c",*p); break; case '\n': s = sdscatlen(s,"\\n",2); break; case '\r': s = sdscatlen(s,"\\r",2); break; case '\t': s = sdscatlen(s,"\\t",2); break; case '\a': s = sdscatlen(s,"\\a",2); break; case '\b': s = sdscatlen(s,"\\b",2); break; default: if (isprint(*p)) s = sdscatprintf(s,"%c",*p); else s = sdscatprintf(s,"\\x%02x",(unsigned char)*p); break; } p++; } return sdscatlen(s,"\"",1); } /* Helper function for sdssplitargs() that returns non zero if 'c' * is a valid hex digit. */ int is_hex_digit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); } /* Helper function for sdssplitargs() that converts a hex digit into an * integer from 0 to 15 */ int hex_digit_to_int(char c) { switch(c) { case '0': return 0; case '1': return 1; case '2': return 2; case '3': return 3; case '4': return 4; case '5': return 5; case '6': return 6; case '7': return 7; case '8': return 8; case '9': return 9; case 'a': case 'A': return 10; case 'b': case 'B': return 11; case 'c': case 'C': return 12; case 'd': case 'D': return 13; case 'e': case 'E': return 14; case 'f': case 'F': return 15; default: return 0; } } /* Split a line into arguments, where every argument can be in the * following programming-language REPL-alike form: * * foo bar "newline are supported\n" and "\xff\x00otherstuff" * * The number of arguments is stored into *argc, and an array * of sds is returned. * * The caller should free the resulting array of sds strings with * sdsfreesplitres(). * * Note that sdscatrepr() is able to convert back a string into * a quoted string in the same format sdssplitargs() is able to parse. * * The function returns the allocated tokens on success, even when the * input string is empty, or NULL if the input contains unbalanced * quotes or closed quotes followed by non space characters * as in: "foo"bar or "foo' */ sds *sdssplitargs(const char *line, int *argc) { const char *p = line; char *current = NULL; char **vector = NULL; *argc = 0; while(1) { /* skip blanks */ while(*p && isspace(*p)) p++; if (*p) { /* get a token */ int inq=0; /* set to 1 if we are in "quotes" */ int insq=0; /* set to 1 if we are in 'single quotes' */ int done=0; if (current == NULL) current = sdsempty(); while(!done) { if (inq) { if (*p == '\\' && *(p+1) == 'x' && is_hex_digit(*(p+2)) && is_hex_digit(*(p+3))) { unsigned char byte; byte = (hex_digit_to_int(*(p+2))*16)+ hex_digit_to_int(*(p+3)); current = sdscatlen(current,(char*)&byte,1); p += 3; } else if (*p == '\\' && *(p+1)) { char c; p++; switch(*p) { case 'n': c = '\n'; break; case 'r': c = '\r'; break; case 't': c = '\t'; break; case 'b': c = '\b'; break; case 'a': c = '\a'; break; default: c = *p; break; } current = sdscatlen(current,&c,1); } else if (*p == '"') { /* closing quote must be followed by a space or * nothing at all. */ if (*(p+1) && !isspace(*(p+1))) goto err; done=1; } else if (!*p) { /* unterminated quotes */ goto err; } else { current = sdscatlen(current,p,1); } } else if (insq) { if (*p == '\\' && *(p+1) == '\'') { p++; current = sdscatlen(current,"'",1); } else if (*p == '\'') { /* closing quote must be followed by a space or * nothing at all. */ if (*(p+1) && !isspace(*(p+1))) goto err; done=1; } else if (!*p) { /* unterminated quotes */ goto err; } else { current = sdscatlen(current,p,1); } } else { switch(*p) { case ' ': case '\n': case '\r': case '\t': case '\0': done=1; break; case '"': inq=1; break; case '\'': insq=1; break; default: current = sdscatlen(current,p,1); break; } } if (*p) p++; } /* add the token to the vector */ vector = s_realloc(vector,((*argc)+1)*sizeof(char*)); vector[*argc] = current; (*argc)++; current = NULL; } else { /* Even on empty input string return something not NULL. */ if (vector == NULL) vector = s_malloc(sizeof(void*)); return vector; } } err: while((*argc)--) sdsfree(vector[*argc]); s_free(vector); if (current) sdsfree(current); *argc = 0; return NULL; } /* Modify the string substituting all the occurrences of the set of * characters specified in the 'from' string to the corresponding character * in the 'to' array. * * For instance: sdsmapchars(mystring, "ho", "01", 2) * will have the effect of turning the string "hello" into "0ell1". * * The function returns the sds string pointer, that is always the same * as the input pointer since no resize is needed. */ sds sdsmapchars(sds s, const char *from, const char *to, size_t setlen) { size_t j, i, l = sdslen(s); for (j = 0; j < l; j++) { for (i = 0; i < setlen; i++) { if (s[j] == from[i]) { s[j] = to[i]; break; } } } return s; } /* Join an array of C strings using the specified separator (also a C string). * Returns the result as an sds string. */ sds sdsjoin(char **argv, int argc, char *sep) { sds join = sdsempty(); int j; for (j = 0; j < argc; j++) { join = sdscat(join, argv[j]); if (j != argc-1) join = sdscat(join,sep); } return join; } /* Like sdsjoin, but joins an array of SDS strings. */ sds sdsjoinsds(sds *argv, int argc, const char *sep, size_t seplen) { sds join = sdsempty(); int j; for (j = 0; j < argc; j++) { join = sdscatsds(join, argv[j]); if (j != argc-1) join = sdscatlen(join,sep,seplen); } return join; } /* Wrappers to the allocators used by SDS. Note that SDS will actually * just use the macros defined into sdsalloc.h in order to avoid to pay * the overhead of function calls. Here we define these wrappers only for * the programs SDS is linked to, if they want to touch the SDS internals * even if they use a different allocator. */ void *sds_malloc(size_t size) { return s_malloc(size); } void *sds_realloc(void *ptr, size_t size) { return s_realloc(ptr,size); } void sds_free(void *ptr) { s_free(ptr); } /* Perform expansion of a template string and return the result as a newly * allocated sds. * * Template variables are specified using curly brackets, e.g. {variable}. * An opening bracket can be quoted by repeating it twice. */ sds sdstemplate(const char *template, sdstemplate_callback_t cb_func, void *cb_arg) { sds res = sdsempty(); const char *p = template; while (*p) { /* Find next variable, copy everything until there */ const char *sv = strchr(p, '{'); if (!sv) { /* Not found: copy till rest of template and stop */ res = sdscat(res, p); break; } else if (sv > p) { /* Found: copy anything up to the beginning of the variable */ res = sdscatlen(res, p, sv - p); } /* Skip into variable name, handle premature end or quoting */ sv++; if (!*sv) goto error; /* Premature end of template */ if (*sv == '{') { /* Quoted '{' */ p = sv + 1; res = sdscat(res, "{"); continue; } /* Find end of variable name, handle premature end of template */ const char *ev = strchr(sv, '}'); if (!ev) goto error; /* Pass variable name to callback and obtain value. If callback failed, * abort. */ sds varname = sdsnewlen(sv, ev - sv); sds value = cb_func(varname, cb_arg); sdsfree(varname); if (!value) goto error; /* Append value to result and continue */ res = sdscat(res, value); sdsfree(value); p = ev + 1; } return res; error: sdsfree(res); return NULL; } #ifdef REDIS_TEST #include #include #include "testhelp.h" #define UNUSED(x) (void)(x) static sds sdsTestTemplateCallback(sds varname, void *arg) { UNUSED(arg); static const char *_var1 = "variable1"; static const char *_var2 = "variable2"; if (!strcmp(varname, _var1)) return sdsnew("value1"); else if (!strcmp(varname, _var2)) return sdsnew("value2"); else return NULL; } int sdsTest(int argc, char **argv, int flags) { UNUSED(argc); UNUSED(argv); UNUSED(flags); { sds x = sdsnew("foo"), y; test_cond("Create a string and obtain the length", sdslen(x) == 3 && memcmp(x,"foo\0",4) == 0); sdsfree(x); x = sdsnewlen("foo",2); test_cond("Create a string with specified length", sdslen(x) == 2 && memcmp(x,"fo\0",3) == 0); x = sdscat(x,"bar"); test_cond("Strings concatenation", sdslen(x) == 5 && memcmp(x,"fobar\0",6) == 0); x = sdscpy(x,"a"); test_cond("sdscpy() against an originally longer string", sdslen(x) == 1 && memcmp(x,"a\0",2) == 0); x = sdscpy(x,"xyzxxxxxxxxxxyyyyyyyyyykkkkkkkkkk"); test_cond("sdscpy() against an originally shorter string", sdslen(x) == 33 && memcmp(x,"xyzxxxxxxxxxxyyyyyyyyyykkkkkkkkkk\0",33) == 0); sdsfree(x); x = sdscatprintf(sdsempty(),"%d",123); test_cond("sdscatprintf() seems working in the base case", sdslen(x) == 3 && memcmp(x,"123\0",4) == 0); sdsfree(x); x = sdscatprintf(sdsempty(),"a%cb",0); test_cond("sdscatprintf() seems working with \\0 inside of result", sdslen(x) == 3 && memcmp(x,"a\0""b\0",4) == 0); { sdsfree(x); char etalon[1024*1024]; for (size_t i = 0; i < sizeof(etalon); i++) { etalon[i] = '0'; } x = sdscatprintf(sdsempty(),"%0*d",(int)sizeof(etalon),0); test_cond("sdscatprintf() can print 1MB", sdslen(x) == sizeof(etalon) && memcmp(x,etalon,sizeof(etalon)) == 0); } sdsfree(x); x = sdsnew("--"); x = sdscatfmt(x, "Hello %s World %I,%I--", "Hi!", LLONG_MIN,LLONG_MAX); test_cond("sdscatfmt() seems working in the base case", sdslen(x) == 60 && memcmp(x,"--Hello Hi! World -9223372036854775808," "9223372036854775807--",60) == 0); printf("[%s]\n",x); sdsfree(x); x = sdsnew("--"); x = sdscatfmt(x, "%u,%U--", UINT_MAX, ULLONG_MAX); test_cond("sdscatfmt() seems working with unsigned numbers", sdslen(x) == 35 && memcmp(x,"--4294967295,18446744073709551615--",35) == 0); sdsfree(x); x = sdsnew(" x "); sdstrim(x," x"); test_cond("sdstrim() works when all chars match", sdslen(x) == 0); sdsfree(x); x = sdsnew(" x "); sdstrim(x," "); test_cond("sdstrim() works when a single char remains", sdslen(x) == 1 && x[0] == 'x'); sdsfree(x); x = sdsnew("xxciaoyyy"); sdstrim(x,"xy"); test_cond("sdstrim() correctly trims characters", sdslen(x) == 4 && memcmp(x,"ciao\0",5) == 0); y = sdsdup(x); sdsrange(y,1,1); test_cond("sdsrange(...,1,1)", sdslen(y) == 1 && memcmp(y,"i\0",2) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,1,-1); test_cond("sdsrange(...,1,-1)", sdslen(y) == 3 && memcmp(y,"iao\0",4) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,-2,-1); test_cond("sdsrange(...,-2,-1)", sdslen(y) == 2 && memcmp(y,"ao\0",3) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,2,1); test_cond("sdsrange(...,2,1)", sdslen(y) == 0 && memcmp(y,"\0",1) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,1,100); test_cond("sdsrange(...,1,100)", sdslen(y) == 3 && memcmp(y,"iao\0",4) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,100,100); test_cond("sdsrange(...,100,100)", sdslen(y) == 0 && memcmp(y,"\0",1) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,4,6); test_cond("sdsrange(...,4,6)", sdslen(y) == 0 && memcmp(y,"\0",1) == 0); sdsfree(y); y = sdsdup(x); sdsrange(y,3,6); test_cond("sdsrange(...,3,6)", sdslen(y) == 1 && memcmp(y,"o\0",2) == 0); sdsfree(y); sdsfree(x); x = sdsnew("foo"); y = sdsnew("foa"); test_cond("sdscmp(foo,foa)", sdscmp(x,y) > 0); sdsfree(y); sdsfree(x); x = sdsnew("bar"); y = sdsnew("bar"); test_cond("sdscmp(bar,bar)", sdscmp(x,y) == 0); sdsfree(y); sdsfree(x); x = sdsnew("aar"); y = sdsnew("bar"); test_cond("sdscmp(bar,bar)", sdscmp(x,y) < 0); sdsfree(y); sdsfree(x); x = sdsnewlen("\a\n\0foo\r",7); y = sdscatrepr(sdsempty(),x,sdslen(x)); test_cond("sdscatrepr(...data...)", memcmp(y,"\"\\a\\n\\x00foo\\r\"",15) == 0); { unsigned int oldfree; char *p; int i; size_t step = 10, j; sdsfree(x); sdsfree(y); x = sdsnew("0"); test_cond("sdsnew() free/len buffers", sdslen(x) == 1 && sdsavail(x) == 0); /* Run the test a few times in order to hit the first two * SDS header types. */ for (i = 0; i < 10; i++) { size_t oldlen = sdslen(x); x = sdsMakeRoomFor(x,step); int type = x[-1]&SDS_TYPE_MASK; test_cond("sdsMakeRoomFor() len", sdslen(x) == oldlen); if (type != SDS_TYPE_5) { test_cond("sdsMakeRoomFor() free", sdsavail(x) >= step); oldfree = sdsavail(x); UNUSED(oldfree); } p = x+oldlen; for (j = 0; j < step; j++) { p[j] = 'A'+j; } sdsIncrLen(x,step); } test_cond("sdsMakeRoomFor() content", memcmp("0ABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJABCDEFGHIJ",x,101) == 0); test_cond("sdsMakeRoomFor() final length",sdslen(x)==101); sdsfree(x); } /* Simple template */ x = sdstemplate("v1={variable1} v2={variable2}", sdsTestTemplateCallback, NULL); test_cond("sdstemplate() normal flow", memcmp(x,"v1=value1 v2=value2",19) == 0); sdsfree(x); /* Template with callback error */ x = sdstemplate("v1={variable1} v3={doesnotexist}", sdsTestTemplateCallback, NULL); test_cond("sdstemplate() with callback error", x == NULL); /* Template with empty var name */ x = sdstemplate("v1={", sdsTestTemplateCallback, NULL); test_cond("sdstemplate() with empty var name", x == NULL); /* Template with truncated var name */ x = sdstemplate("v1={start", sdsTestTemplateCallback, NULL); test_cond("sdstemplate() with truncated var name", x == NULL); /* Template with quoting */ x = sdstemplate("v1={{{variable1}} {{} v2={variable2}", sdsTestTemplateCallback, NULL); test_cond("sdstemplate() with quoting", memcmp(x,"v1={value1} {} v2=value2",24) == 0); sdsfree(x); /* Test sdsresize - extend */ x = sdsnew("1234567890123456789012345678901234567890"); x = sdsResize(x, 200); test_cond("sdsrezie() expand len", sdslen(x) == 40); test_cond("sdsrezie() expand strlen", strlen(x) == 40); test_cond("sdsrezie() expand alloc", sdsalloc(x) == 200); /* Test sdsresize - trim free space */ x = sdsResize(x, 80); test_cond("sdsrezie() shrink len", sdslen(x) == 40); test_cond("sdsrezie() shrink strlen", strlen(x) == 40); test_cond("sdsrezie() shrink alloc", sdsalloc(x) == 80); /* Test sdsresize - crop used space */ x = sdsResize(x, 30); test_cond("sdsrezie() crop len", sdslen(x) == 30); test_cond("sdsrezie() crop strlen", strlen(x) == 30); test_cond("sdsrezie() crop alloc", sdsalloc(x) == 30); /* Test sdsresize - extend to different class */ x = sdsResize(x, 400); test_cond("sdsrezie() expand len", sdslen(x) == 30); test_cond("sdsrezie() expand strlen", strlen(x) == 30); test_cond("sdsrezie() expand alloc", sdsalloc(x) == 400); /* Test sdsresize - shrink to different class */ x = sdsResize(x, 4); test_cond("sdsrezie() crop len", sdslen(x) == 4); test_cond("sdsrezie() crop strlen", strlen(x) == 4); test_cond("sdsrezie() crop alloc", sdsalloc(x) == 4); sdsfree(x); } return 0; } #endif ================================================ FILE: src/redis/sds.h ================================================ /* SDSLib 2.0 -- A C dynamic strings library * * Copyright (c) 2006-2015, Salvatore Sanfilippo * Copyright (c) 2015, Oran Agra * Copyright (c) 2015, Redis Labs, Inc * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __SDS_H #define __SDS_H #define SDS_MAX_PREALLOC (1024*1024) extern const char *SDS_NOINIT; #include #include #include typedef char *sds; /* Note: sdshdr5 is never used, we just access the flags byte directly. * However is here to document the layout of type 5 SDS strings. */ struct __attribute__ ((__packed__)) sdshdr5 { unsigned char flags; /* 3 lsb of type, and 5 msb of string length */ char buf[]; }; struct __attribute__ ((__packed__)) sdshdr8 { uint8_t len; /* used */ uint8_t alloc; /* excluding the header and null terminator */ unsigned char flags; /* 3 lsb of type, 5 unused bits */ char buf[]; }; struct __attribute__ ((__packed__)) sdshdr16 { uint16_t len; /* used */ uint16_t alloc; /* excluding the header and null terminator */ unsigned char flags; /* 3 lsb of type, 5 unused bits */ char buf[]; }; struct __attribute__ ((__packed__)) sdshdr32 { uint32_t len; /* used */ uint32_t alloc; /* excluding the header and null terminator */ unsigned char flags; /* 3 lsb of type, 5 unused bits */ char buf[]; }; struct __attribute__ ((__packed__)) sdshdr64 { uint64_t len; /* used */ uint64_t alloc; /* excluding the header and null terminator */ unsigned char flags; /* 3 lsb of type, 5 unused bits */ char buf[]; }; #define SDS_TYPE_5 0 #define SDS_TYPE_8 1 #define SDS_TYPE_16 2 #define SDS_TYPE_32 3 #define SDS_TYPE_64 4 #define SDS_TYPE_MASK 7 #define SDS_TYPE_BITS 3 #define SDS_HDR(T,s) ((struct sdshdr##T *)((s)-(sizeof(struct sdshdr##T)))) #define SDS_HDR_VAR(T,s) struct sdshdr##T *sh = SDS_HDR(T,s); #define SDS_TYPE_5_LEN(f) ((f)>>SDS_TYPE_BITS) static inline size_t sdslen(const sds s) { unsigned char flags = s[-1]; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: return SDS_TYPE_5_LEN(flags); case SDS_TYPE_8: return SDS_HDR(8,s)->len; case SDS_TYPE_16: return SDS_HDR(16,s)->len; case SDS_TYPE_32: return SDS_HDR(32,s)->len; case SDS_TYPE_64: return SDS_HDR(64,s)->len; } return 0; } static inline size_t sdsavail(const sds s) { unsigned char flags = s[-1]; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: { return 0; } case SDS_TYPE_8: { SDS_HDR_VAR(8,s); return sh->alloc - sh->len; } case SDS_TYPE_16: { SDS_HDR_VAR(16,s); return sh->alloc - sh->len; } case SDS_TYPE_32: { SDS_HDR_VAR(32,s); return sh->alloc - sh->len; } case SDS_TYPE_64: { SDS_HDR_VAR(64,s); return sh->alloc - sh->len; } } return 0; } static inline void sdssetlen(sds s, size_t newlen) { unsigned char flags = s[-1]; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: { unsigned char *fp = ((unsigned char*)s)-1; *fp = SDS_TYPE_5 | (newlen << SDS_TYPE_BITS); } break; case SDS_TYPE_8: SDS_HDR(8,s)->len = newlen; break; case SDS_TYPE_16: SDS_HDR(16,s)->len = newlen; break; case SDS_TYPE_32: SDS_HDR(32,s)->len = newlen; break; case SDS_TYPE_64: SDS_HDR(64,s)->len = newlen; break; } } static inline void sdsinclen(sds s, size_t inc) { unsigned char flags = s[-1]; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: { unsigned char *fp = ((unsigned char*)s)-1; unsigned char newlen = SDS_TYPE_5_LEN(flags)+inc; *fp = SDS_TYPE_5 | (newlen << SDS_TYPE_BITS); } break; case SDS_TYPE_8: SDS_HDR(8,s)->len += inc; break; case SDS_TYPE_16: SDS_HDR(16,s)->len += inc; break; case SDS_TYPE_32: SDS_HDR(32,s)->len += inc; break; case SDS_TYPE_64: SDS_HDR(64,s)->len += inc; break; } } /* sdsalloc() = sdsavail() + sdslen() */ static inline size_t sdsalloc(const sds s) { unsigned char flags = s[-1]; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: return SDS_TYPE_5_LEN(flags); case SDS_TYPE_8: return SDS_HDR(8,s)->alloc; case SDS_TYPE_16: return SDS_HDR(16,s)->alloc; case SDS_TYPE_32: return SDS_HDR(32,s)->alloc; case SDS_TYPE_64: return SDS_HDR(64,s)->alloc; } return 0; } static inline void sdssetalloc(sds s, size_t newlen) { unsigned char flags = s[-1]; switch(flags&SDS_TYPE_MASK) { case SDS_TYPE_5: /* Nothing to do, this type has no total allocation info. */ break; case SDS_TYPE_8: SDS_HDR(8,s)->alloc = newlen; break; case SDS_TYPE_16: SDS_HDR(16,s)->alloc = newlen; break; case SDS_TYPE_32: SDS_HDR(32,s)->alloc = newlen; break; case SDS_TYPE_64: SDS_HDR(64,s)->alloc = newlen; break; } } sds sdsnewlen(const void *init, size_t initlen); sds sdsnew(const char *init); sds sdsempty(void); sds sdsdup(const sds s); void sdsfree(sds s); sds sdsgrowzero(sds s, size_t len); sds sdscatlen(sds s, const void *t, size_t len); sds sdscat(sds s, const char *t); sds sdscatsds(sds s, const sds t); sds sdscpylen(sds s, const char *t, size_t len); sds sdscpy(sds s, const char *t); sds sdscatvprintf(sds s, const char *fmt, va_list ap); #ifdef __GNUC__ sds sdscatprintf(sds s, const char *fmt, ...) __attribute__((format(printf, 2, 3))); #else sds sdscatprintf(sds s, const char *fmt, ...); #endif sds sdscatfmt(sds s, char const *fmt, ...); sds sdstrim(sds s, const char *cset); void sdssubstr(sds s, size_t start, size_t len); int sdsrange(sds s, ssize_t start, ssize_t end); void sdsupdatelen(sds s); void sdsclear(sds s); int sdscmp(const sds s1, const sds s2); sds *sdssplitlen(const char *s, ssize_t len, const char *sep, int seplen, int *count); void sdsfreesplitres(sds *tokens, int count); void sdstolower(sds s); void sdstoupper(sds s); sds sdsfromlonglong(long long value); sds sdscatrepr(sds s, const char *p, size_t len); sds *sdssplitargs(const char *line, int *argc); sds sdsmapchars(sds s, const char *from, const char *to, size_t setlen); sds sdsjoin(char **argv, int argc, char *sep); sds sdsjoinsds(sds *argv, int argc, const char *sep, size_t seplen); /* Callback for sdstemplate. The function gets called by sdstemplate * every time a variable needs to be expanded. The variable name is * provided as variable, and the callback is expected to return a * substitution value. Returning a NULL indicates an error. */ typedef sds (*sdstemplate_callback_t)(const sds variable, void *arg); sds sdstemplate(const char *templ, sdstemplate_callback_t cb_func, void *cb_arg); /* Low level functions exposed to the user API */ sds sdsMakeRoomFor(sds s, size_t addlen); sds sdsMakeRoomForNonGreedy(sds s, size_t addlen); void sdsIncrLen(sds s, ssize_t incr); sds sdsRemoveFreeSpace(sds s); sds sdsResize(sds s, size_t size); size_t sdsAllocSize(sds s); void *sdsAllocPtr(sds s); /* Export the allocator used by SDS to the program using SDS. * Sometimes the program SDS is linked to, may use a different set of * allocators, but may want to allocate or free things that SDS will * respectively free or allocate. */ void *sds_malloc(size_t size); void *sds_realloc(void *ptr, size_t size); void sds_free(void *ptr); #ifdef REDIS_TEST int sdsTest(int argc, char *argv[], int flags); #endif #endif ================================================ FILE: src/redis/sdsalloc.h ================================================ /* SDSLib 2.0 -- A C dynamic strings library * * Copyright (c) 2006-2015, Salvatore Sanfilippo * Copyright (c) 2015, Redis Labs, Inc * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ /* SDS allocator selection. * * This file is used in order to change the SDS allocator at compile time. * Just define the following defines to what you want to use. Also add * the include of your alternate allocator if needed (not needed in order * to use the default libc allocator). */ #ifndef __SDS_ALLOC_H__ #define __SDS_ALLOC_H__ #include "zmalloc.h" #define s_malloc zmalloc #define s_realloc zrealloc #define s_calloc zcalloc #define s_trymalloc ztrymalloc #define s_tryrealloc ztryrealloc #define s_free zfree #define s_malloc_usable zmalloc_usable #define s_realloc_usable zrealloc_usable #define s_trymalloc_usable ztrymalloc_usable #define s_tryrealloc_usable ztryrealloc_usable #define s_free_usable zfree_usable #endif ================================================ FILE: src/redis/siphash.c ================================================ /* SipHash reference C implementation Copyright (c) 2012-2016 Jean-Philippe Aumasson Copyright (c) 2012-2014 Daniel J. Bernstein Copyright (c) 2017 Salvatore Sanfilippo To the extent possible under law, the author(s) have dedicated all copyright and related and neighboring rights to this software to the public domain worldwide. This software is distributed without any warranty. You should have received a copy of the CC0 Public Domain Dedication along with this software. If not, see . ---------------------------------------------------------------------------- This version was modified by Salvatore Sanfilippo in the following ways: 1. We use SipHash 1-2. This is not believed to be as strong as the suggested 2-4 variant, but AFAIK there are not trivial attacks against this reduced-rounds version, and it runs at the same speed as Murmurhash2 that we used previously, while the 2-4 variant slowed down Redis by a 4% figure more or less. 2. Hard-code rounds in the hope the compiler can optimize it more in this raw from. Anyway we always want the standard 2-4 variant. 3. Modify the prototype and implementation so that the function directly returns an uint64_t value, the hash itself, instead of receiving an output buffer. This also means that the output size is set to 8 bytes and the 16 bytes output code handling was removed. 4. Provide a case insensitive variant to be used when hashing strings that must be considered identical by the hash table regardless of the case. If we don't have directly a case insensitive hash function, we need to perform a text transformation in some temporary buffer, which is costly. 5. Remove debugging code. 6. Modified the original test.c file to be a stand-alone function testing the function in the new form (returning an uint64_t) using just the relevant test vector. */ #include #include #include #include #include /* Fast tolower() alike function that does not care about locale * but just returns a-z instead of A-Z. */ int siptlw(int c) { if (c >= 'A' && c <= 'Z') { return c+('a'-'A'); } else { return c; } } #if defined(__has_attribute) #if __has_attribute(no_sanitize) #define NO_SANITIZE(sanitizer) __attribute__((no_sanitize(sanitizer))) #endif #endif #if !defined(NO_SANITIZE) #define NO_SANITIZE(sanitizer) #endif /* Test of the CPU is Little Endian and supports not aligned accesses. * Two interesting conditions to speedup the function that happen to be * in most of x86 servers. */ #if defined(__X86_64__) || defined(__x86_64__) || defined (__i386__) \ || defined (__aarch64__) || defined (__arm64__) #define UNALIGNED_LE_CPU #endif #define ROTL(x, b) (uint64_t)(((x) << (b)) | ((x) >> (64 - (b)))) #define U32TO8_LE(p, v) \ (p)[0] = (uint8_t)((v)); \ (p)[1] = (uint8_t)((v) >> 8); \ (p)[2] = (uint8_t)((v) >> 16); \ (p)[3] = (uint8_t)((v) >> 24); #define U64TO8_LE(p, v) \ U32TO8_LE((p), (uint32_t)((v))); \ U32TO8_LE((p) + 4, (uint32_t)((v) >> 32)); #ifdef UNALIGNED_LE_CPU #define U8TO64_LE(p) (*((uint64_t*)(p))) #else #define U8TO64_LE(p) \ (((uint64_t)((p)[0])) | ((uint64_t)((p)[1]) << 8) | \ ((uint64_t)((p)[2]) << 16) | ((uint64_t)((p)[3]) << 24) | \ ((uint64_t)((p)[4]) << 32) | ((uint64_t)((p)[5]) << 40) | \ ((uint64_t)((p)[6]) << 48) | ((uint64_t)((p)[7]) << 56)) #endif #define U8TO64_LE_NOCASE(p) \ (((uint64_t)(siptlw((p)[0]))) | \ ((uint64_t)(siptlw((p)[1])) << 8) | \ ((uint64_t)(siptlw((p)[2])) << 16) | \ ((uint64_t)(siptlw((p)[3])) << 24) | \ ((uint64_t)(siptlw((p)[4])) << 32) | \ ((uint64_t)(siptlw((p)[5])) << 40) | \ ((uint64_t)(siptlw((p)[6])) << 48) | \ ((uint64_t)(siptlw((p)[7])) << 56)) #define SIPROUND \ do { \ v0 += v1; \ v1 = ROTL(v1, 13); \ v1 ^= v0; \ v0 = ROTL(v0, 32); \ v2 += v3; \ v3 = ROTL(v3, 16); \ v3 ^= v2; \ v0 += v3; \ v3 = ROTL(v3, 21); \ v3 ^= v0; \ v2 += v1; \ v1 = ROTL(v1, 17); \ v1 ^= v2; \ v2 = ROTL(v2, 32); \ } while (0) NO_SANITIZE("alignment") uint64_t siphash(const uint8_t *in, const size_t inlen, const uint8_t *k) { #ifndef UNALIGNED_LE_CPU uint64_t hash; uint8_t *out = (uint8_t*) &hash; #endif uint64_t v0 = 0x736f6d6570736575ULL; uint64_t v1 = 0x646f72616e646f6dULL; uint64_t v2 = 0x6c7967656e657261ULL; uint64_t v3 = 0x7465646279746573ULL; uint64_t k0 = U8TO64_LE(k); uint64_t k1 = U8TO64_LE(k + 8); uint64_t m; const uint8_t *end = in + inlen - (inlen % sizeof(uint64_t)); const int left = inlen & 7; uint64_t b = ((uint64_t)inlen) << 56; v3 ^= k1; v2 ^= k0; v1 ^= k1; v0 ^= k0; for (; in != end; in += 8) { m = U8TO64_LE(in); v3 ^= m; SIPROUND; v0 ^= m; } switch (left) { case 7: b |= ((uint64_t)in[6]) << 48; /* fall-thru */ case 6: b |= ((uint64_t)in[5]) << 40; /* fall-thru */ case 5: b |= ((uint64_t)in[4]) << 32; /* fall-thru */ case 4: b |= ((uint64_t)in[3]) << 24; /* fall-thru */ case 3: b |= ((uint64_t)in[2]) << 16; /* fall-thru */ case 2: b |= ((uint64_t)in[1]) << 8; /* fall-thru */ case 1: b |= ((uint64_t)in[0]); break; case 0: break; } v3 ^= b; SIPROUND; v0 ^= b; v2 ^= 0xff; SIPROUND; SIPROUND; b = v0 ^ v1 ^ v2 ^ v3; #ifndef UNALIGNED_LE_CPU U64TO8_LE(out, b); return hash; #else return b; #endif } NO_SANITIZE("alignment") uint64_t siphash_nocase(const uint8_t *in, const size_t inlen, const uint8_t *k) { #ifndef UNALIGNED_LE_CPU uint64_t hash; uint8_t *out = (uint8_t*) &hash; #endif uint64_t v0 = 0x736f6d6570736575ULL; uint64_t v1 = 0x646f72616e646f6dULL; uint64_t v2 = 0x6c7967656e657261ULL; uint64_t v3 = 0x7465646279746573ULL; uint64_t k0 = U8TO64_LE(k); uint64_t k1 = U8TO64_LE(k + 8); uint64_t m; const uint8_t *end = in + inlen - (inlen % sizeof(uint64_t)); const int left = inlen & 7; uint64_t b = ((uint64_t)inlen) << 56; v3 ^= k1; v2 ^= k0; v1 ^= k1; v0 ^= k0; for (; in != end; in += 8) { m = U8TO64_LE_NOCASE(in); v3 ^= m; SIPROUND; v0 ^= m; } switch (left) { case 7: b |= ((uint64_t)siptlw(in[6])) << 48; /* fall-thru */ case 6: b |= ((uint64_t)siptlw(in[5])) << 40; /* fall-thru */ case 5: b |= ((uint64_t)siptlw(in[4])) << 32; /* fall-thru */ case 4: b |= ((uint64_t)siptlw(in[3])) << 24; /* fall-thru */ case 3: b |= ((uint64_t)siptlw(in[2])) << 16; /* fall-thru */ case 2: b |= ((uint64_t)siptlw(in[1])) << 8; /* fall-thru */ case 1: b |= ((uint64_t)siptlw(in[0])); break; case 0: break; } v3 ^= b; SIPROUND; v0 ^= b; v2 ^= 0xff; SIPROUND; SIPROUND; b = v0 ^ v1 ^ v2 ^ v3; #ifndef UNALIGNED_LE_CPU U64TO8_LE(out, b); return hash; #else return b; #endif } /* --------------------------------- TEST ------------------------------------ */ #ifdef SIPHASH_TEST const uint8_t vectors_sip64[64][8] = { { 0x31, 0x0e, 0x0e, 0xdd, 0x47, 0xdb, 0x6f, 0x72, }, { 0xfd, 0x67, 0xdc, 0x93, 0xc5, 0x39, 0xf8, 0x74, }, { 0x5a, 0x4f, 0xa9, 0xd9, 0x09, 0x80, 0x6c, 0x0d, }, { 0x2d, 0x7e, 0xfb, 0xd7, 0x96, 0x66, 0x67, 0x85, }, { 0xb7, 0x87, 0x71, 0x27, 0xe0, 0x94, 0x27, 0xcf, }, { 0x8d, 0xa6, 0x99, 0xcd, 0x64, 0x55, 0x76, 0x18, }, { 0xce, 0xe3, 0xfe, 0x58, 0x6e, 0x46, 0xc9, 0xcb, }, { 0x37, 0xd1, 0x01, 0x8b, 0xf5, 0x00, 0x02, 0xab, }, { 0x62, 0x24, 0x93, 0x9a, 0x79, 0xf5, 0xf5, 0x93, }, { 0xb0, 0xe4, 0xa9, 0x0b, 0xdf, 0x82, 0x00, 0x9e, }, { 0xf3, 0xb9, 0xdd, 0x94, 0xc5, 0xbb, 0x5d, 0x7a, }, { 0xa7, 0xad, 0x6b, 0x22, 0x46, 0x2f, 0xb3, 0xf4, }, { 0xfb, 0xe5, 0x0e, 0x86, 0xbc, 0x8f, 0x1e, 0x75, }, { 0x90, 0x3d, 0x84, 0xc0, 0x27, 0x56, 0xea, 0x14, }, { 0xee, 0xf2, 0x7a, 0x8e, 0x90, 0xca, 0x23, 0xf7, }, { 0xe5, 0x45, 0xbe, 0x49, 0x61, 0xca, 0x29, 0xa1, }, { 0xdb, 0x9b, 0xc2, 0x57, 0x7f, 0xcc, 0x2a, 0x3f, }, { 0x94, 0x47, 0xbe, 0x2c, 0xf5, 0xe9, 0x9a, 0x69, }, { 0x9c, 0xd3, 0x8d, 0x96, 0xf0, 0xb3, 0xc1, 0x4b, }, { 0xbd, 0x61, 0x79, 0xa7, 0x1d, 0xc9, 0x6d, 0xbb, }, { 0x98, 0xee, 0xa2, 0x1a, 0xf2, 0x5c, 0xd6, 0xbe, }, { 0xc7, 0x67, 0x3b, 0x2e, 0xb0, 0xcb, 0xf2, 0xd0, }, { 0x88, 0x3e, 0xa3, 0xe3, 0x95, 0x67, 0x53, 0x93, }, { 0xc8, 0xce, 0x5c, 0xcd, 0x8c, 0x03, 0x0c, 0xa8, }, { 0x94, 0xaf, 0x49, 0xf6, 0xc6, 0x50, 0xad, 0xb8, }, { 0xea, 0xb8, 0x85, 0x8a, 0xde, 0x92, 0xe1, 0xbc, }, { 0xf3, 0x15, 0xbb, 0x5b, 0xb8, 0x35, 0xd8, 0x17, }, { 0xad, 0xcf, 0x6b, 0x07, 0x63, 0x61, 0x2e, 0x2f, }, { 0xa5, 0xc9, 0x1d, 0xa7, 0xac, 0xaa, 0x4d, 0xde, }, { 0x71, 0x65, 0x95, 0x87, 0x66, 0x50, 0xa2, 0xa6, }, { 0x28, 0xef, 0x49, 0x5c, 0x53, 0xa3, 0x87, 0xad, }, { 0x42, 0xc3, 0x41, 0xd8, 0xfa, 0x92, 0xd8, 0x32, }, { 0xce, 0x7c, 0xf2, 0x72, 0x2f, 0x51, 0x27, 0x71, }, { 0xe3, 0x78, 0x59, 0xf9, 0x46, 0x23, 0xf3, 0xa7, }, { 0x38, 0x12, 0x05, 0xbb, 0x1a, 0xb0, 0xe0, 0x12, }, { 0xae, 0x97, 0xa1, 0x0f, 0xd4, 0x34, 0xe0, 0x15, }, { 0xb4, 0xa3, 0x15, 0x08, 0xbe, 0xff, 0x4d, 0x31, }, { 0x81, 0x39, 0x62, 0x29, 0xf0, 0x90, 0x79, 0x02, }, { 0x4d, 0x0c, 0xf4, 0x9e, 0xe5, 0xd4, 0xdc, 0xca, }, { 0x5c, 0x73, 0x33, 0x6a, 0x76, 0xd8, 0xbf, 0x9a, }, { 0xd0, 0xa7, 0x04, 0x53, 0x6b, 0xa9, 0x3e, 0x0e, }, { 0x92, 0x59, 0x58, 0xfc, 0xd6, 0x42, 0x0c, 0xad, }, { 0xa9, 0x15, 0xc2, 0x9b, 0xc8, 0x06, 0x73, 0x18, }, { 0x95, 0x2b, 0x79, 0xf3, 0xbc, 0x0a, 0xa6, 0xd4, }, { 0xf2, 0x1d, 0xf2, 0xe4, 0x1d, 0x45, 0x35, 0xf9, }, { 0x87, 0x57, 0x75, 0x19, 0x04, 0x8f, 0x53, 0xa9, }, { 0x10, 0xa5, 0x6c, 0xf5, 0xdf, 0xcd, 0x9a, 0xdb, }, { 0xeb, 0x75, 0x09, 0x5c, 0xcd, 0x98, 0x6c, 0xd0, }, { 0x51, 0xa9, 0xcb, 0x9e, 0xcb, 0xa3, 0x12, 0xe6, }, { 0x96, 0xaf, 0xad, 0xfc, 0x2c, 0xe6, 0x66, 0xc7, }, { 0x72, 0xfe, 0x52, 0x97, 0x5a, 0x43, 0x64, 0xee, }, { 0x5a, 0x16, 0x45, 0xb2, 0x76, 0xd5, 0x92, 0xa1, }, { 0xb2, 0x74, 0xcb, 0x8e, 0xbf, 0x87, 0x87, 0x0a, }, { 0x6f, 0x9b, 0xb4, 0x20, 0x3d, 0xe7, 0xb3, 0x81, }, { 0xea, 0xec, 0xb2, 0xa3, 0x0b, 0x22, 0xa8, 0x7f, }, { 0x99, 0x24, 0xa4, 0x3c, 0xc1, 0x31, 0x57, 0x24, }, { 0xbd, 0x83, 0x8d, 0x3a, 0xaf, 0xbf, 0x8d, 0xb7, }, { 0x0b, 0x1a, 0x2a, 0x32, 0x65, 0xd5, 0x1a, 0xea, }, { 0x13, 0x50, 0x79, 0xa3, 0x23, 0x1c, 0xe6, 0x60, }, { 0x93, 0x2b, 0x28, 0x46, 0xe4, 0xd7, 0x06, 0x66, }, { 0xe1, 0x91, 0x5f, 0x5c, 0xb1, 0xec, 0xa4, 0x6c, }, { 0xf3, 0x25, 0x96, 0x5c, 0xa1, 0x6d, 0x62, 0x9f, }, { 0x57, 0x5f, 0xf2, 0x8e, 0x60, 0x38, 0x1b, 0xe5, }, { 0x72, 0x45, 0x06, 0xeb, 0x4c, 0x32, 0x8a, 0x95, }, }; /* Test siphash using a test vector. Returns 0 if the function passed * all the tests, otherwise 1 is returned. * * IMPORTANT: The test vector is for SipHash 2-4. Before running * the test revert back the siphash() function to 2-4 rounds since * now it uses 1-2 rounds. */ int siphash_test(void) { uint8_t in[64], k[16]; int i; int fails = 0; for (i = 0; i < 16; ++i) k[i] = i; for (i = 0; i < 64; ++i) { in[i] = i; uint64_t hash = siphash(in, i, k); const uint8_t *v = NULL; v = (uint8_t *)vectors_sip64; if (memcmp(&hash, v + (i * 8), 8)) { /* printf("fail for %d bytes\n", i); */ fails++; } } /* Run a few basic tests with the case insensitive version. */ uint64_t h1, h2; h1 = siphash((uint8_t*)"hello world",11,(uint8_t*)"1234567812345678"); h2 = siphash_nocase((uint8_t*)"hello world",11,(uint8_t*)"1234567812345678"); if (h1 != h2) fails++; h1 = siphash((uint8_t*)"hello world",11,(uint8_t*)"1234567812345678"); h2 = siphash_nocase((uint8_t*)"HELLO world",11,(uint8_t*)"1234567812345678"); if (h1 != h2) fails++; h1 = siphash((uint8_t*)"HELLO world",11,(uint8_t*)"1234567812345678"); h2 = siphash_nocase((uint8_t*)"HELLO world",11,(uint8_t*)"1234567812345678"); if (h1 == h2) fails++; if (!fails) return 0; return 1; } int main(void) { if (siphash_test() == 0) { printf("SipHash test: OK\n"); return 0; } else { printf("SipHash test: FAILED\n"); return 1; } } #endif ================================================ FILE: src/redis/stream.h ================================================ #ifndef STREAM_H #define STREAM_H #include "util.h" #include "rax.h" #include "sds.h" #include "listpack.h" typedef struct redisObject robj; /* Stream item ID: a 128 bit number composed of a milliseconds time and * a sequence counter. IDs generated in the same millisecond (or in a past * millisecond if the clock jumped backward) will use the millisecond time * of the latest generated ID and an incremented sequence. */ typedef struct streamID { uint64_t ms; /* Unix time in milliseconds. */ uint64_t seq; /* Sequence number. */ } streamID; typedef struct stream { struct rax *rax; /* The radix tree holding the stream. */ uint64_t length; /* Current number of elements inside this stream. */ streamID last_id; /* Zero if there are yet no items. */ streamID first_id; /* The first non-tombstone entry, zero if empty. */ streamID max_deleted_entry_id; /* The maximal ID that was deleted. */ uint64_t entries_added; /* All time count of elements added. */ struct rax *cgroups; /* Consumer groups dictionary: name -> streamCG */ } stream; /* We define an iterator to iterate stream items in an abstract way, without * caring about the radix tree + listpack representation. Technically speaking * the iterator is only used inside streamReplyWithRange(), so could just * be implemented inside the function, but practically there is the AOF * rewriting code that also needs to iterate the stream to emit the XADD * commands. */ typedef struct streamIterator { stream *stream; /* The stream we are iterating. */ streamID master_id; /* ID of the master entry at listpack head. */ uint64_t master_fields_count; /* Master entries # of fields. */ unsigned char *master_fields_start; /* Master entries start in listpack. */ unsigned char *master_fields_ptr; /* Master field to emit next. */ int entry_flags; /* Flags of entry we are emitting. */ int rev; /* True if iterating end to start (reverse). */ int skip_tombstones; /* True if not emitting tombstone entries. */ uint64_t start_key[2]; /* Start key as 128 bit big endian. */ uint64_t end_key[2]; /* End key as 128 bit big endian. */ raxIterator ri; /* Rax iterator. */ unsigned char *lp; /* Current listpack. */ unsigned char *lp_ele; /* Current listpack cursor. */ unsigned char *lp_flags; /* Current entry flags pointer. */ /* Buffers used to hold the string of lpGet() when the element is * integer encoded, so that there is no string representation of the * element inside the listpack itself. */ unsigned char field_buf[LP_INTBUF_SIZE]; unsigned char value_buf[LP_INTBUF_SIZE]; } streamIterator; /* Consumer group. */ typedef struct streamCG { streamID last_id; /* Last delivered (not acknowledged) ID for this group. Consumers that will just ask for more messages will served with IDs > than this. */ long long entries_read; /* In a perfect world (CG starts at 0-0, no dels, no XGROUP SETID, ...), this is the total number of group reads. In the real world, the reasoning behind this value is detailed at the top comment of streamEstimateDistanceFromFirstEverEntry(). */ rax *pel; /* Pending entries list. This is a radix tree that has every message delivered to consumers (without the NOACK option) that was yet not acknowledged as processed. The key of the radix tree is the ID as a 64 bit big endian number, while the associated value is a streamNACK structure.*/ rax *consumers; /* A radix tree representing the consumers by name and their associated representation in the form of streamConsumer structures. */ } streamCG; /* A specific consumer in a consumer group. */ typedef struct streamConsumer { mstime_t seen_time; /* Last time this consumer tried to perform an action (attempted reading/claiming). */ mstime_t active_time; /* Last time this consumer was active (successful reading/claiming). */ sds name; /* Consumer name. This is how the consumer will be identified in the consumer group protocol. Case sensitive. */ rax *pel; /* Consumer specific pending entries list: all the pending messages delivered to this consumer not yet acknowledged. Keys are big endian message IDs, while values are the same streamNACK structure referenced in the "pel" of the consumer group structure itself, so the value is shared. */ } streamConsumer; /* Pending (yet not acknowledged) message in a consumer group. */ typedef struct streamNACK { mstime_t delivery_time; /* Last time this message was delivered. */ uint64_t delivery_count; /* Number of times this message was delivered.*/ streamConsumer *consumer; /* The consumer this message was delivered to in the last delivery. */ } streamNACK; typedef struct { /* XADD options */ streamID id; /* User-provided ID, for XADD only. */ int id_given; /* Was an ID different than "*" specified? for XADD only. */ int seq_given; /* Was an ID different than "ms-*" specified? for XADD only. */ int no_mkstream; /* if set to 1 do not create new stream */ /* XADD + XTRIM common options */ int trim_strategy; /* TRIM_STRATEGY_* */ int trim_strategy_arg_idx; /* Index of the count in MAXLEN/MINID, for rewriting. */ int approx_trim; /* If 1 only delete whole radix tree nodes, so * the trim argument is not applied verbatim. */ long long limit; /* Maximum amount of entries to trim. If 0, no limitation * on the amount of trimming work is enforced. */ /* TRIM_STRATEGY_MAXLEN options */ long long maxlen; /* After trimming, leave stream at this length . */ /* TRIM_STRATEGY_MINID options */ streamID minid; /* Trim by ID (No stream entries with ID < 'minid' will remain) */ } streamAddTrimArgs; /* Prototypes of exported APIs. */ // struct client; /* Flags for streamCreateConsumer */ #define SCC_DEFAULT 0 #define SCC_NO_NOTIFY (1 << 0) /* Do not notify key space if consumer created */ #define SCC_NO_DIRTIFY (1 << 1) /* Do not dirty++ if consumer created */ #define SCG_INVALID_ENTRIES_READ -1 #define SCG_INVALID_LAG -1 #define TRIM_STRATEGY_NONE 0 #define TRIM_STRATEGY_MAXLEN 1 #define TRIM_STRATEGY_MINID 2 /* Every stream item inside the listpack, has a flags field that is used to * mark the entry as deleted, or having the same field as the "master" * entry at the start of the listpack. */ #define STREAM_ITEM_FLAG_NONE 0 /* No special flags. */ #define STREAM_ITEM_FLAG_DELETED (1 << 0) /* Entry is deleted. Skip it. */ #define STREAM_ITEM_FLAG_SAMEFIELDS (1 << 1) /* Same fields as primary entry. */ void streamIteratorStart(streamIterator *si, stream *s, streamID *start, streamID *end, int rev); int streamIteratorGetID(streamIterator *si, streamID *id, int64_t *numfields); void streamIteratorGetField(streamIterator *si, unsigned char **fieldptr, unsigned char **valueptr, int64_t *fieldlen, int64_t *valuelen); void streamIteratorStop(streamIterator *si); streamCG *streamCreateCG(stream *s, const char *name, size_t namelen, streamID *id, long long entries_read); void streamDecodeID(void *buf, streamID *id); int streamCompareID(streamID *a, streamID *b); void streamFreeNACK(streamNACK *na); void streamGetEdgeID(stream *s, int first, int skip_tombstones, streamID *edge_id); long long streamEstimateDistanceFromFirstEverEntry(stream *s, streamID *id); #endif ================================================ FILE: src/redis/t_stream.c ================================================ /* * Copyright (c) 2017, Redis Ltd. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include "endianconv.h" #include "stream.h" #include "redis_aux.h" #include "zmalloc.h" /* For stream commands that require multiple IDs * when the number of IDs is less than 'STREAMID_STATIC_VECTOR_LEN', * avoid malloc allocation.*/ #define STREAMID_STATIC_VECTOR_LEN 8 /* Max pre-allocation for listpack. This is done to avoid abuse of a user * setting stream_node_max_bytes to a huge number. */ #define STREAM_LISTPACK_MAX_PRE_ALLOCATE 4096 /* Don't let listpacks grow too big, even if the user config allows it. * doing so can lead to an overflow (trying to store more than 32bit length * into the listpack header), or actually an assertion since lpInsert * will return NULL. */ #define STREAM_LISTPACK_MAX_SIZE (1 << 30) /* ----------------------------------------------------------------------- * Low level stream encoding: a radix tree of listpacks. * ----------------------------------------------------------------------- */ static inline int64_t lpGetIntegerIfValid(unsigned char *ele, int *valid) { int64_t v; unsigned char *e = lpGet(ele, &v, NULL); if (e == NULL) { if (valid) *valid = 1; return v; } long long ll; int ret = string2ll((char *)e, v, &ll); if (valid) *valid = ret; else serverAssert(ret != 0); v = ll; return v; } #define lpGetInteger(ele) lpGetIntegerIfValid(ele, NULL) /* Get an edge streamID of a given listpack. * 'master_id' is an input param, used to build the 'edge_id' output param */ /* Convert the specified stream entry ID as a 128 bit big endian number, so * that the IDs can be sorted lexicographically. */ static void streamEncodeID(void *buf, streamID *id) { uint64_t e[2]; e[0] = htonu64(id->ms); e[1] = htonu64(id->seq); memcpy(buf, e, sizeof(e)); } /* This is the reverse of streamEncodeID(): the decoded ID will be stored * in the 'id' structure passed by reference. The buffer 'buf' must point * to a 128 bit big-endian encoded ID. */ void streamDecodeID(void *buf, streamID *id) { uint64_t e[2]; memcpy(e, buf, sizeof(e)); id->ms = ntohu64(e[0]); id->seq = ntohu64(e[1]); } /* Compare two stream IDs. Return -1 if a < b, 0 if a == b, 1 if a > b. */ int streamCompareID(streamID *a, streamID *b) { if (a->ms > b->ms) return 1; else if (a->ms < b->ms) return -1; /* The ms part is the same. Check the sequence part. */ else if (a->seq > b->seq) return 1; else if (a->seq < b->seq) return -1; /* Everything is the same: IDs are equal. */ return 0; } /* Retrieves the ID of the stream edge entry. An edge is either the first or * the last ID in the stream, and may be a tombstone. To filter out tombstones, * set the'skip_tombstones' argument to 1. */ void streamGetEdgeID(stream *s, int first, int skip_tombstones, streamID *edge_id) { streamIterator si; int64_t numfields; streamIteratorStart(&si, s, NULL, NULL, !first); si.skip_tombstones = skip_tombstones; int found = streamIteratorGetID(&si, edge_id, &numfields); if (!found) { streamID min_id = {0, 0}, max_id = {UINT64_MAX, UINT64_MAX}; *edge_id = first ? max_id : min_id; } streamIteratorStop(&si); } /* Initialize the stream iterator, so that we can call iterating functions * to get the next items. This requires a corresponding streamIteratorStop() * at the end. The 'rev' parameter controls the direction. If it's zero the * iteration is from the start to the end element (inclusive), otherwise * if rev is non-zero, the iteration is reversed. * * Once the iterator is initialized, we iterate like this: * * streamIterator myiterator; * streamIteratorStart(&myiterator,...); * int64_t numfields; * while(streamIteratorGetID(&myiterator,&ID,&numfields)) { * while(numfields--) { * unsigned char *key, *value; * size_t key_len, value_len; * streamIteratorGetField(&myiterator,&key,&value,&key_len,&value_len); * * ... do what you want with key and value ... * } * } * streamIteratorStop(&myiterator); */ void streamIteratorStart(streamIterator *si, stream *s, streamID *start, streamID *end, int rev) { /* Initialize the iterator and translates the iteration start/stop * elements into a 128 big big-endian number. */ if (start) { streamEncodeID(si->start_key, start); } else { si->start_key[0] = 0; si->start_key[1] = 0; } if (end) { streamEncodeID(si->end_key, end); } else { si->end_key[0] = UINT64_MAX; si->end_key[1] = UINT64_MAX; } /* Seek the correct node in the radix tree. */ raxStart(&si->ri, s->rax); if (!rev) { if (start && (start->ms || start->seq)) { raxSeek(&si->ri, "<=", (unsigned char *)si->start_key, sizeof(si->start_key)); if (raxEOF(&si->ri)) raxSeek(&si->ri, "^", NULL, 0); } else { raxSeek(&si->ri, "^", NULL, 0); } } else { if (end && (end->ms || end->seq)) { raxSeek(&si->ri, "<=", (unsigned char *)si->end_key, sizeof(si->end_key)); if (raxEOF(&si->ri)) raxSeek(&si->ri, "$", NULL, 0); } else { raxSeek(&si->ri, "$", NULL, 0); } } si->stream = s; si->lp = NULL; /* There is no current listpack right now. */ si->lp_ele = NULL; /* Current listpack cursor. */ si->rev = rev; /* Direction, if non-zero reversed, from end to start. */ si->skip_tombstones = 1; /* By default tombstones aren't emitted. */ } /* Return 1 and store the current item ID at 'id' if there are still * elements within the iteration range, otherwise return 0 in order to * signal the iteration terminated. */ int streamIteratorGetID(streamIterator *si, streamID *id, int64_t *numfields) { while (1) { /* Will stop when element > stop_key or end of radix tree. */ /* If the current listpack is set to NULL, this is the start of the * iteration or the previous listpack was completely iterated. * Go to the next node. */ if (si->lp == NULL || si->lp_ele == NULL) { if (!si->rev && !raxNext(&si->ri)) return 0; else if (si->rev && !raxPrev(&si->ri)) return 0; serverAssert(si->ri.key_len == sizeof(streamID)); /* Get the master ID. */ streamDecodeID(si->ri.key,&si->master_id); /* Get the master fields count. */ si->lp = si->ri.data; si->lp_ele = lpFirst(si->lp); /* Seek items count */ si->lp_ele = lpNext(si->lp,si->lp_ele); /* Seek deleted count. */ si->lp_ele = lpNext(si->lp,si->lp_ele); /* Seek num fields. */ si->master_fields_count = lpGetInteger(si->lp_ele); si->lp_ele = lpNext(si->lp,si->lp_ele); /* Seek first field. */ si->master_fields_start = si->lp_ele; /* We are now pointing to the first field of the master entry. * We need to seek either the first or the last entry depending * on the direction of the iteration. */ if (!si->rev) { /* If we are iterating in normal order, skip the master fields * to seek the first actual entry. */ for (uint64_t i = 0; i < si->master_fields_count; i++) si->lp_ele = lpNext(si->lp,si->lp_ele); } else { /* If we are iterating in reverse direction, just seek the * last part of the last entry in the listpack (that is, the * fields count). */ si->lp_ele = lpLast(si->lp); } } else if (si->rev) { /* If we are iterating in the reverse order, and this is not * the first entry emitted for this listpack, then we already * emitted the current entry, and have to go back to the previous * one. */ int64_t lp_count = lpGetInteger(si->lp_ele); while (lp_count--) si->lp_ele = lpPrev(si->lp, si->lp_ele); /* Seek lp-count of prev entry. */ si->lp_ele = lpPrev(si->lp, si->lp_ele); } /* For every radix tree node, iterate the corresponding listpack, * returning elements when they are within range. */ while (1) { if (!si->rev) { /* If we are going forward, skip the previous entry * lp-count field (or in case of the master entry, the zero * term field) */ si->lp_ele = lpNext(si->lp,si->lp_ele); if (si->lp_ele == NULL) break; } else { /* If we are going backward, read the number of elements this * entry is composed of, and jump backward N times to seek * its start. */ int64_t lp_count = lpGetInteger(si->lp_ele); if (lp_count == 0) { /* We reached the master entry. */ si->lp = NULL; si->lp_ele = NULL; break; } while(lp_count--) si->lp_ele = lpPrev(si->lp,si->lp_ele); } /* Get the flags entry. */ si->lp_flags = si->lp_ele; int64_t flags = lpGetInteger(si->lp_ele); si->lp_ele = lpNext(si->lp,si->lp_ele); /* Seek ID. */ /* Get the ID: it is encoded as difference between the master * ID and this entry ID. */ *id = si->master_id; id->ms += lpGetInteger(si->lp_ele); si->lp_ele = lpNext(si->lp, si->lp_ele); id->seq += lpGetInteger(si->lp_ele); si->lp_ele = lpNext(si->lp, si->lp_ele); unsigned char buf[sizeof(streamID)]; streamEncodeID(buf, id); /* The number of entries is here or not depending on the * flags. */ if (flags & STREAM_ITEM_FLAG_SAMEFIELDS) { *numfields = si->master_fields_count; } else { *numfields = lpGetInteger(si->lp_ele); si->lp_ele = lpNext(si->lp, si->lp_ele); } serverAssert(*numfields >= 0); /* If current >= start, and the entry is not marked as * deleted or tombstones are included, emit it. */ if (!si->rev) { if (memcmp(buf,si->start_key,sizeof(streamID)) >= 0 && (!si->skip_tombstones || !(flags & STREAM_ITEM_FLAG_DELETED))) { if (memcmp(buf,si->end_key,sizeof(streamID)) > 0) return 0; /* We are already out of range. */ si->entry_flags = flags; if (flags & STREAM_ITEM_FLAG_SAMEFIELDS) si->master_fields_ptr = si->master_fields_start; return 1; /* Valid item returned. */ } } else { if (memcmp(buf, si->end_key, sizeof(streamID)) <= 0 && (!si->skip_tombstones || !(flags & STREAM_ITEM_FLAG_DELETED))) { if (memcmp(buf, si->start_key, sizeof(streamID)) < 0) return 0; /* We are already out of range. */ si->entry_flags = flags; if (flags & STREAM_ITEM_FLAG_SAMEFIELDS) si->master_fields_ptr = si->master_fields_start; return 1; /* Valid item returned. */ } } /* If we do not emit, we have to discard if we are going * forward, or seek the previous entry if we are going * backward. */ if (!si->rev) { int64_t to_discard = (flags & STREAM_ITEM_FLAG_SAMEFIELDS) ? *numfields : *numfields * 2; for (int64_t i = 0; i < to_discard; i++) si->lp_ele = lpNext(si->lp, si->lp_ele); } else { int64_t prev_times = 4; /* flag + id ms + id seq + one more to go back to the previous entry "count" field. */ /* If the entry was not flagged SAMEFIELD we also read the * number of fields, so go back one more. */ if (!(flags & STREAM_ITEM_FLAG_SAMEFIELDS)) prev_times++; while (prev_times--) si->lp_ele = lpPrev(si->lp, si->lp_ele); } } /* End of listpack reached. Try the next/prev radix tree node. */ } } /* Get the field and value of the current item we are iterating. This should * be called immediately after streamIteratorGetID(), and for each field * according to the number of fields returned by streamIteratorGetID(). * The function populates the field and value pointers and the corresponding * lengths by reference, that are valid until the next iterator call, assuming * no one touches the stream meanwhile. */ void streamIteratorGetField(streamIterator *si, unsigned char **fieldptr, unsigned char **valueptr, int64_t *fieldlen, int64_t *valuelen) { if (si->entry_flags & STREAM_ITEM_FLAG_SAMEFIELDS) { *fieldptr = lpGet(si->master_fields_ptr,fieldlen,si->field_buf); si->master_fields_ptr = lpNext(si->lp,si->master_fields_ptr); } else { *fieldptr = lpGet(si->lp_ele, fieldlen, si->field_buf); si->lp_ele = lpNext(si->lp, si->lp_ele); } *valueptr = lpGet(si->lp_ele, valuelen, si->value_buf); si->lp_ele = lpNext(si->lp, si->lp_ele); } /* Remove the current entry from the stream: can be called after the * GetID() API or after any GetField() call, however we need to iterate * a valid entry while calling this function. Moreover the function * requires the entry ID we are currently iterating, that was previously * returned by GetID(). * * Note that after calling this function, next calls to GetField() can't * be performed: the entry is now deleted. Instead the iterator will * automatically re-seek to the next entry, so the caller should continue * with GetID(). */ /* Stop the stream iterator. The only cleanup we need is to free the rax * iterator, since the stream iterator itself is supposed to be stack * allocated. */ void streamIteratorStop(streamIterator *si) { raxStop(&si->ri); } static int streamIDEqZero(streamID *id) { return !(id->ms || id->seq); } /* This function returns a value that is the ID's logical read counter, or its * distance (the number of entries) from the first entry ever to have been added * to the stream. * * A counter is returned only in one of the following cases: * 1. The ID is the same as the stream's last ID. In this case, the returned * is the same as the stream's entries_added counter. * 2. The ID equals that of the currently first entry in the stream, and the * stream has no tombstones. The returned value, in this case, is the result * of subtracting the stream's length from its added_entries, incremented by * one. * 3. The ID less than the stream's first current entry's ID, and there are no * tombstones. Here the estimated counter is the result of subtracting the * stream's length from its added_entries. * 4. The stream's added_entries is zero, meaning that no entries were ever * added. * * The special return value of ULLONG_MAX signals that the counter's value isn't * obtainable. It is returned in these cases: * 1. The provided ID, if it even exists, is somewhere between the stream's * current first and last entries' IDs, or in the future. * 2. The stream contains one or more tombstones. */ long long streamEstimateDistanceFromFirstEverEntry(stream *s, streamID *id) { /* The counter of any ID in an empty, never-before-used stream is 0. */ if (!s->entries_added) { return 0; } /* In the empty stream, if the ID is smaller or equal to the last ID, * it can set to the current added_entries value. */ if (!s->length && streamCompareID(id, &s->last_id) < 1) { return s->entries_added; } if (!streamIDEqZero(id) && streamCompareID(id, &s->max_deleted_entry_id) < 0) { /* The ID is before the last tombstone, so the counter is unknown. */ return SCG_INVALID_ENTRIES_READ; } int cmp_last = streamCompareID(id, &s->last_id); if (cmp_last == 0) { /* Return the exact counter of the last entry in the stream. */ return s->entries_added; } else if (cmp_last > 0) { /* The counter of a future ID is unknown. */ return SCG_INVALID_ENTRIES_READ; } int cmp_id_first = streamCompareID(id, &s->first_id); int cmp_xdel_first = streamCompareID(&s->max_deleted_entry_id, &s->first_id); if (streamIDEqZero(&s->max_deleted_entry_id) || cmp_xdel_first < 0) { /* There's definitely no fragmentation ahead. */ if (cmp_id_first < 0) { /* Return the estimated counter. */ return s->entries_added - s->length; } else if (cmp_id_first == 0) { /* Return the exact counter of the first entry in the stream. */ return s->entries_added - s->length + 1; } } /* The ID is either before an XDEL that fragments the stream or an arbitrary * ID. Either case, so we can't make a prediction. */ return SCG_INVALID_ENTRIES_READ; } /* Send the stream items in the specified range to the client 'c'. The range * the client will receive is between start and end inclusive, if 'count' is * non zero, no more than 'count' elements are sent. * * The 'end' pointer can be NULL to mean that we want all the elements from * 'start' till the end of the stream. If 'rev' is non zero, elements are * produced in reversed order from end to start. * * The function returns the number of entries emitted. * * If group and consumer are not NULL, the function performs additional work: * 1. It updates the last delivered ID in the group in case we are * sending IDs greater than the current last ID. * 2. If the requested IDs are already assigned to some other consumer, the * function will not return it to the client. * 3. An entry in the pending list will be created for every entry delivered * for the first time to this consumer. * 4. The group's read counter is incremented if it is already valid and there * are no future tombstones, or is invalidated (set to 0) otherwise. If the * counter is invalid to begin with, we try to obtain it for the last * delivered ID. * * The behavior may be modified passing non-zero flags: * * STREAM_RWR_NOACK: Do not create PEL entries, that is, the point "3" above * is not performed. * STREAM_RWR_RAWENTRIES: Do not emit array boundaries, but just the entries, * and return the number of entries emitted as usually. * This is used when the function is just used in order * to emit data and there is some higher level logic. * * The final argument 'spi' (stream propagation info pointer) is a structure * filled with information needed to propagate the command execution to AOF * and replicas, in the case a consumer group was passed: we need to generate * XCLAIM commands to create the pending list into AOF/replicas in that case. * * If 'spi' is set to NULL no propagation will happen even if the group was * given, but currently such a feature is never used by the code base that * will always pass 'spi' and propagate when a group is passed. * * Note that this function is recursive in certain cases. When it's called * with a non NULL group and consumer argument, it may call * streamReplyWithRangeFromConsumerPEL() in order to get entries from the * consumer pending entries list. However such a function will then call * streamReplyWithRange() in order to emit single entries (found in the * PEL by ID) to the client. This is the use case for the STREAM_RWR_RAWENTRIES * flag. */ #define STREAM_RWR_NOACK (1 << 0) /* Do not create entries in the PEL. */ #define STREAM_RWR_RAWENTRIES \ (1 << 1) /* Do not emit protocol for array \ boundaries, just the entries. */ #define STREAM_RWR_HISTORY (1 << 2) /* Only serve consumer local PEL. */ /* ----------------------------------------------------------------------- * Low level implementation of consumer groups * ----------------------------------------------------------------------- */ /* Free a NACK entry. */ void streamFreeNACK(streamNACK *na) { zfree(na); } /* Create a new consumer group in the context of the stream 's', having the * specified name, last server ID and reads counter. If a consumer group with * the same name already exists NULL is returned, otherwise the pointer to the * consumer group is returned. */ streamCG *streamCreateCG(stream *s, const char *name, size_t namelen, streamID *id, long long entries_read) { if (s->cgroups == NULL) s->cgroups = raxNew(); if (raxFind(s->cgroups, (unsigned char *)name, namelen, NULL)) return NULL; streamCG *cg = zmalloc(sizeof(*cg)); cg->pel = raxNew(); cg->consumers = raxNew(); cg->last_id = *id; cg->entries_read = entries_read; raxInsert(s->cgroups, (unsigned char *)name, namelen, cg, NULL); return cg; } ================================================ FILE: src/redis/util.c ================================================ /* * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include #include #include #include #include #include #include "util.h" /* Return the number of digits of 'v' when converted to string in radix 10. * See ll2string() for more information. */ static uint32_t digits10(uint64_t v) { if (v < 10) return 1; if (v < 100) return 2; if (v < 1000) return 3; if (v < 1000000000000UL) { if (v < 100000000UL) { if (v < 1000000) { if (v < 10000) return 4; return 5 + (v >= 100000); } return 7 + (v >= 10000000UL); } if (v < 10000000000UL) { return 9 + (v >= 1000000000UL); } return 11 + (v >= 100000000000UL); } return 12 + digits10(v / 1000000000000UL); } /* Convert a long long into a string. Returns the number of * characters needed to represent the number. * If the buffer is not big enough to store the string, 0 is returned. * * Based on the following article (that apparently does not provide a * novel approach but only publicizes an already used technique): * * https://www.facebook.com/notes/facebook-engineering/three-optimization-tips-for-c/10151361643253920 * * Modified in order to handle signed integers since the original code was * designed for unsigned integers. */ int ll2string(char *dst, size_t dstlen, long long svalue) { static const char digits[201] = "0001020304050607080910111213141516171819" "2021222324252627282930313233343536373839" "4041424344454647484950515253545556575859" "6061626364656667686970717273747576777879" "8081828384858687888990919293949596979899"; int negative; unsigned long long value; /* The main loop works with 64bit unsigned integers for simplicity, so * we convert the number here and remember if it is negative. */ if (svalue < 0) { if (svalue != LLONG_MIN) { value = -svalue; } else { value = ((unsigned long long) LLONG_MAX)+1; } negative = 1; } else { value = svalue; negative = 0; } /* Check length. */ uint32_t const length = digits10(value)+negative; if (length >= dstlen) return 0; /* Null term. */ uint32_t next = length; dst[next] = '\0'; next--; while (value >= 100) { int const i = (value % 100) * 2; value /= 100; dst[next] = digits[i + 1]; dst[next - 1] = digits[i]; next -= 2; } /* Handle last 1-2 digits. */ if (value < 10) { dst[next] = '0' + (uint32_t) value; } else { int i = (uint32_t) value * 2; dst[next] = digits[i + 1]; dst[next - 1] = digits[i]; } /* Add sign. */ if (negative) dst[0] = '-'; return length; } /* Convert a string into a long long. Returns 1 if the string could be parsed * into a (non-overflowing) long long, 0 otherwise. The value will be set to * the parsed value when appropriate. * * Note that this function demands that the string strictly represents * a long long: no spaces or other characters before or after the string * representing the number are accepted, nor zeroes at the start if not * for the string "0" representing the zero number. * * Because of its strictness, it is safe to use this function to check if * you can convert a string into a long long, and obtain back the string * from the number without any loss in the string representation. */ int string2ll(const char *s, size_t slen, long long *value) { const char *p = s; size_t plen = 0; int negative = 0; unsigned long long v; /* A zero length string is not a valid number. */ if (plen == slen) return 0; /* Special case: first and only digit is 0. */ if (slen == 1 && p[0] == '0') { if (value != NULL) *value = 0; return 1; } /* Handle negative numbers: just set a flag and continue like if it * was a positive number. Later convert into negative. */ if (p[0] == '-') { negative = 1; p++; plen++; /* Abort on only a negative sign. */ if (plen == slen) return 0; } /* First digit should be 1-9, otherwise the string should just be 0. */ if (p[0] >= '1' && p[0] <= '9') { v = p[0]-'0'; p++; plen++; } else { return 0; } /* Parse all the other digits, checking for overflow at every step. */ while (plen < slen && p[0] >= '0' && p[0] <= '9') { if (v > (ULLONG_MAX / 10)) /* Overflow. */ return 0; v *= 10; if (v > (ULLONG_MAX - (p[0]-'0'))) /* Overflow. */ return 0; v += p[0]-'0'; p++; plen++; } /* Return if not all bytes were used. */ if (plen < slen) return 0; /* Convert to negative if needed, and do the final overflow check when * converting from unsigned long long to long long. */ if (negative) { if (v > ((unsigned long long)(-(LLONG_MIN+1))+1)) /* Overflow. */ return 0; if (value != NULL) *value = -v; } else { if (v > LLONG_MAX) /* Overflow. */ return 0; if (value != NULL) *value = v; } return 1; } ================================================ FILE: src/redis/util.h ================================================ /* * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __REDIS_UTIL_H #define __REDIS_UTIL_H #include #include #include /* The maximum number of characters needed to represent a long double * as a string (long double has a huge range). * This should be the size of the buffer given to ld2string */ #define MAX_LONG_DOUBLE_CHARS 5*1024 /* Error codes */ #define C_OK 0 #define C_ERR -1 int ll2string(char *s, size_t len, long long value); int string2ll(const char *s, size_t slen, long long *value); #define LOG_MAX_LEN 1024 /* Default maximum length of syslog messages.*/ /* Log levels */ #define LL_DEBUG 0 #define LL_VERBOSE 1 #define LL_NOTICE 2 #define LL_WARNING 3 #define LL_RAW (1<<10) /* Modifier to log without timestamp */ /* Bytes needed for long -> str + '\0' */ #define LONG_STR_SIZE 21 void serverLog(int level, const char *fmt, ...); void _serverPanic(const char *file, int line, const char *msg, ...); void _serverAssert(const char *estr, const char *file, int line); #define serverPanic(...) _serverPanic(__FILE__,__LINE__,__VA_ARGS__),_exit(1) #define serverAssert(_e) ((_e)?(void)0 : (_serverAssert(#_e,__FILE__,__LINE__),_exit(1))) typedef long long mstime_t; /* millisecond time type. */ #endif ================================================ FILE: src/redis/ziplist.c ================================================ /* The ziplist is a specially encoded dually linked list that is designed * to be very memory efficient. It stores both strings and integer values, * where integers are encoded as actual integers instead of a series of * characters. It allows push and pop operations on either side of the list * in O(1) time. However, because every operation requires a reallocation of * the memory used by the ziplist, the actual complexity is related to the * amount of memory used by the ziplist. * * ---------------------------------------------------------------------------- * * ZIPLIST OVERALL LAYOUT * ====================== * * The general layout of the ziplist is as follows: * * ... * * NOTE: all fields are stored in little endian, if not specified otherwise. * * is an unsigned integer to hold the number of bytes that * the ziplist occupies, including the four bytes of the zlbytes field itself. * This value needs to be stored to be able to resize the entire structure * without the need to traverse it first. * * is the offset to the last entry in the list. This allows * a pop operation on the far side of the list without the need for full * traversal. * * is the number of entries. When there are more than * 2^16-2 entries, this value is set to 2^16-1 and we need to traverse the * entire list to know how many items it holds. * * is a special entry representing the end of the ziplist. * Is encoded as a single byte equal to 255. No other normal entry starts * with a byte set to the value of 255. * * ZIPLIST ENTRIES * =============== * * Every entry in the ziplist is prefixed by metadata that contains two pieces * of information. First, the length of the previous entry is stored to be * able to traverse the list from back to front. Second, the entry encoding is * provided. It represents the entry type, integer or string, and in the case * of strings it also represents the length of the string payload. * So a complete entry is stored like this: * * * * Sometimes the encoding represents the entry itself, like for small integers * as we'll see later. In such a case the part is missing, and we * could have just: * * * * The length of the previous entry, , is encoded in the following way: * If this length is smaller than 254 bytes, it will only consume a single * byte representing the length as an unsigned 8 bit integer. When the length * is greater than or equal to 254, it will consume 5 bytes. The first byte is * set to 254 (FE) to indicate a larger value is following. The remaining 4 * bytes take the length of the previous entry as value. * * So practically an entry is encoded in the following way: * * * * Or alternatively if the previous entry length is greater than 253 bytes * the following encoding is used: * * 0xFE <4 bytes unsigned little endian prevlen> * * The encoding field of the entry depends on the content of the * entry. When the entry is a string, the first 2 bits of the encoding first * byte will hold the type of encoding used to store the length of the string, * followed by the actual length of the string. When the entry is an integer * the first 2 bits are both set to 1. The following 2 bits are used to specify * what kind of integer will be stored after this header. An overview of the * different types and encodings is as follows. The first byte is always enough * to determine the kind of entry. * * |00pppppp| - 1 byte * String value with length less than or equal to 63 bytes (6 bits). * "pppppp" represents the unsigned 6 bit length. * |01pppppp|qqqqqqqq| - 2 bytes * String value with length less than or equal to 16383 bytes (14 bits). * IMPORTANT: The 14 bit number is stored in big endian. * |10000000|qqqqqqqq|rrrrrrrr|ssssssss|tttttttt| - 5 bytes * String value with length greater than or equal to 16384 bytes. * Only the 4 bytes following the first byte represents the length * up to 2^32-1. The 6 lower bits of the first byte are not used and * are set to zero. * IMPORTANT: The 32 bit number is stored in big endian. * |11000000| - 3 bytes * Integer encoded as int16_t (2 bytes). * |11010000| - 5 bytes * Integer encoded as int32_t (4 bytes). * |11100000| - 9 bytes * Integer encoded as int64_t (8 bytes). * |11110000| - 4 bytes * Integer encoded as 24 bit signed (3 bytes). * |11111110| - 2 bytes * Integer encoded as 8 bit signed (1 byte). * |1111xxxx| - (with xxxx between 0001 and 1101) immediate 4 bit integer. * Unsigned integer from 0 to 12. The encoded value is actually from * 1 to 13 because 0000 and 1111 can not be used, so 1 should be * subtracted from the encoded 4 bit value to obtain the right value. * |11111111| - End of ziplist special entry. * * Like for the ziplist header, all the integers are represented in little * endian byte order, even when this code is compiled in big endian systems. * * EXAMPLES OF ACTUAL ZIPLISTS * =========================== * * The following is a ziplist containing the two elements representing * the strings "2" and "5". It is composed of 15 bytes, that we visually * split into sections: * * [0f 00 00 00] [0c 00 00 00] [02 00] [00 f3] [02 f6] [ff] * | | | | | | * zlbytes zltail zllen "2" "5" end * * The first 4 bytes represent the number 15, that is the number of bytes * the whole ziplist is composed of. The second 4 bytes are the offset * at which the last ziplist entry is found, that is 12, in fact the * last entry, that is "5", is at offset 12 inside the ziplist. * The next 16 bit integer represents the number of elements inside the * ziplist, its value is 2 since there are just two elements inside. * Finally "00 f3" is the first entry representing the number 2. It is * composed of the previous entry length, which is zero because this is * our first entry, and the byte F3 which corresponds to the encoding * |1111xxxx| with xxxx between 0001 and 1101. We need to remove the "F" * higher order bits 1111, and subtract 1 from the "3", so the entry value * is "2". The next entry has a prevlen of 02, since the first entry is * composed of exactly two bytes. The entry itself, F6, is encoded exactly * like the first entry, and 6-1 = 5, so the value of the entry is 5. * Finally the special entry FF signals the end of the ziplist. * * Adding another element to the above string with the value "Hello World" * allows us to show how the ziplist encodes small strings. We'll just show * the hex dump of the entry itself. Imagine the bytes as following the * entry that stores "5" in the ziplist above: * * [02] [0b] [48 65 6c 6c 6f 20 57 6f 72 6c 64] * * The first byte, 02, is the length of the previous entry. The next * byte represents the encoding in the pattern |00pppppp| that means * that the entry is a string of length , so 0B means that * an 11 bytes string follows. From the third byte (48) to the last (64) * there are just the ASCII characters for "Hello World". * * ---------------------------------------------------------------------------- * * Copyright (c) 2009-2012, Pieter Noordhuis * Copyright (c) 2009-2017, 2020, Redis Ltd. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include #include #include "zmalloc.h" #include "util.h" #include "ziplist.h" #include "config.h" #include "endianconv.h" #define ZIP_END 255 /* Special "end of ziplist" entry. */ #define ZIP_BIG_PREVLEN \ 254 /* ZIP_BIG_PREVLEN - 1 is the max number of bytes of \ the previous entry, for the "prevlen" field prefixing \ each entry, to be represented with just a single byte. \ Otherwise it is represented as FE AA BB CC DD, where \ AA BB CC DD are a 4 bytes unsigned integer \ representing the previous entry len. */ /* Different encoding/length possibilities */ #define ZIP_STR_MASK 0xc0 #define ZIP_INT_MASK 0x30 #define ZIP_STR_06B (0 << 6) #define ZIP_STR_14B (1 << 6) #define ZIP_STR_32B (2 << 6) #define ZIP_INT_16B (0xc0 | 0<<4) #define ZIP_INT_32B (0xc0 | 1<<4) #define ZIP_INT_64B (0xc0 | 2<<4) #define ZIP_INT_24B (0xc0 | 3<<4) #define ZIP_INT_8B 0xfe /* 4 bit integer immediate encoding |1111xxxx| with xxxx between * 0001 and 1101. */ #define ZIP_INT_IMM_MASK \ 0x0f /* Mask to extract the 4 bits value. To add \ one is needed to reconstruct the value. */ #define ZIP_INT_IMM_MIN 0xf1 /* 11110001 */ #define ZIP_INT_IMM_MAX 0xfd /* 11111101 */ #define INT24_MAX 0x7fffff #define INT24_MIN (-INT24_MAX - 1) /* Macro to determine if the entry is a string. String entries never start * with "11" as most significant bits of the first byte. */ #define ZIP_IS_STR(enc) (((enc) & ZIP_STR_MASK) < ZIP_STR_MASK) /* Utility macros.*/ /* Return total bytes a ziplist is composed of. */ #define ZIPLIST_BYTES(zl) (*((uint32_t*)(zl))) /* Return the offset of the last item inside the ziplist. */ #define ZIPLIST_TAIL_OFFSET(zl) (*((uint32_t*)((zl)+sizeof(uint32_t)))) /* Return the length of a ziplist, or UINT16_MAX if the length cannot be * determined without scanning the whole ziplist. */ #define ZIPLIST_LENGTH(zl) (*((uint16_t*)((zl)+sizeof(uint32_t)*2))) /* The size of a ziplist header: two 32 bit integers for the total * bytes count and last item offset. One 16 bit integer for the number * of items field. */ #define ZIPLIST_HEADER_SIZE (sizeof(uint32_t)*2+sizeof(uint16_t)) /* Size of the "end of ziplist" entry. Just one byte. */ #define ZIPLIST_END_SIZE (sizeof(uint8_t)) /* Return the pointer to the first entry of a ziplist. */ #define ZIPLIST_ENTRY_HEAD(zl) ((zl)+ZIPLIST_HEADER_SIZE) /* Return the pointer to the last entry of a ziplist, using the * last entry offset inside the ziplist header. */ #define ZIPLIST_ENTRY_TAIL(zl) ((zl)+intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl))) /* Return the pointer to the last byte of a ziplist, which is, the * end of ziplist FF entry. */ #define ZIPLIST_ENTRY_END(zl) ((zl)+intrev32ifbe(ZIPLIST_BYTES(zl))-ZIPLIST_END_SIZE) /* Increment the number of items field in the ziplist header. Note that this * macro should never overflow the unsigned 16 bit integer, since entries are * always pushed one at a time. When UINT16_MAX is reached we want the count * to stay there to signal that a full scan is needed to get the number of * items inside the ziplist. */ #define ZIPLIST_INCR_LENGTH(zl, incr) \ { \ if (intrev16ifbe(ZIPLIST_LENGTH(zl)) < UINT16_MAX) \ ZIPLIST_LENGTH(zl) = intrev16ifbe(intrev16ifbe(ZIPLIST_LENGTH(zl))+incr); \ } /* Don't let ziplists grow over 1GB in any case, don't wanna risk overflow in * zlbytes*/ #define ZIPLIST_MAX_SAFETY_SIZE (1<<30) int ziplistSafeToAdd(unsigned char* zl, size_t add) { size_t len = zl? ziplistBlobLen(zl): 0; if (len + add > ZIPLIST_MAX_SAFETY_SIZE) return 0; return 1; } /* We use this function to receive information about a ziplist entry. * Note that this is not how the data is actually encoded, is just what we * get filled by a function in order to operate more easily. */ typedef struct zlentry { unsigned int prevrawlensize; /* Bytes used to encode the previous entry len*/ unsigned int prevrawlen; /* Previous entry len. */ unsigned int lensize; /* Bytes used to encode this entry type/len. For example strings have a 1, 2 or 5 bytes header. Integers always use a single byte.*/ unsigned int len; /* Bytes used to represent the actual entry. For strings this is just the string length while for integers it is 1, 2, 3, 4, 8 or 0 (for 4 bit immediate) depending on the number range. */ unsigned int headersize; /* prevrawlensize + lensize. */ unsigned char encoding; /* Set to ZIP_STR_* or ZIP_INT_* depending on the entry encoding. However for 4 bits immediate integers this can assume a range of values and must be range-checked. */ unsigned char *p; /* Pointer to the very start of the entry, that is, this points to prev-entry-len field. */ } zlentry; #define ZIPLIST_ENTRY_ZERO(zle) \ { \ (zle)->prevrawlensize = (zle)->prevrawlen = 0; \ (zle)->lensize = (zle)->len = (zle)->headersize = 0; \ (zle)->encoding = 0; \ (zle)->p = NULL; \ } /* Extract the encoding from the byte pointed by 'ptr' and set it into * 'encoding' field of the zlentry structure. */ #define ZIP_ENTRY_ENCODING(ptr, encoding) \ do { \ (encoding) = ((ptr)[0]); \ if ((encoding) < ZIP_STR_MASK) (encoding) &= ZIP_STR_MASK; \ } while(0) #define ZIP_ENCODING_SIZE_INVALID 0xff /* Return the number of bytes required to encode the entry type + length. * On error, return ZIP_ENCODING_SIZE_INVALID */ static inline unsigned int zipEncodingLenSize(unsigned char encoding) { if (encoding == ZIP_INT_16B || encoding == ZIP_INT_32B || encoding == ZIP_INT_24B || encoding == ZIP_INT_64B || encoding == ZIP_INT_8B) return 1; if (encoding >= ZIP_INT_IMM_MIN && encoding <= ZIP_INT_IMM_MAX) return 1; if (encoding == ZIP_STR_06B) return 1; if (encoding == ZIP_STR_14B) return 2; if (encoding == ZIP_STR_32B) return 5; return ZIP_ENCODING_SIZE_INVALID; } #define ZIP_ASSERT_ENCODING(encoding) \ do { \ assert(zipEncodingLenSize(encoding) != ZIP_ENCODING_SIZE_INVALID); \ } while (0) /* Return bytes needed to store integer encoded by 'encoding' */ static inline unsigned int zipIntSize(unsigned char encoding) { switch(encoding) { case ZIP_INT_8B: return 1; case ZIP_INT_16B: return 2; case ZIP_INT_24B: return 3; case ZIP_INT_32B: return 4; case ZIP_INT_64B: return 8; } if (encoding >= ZIP_INT_IMM_MIN && encoding <= ZIP_INT_IMM_MAX) return 0; /* 4 bit immediate */ /* bad encoding, covered by a previous call to ZIP_ASSERT_ENCODING */ valkey_unreachable(); return 0; } /* Write the encoding header of the entry in 'p'. If p is NULL it just returns * the amount of bytes required to encode such a length. Arguments: * * 'encoding' is the encoding we are using for the entry. It could be * ZIP_INT_* or ZIP_STR_* or between ZIP_INT_IMM_MIN and ZIP_INT_IMM_MAX * for single-byte small immediate integers. * * 'rawlen' is only used for ZIP_STR_* encodings and is the length of the * string that this entry represents. * * The function returns the number of bytes used by the encoding/length * header stored in 'p'. */ unsigned int zipStoreEntryEncoding(unsigned char *p, unsigned char encoding, unsigned int rawlen) { unsigned char len = 1, buf[5]; if (ZIP_IS_STR(encoding)) { /* Although encoding is given it may not be set for strings, * so we determine it here using the raw length. */ if (rawlen <= 0x3f) { if (!p) return len; buf[0] = ZIP_STR_06B | rawlen; } else if (rawlen <= 0x3fff) { len += 1; if (!p) return len; buf[0] = ZIP_STR_14B | ((rawlen >> 8) & 0x3f); buf[1] = rawlen & 0xff; } else { len += 4; if (!p) return len; buf[0] = ZIP_STR_32B; buf[1] = (rawlen >> 24) & 0xff; buf[2] = (rawlen >> 16) & 0xff; buf[3] = (rawlen >> 8) & 0xff; buf[4] = rawlen & 0xff; } } else { /* Implies integer encoding, so length is always 1. */ if (!p) return len; buf[0] = encoding; } /* Store this length at p. */ memcpy(p,buf,len); return len; } /* Decode the entry encoding type and data length (string length for strings, * number of bytes used for the integer for integer entries) encoded in 'ptr'. * The 'encoding' variable is input, extracted by the caller, the 'lensize' * variable will hold the number of bytes required to encode the entry * length, and the 'len' variable will hold the entry length. * On invalid encoding error, lensize is set to 0. */ #define ZIP_DECODE_LENGTH(ptr, encoding, lensize, len) \ do { \ if ((encoding) < ZIP_STR_MASK) { \ if ((encoding) == ZIP_STR_06B) { \ (lensize) = 1; \ (len) = (ptr)[0] & 0x3f; \ } else if ((encoding) == ZIP_STR_14B) { \ (lensize) = 2; \ (len) = (((ptr)[0] & 0x3f) << 8) | (ptr)[1]; \ } else if ((encoding) == ZIP_STR_32B) { \ (lensize) = 5; \ (len) = ((uint32_t)(ptr)[1] << 24) | ((uint32_t)(ptr)[2] << 16) | ((uint32_t)(ptr)[3] << 8) | \ ((uint32_t)(ptr)[4]); \ } else { \ (lensize) = 0; /* bad encoding, should be covered by a previous */ \ (len) = 0; /* ZIP_ASSERT_ENCODING / zipEncodingLenSize, or */ \ /* match the lensize after this macro with 0. */ \ } \ } else { \ (lensize) = 1; \ if ((encoding) == ZIP_INT_8B) \ (len) = 1; \ else if ((encoding) == ZIP_INT_16B) \ (len) = 2; \ else if ((encoding) == ZIP_INT_24B) \ (len) = 3; \ else if ((encoding) == ZIP_INT_32B) \ (len) = 4; \ else if ((encoding) == ZIP_INT_64B) \ (len) = 8; \ else if (encoding >= ZIP_INT_IMM_MIN && encoding <= ZIP_INT_IMM_MAX) \ (len) = 0; /* 4 bit immediate */ \ else \ (lensize) = (len) = 0; /* bad encoding */ \ } \ } while(0) /* Encode the length of the previous entry and write it to "p". This only * uses the larger encoding (required in __ziplistCascadeUpdate). */ int zipStorePrevEntryLengthLarge(unsigned char *p, unsigned int len) { uint32_t u32; if (p != NULL) { p[0] = ZIP_BIG_PREVLEN; u32 = len; memcpy(p+1,&u32,sizeof(u32)); memrev32ifbe(p+1); } return 1 + sizeof(uint32_t); } /* Encode the length of the previous entry and write it to "p". Return the * number of bytes needed to encode this length if "p" is NULL. */ unsigned int zipStorePrevEntryLength(unsigned char *p, unsigned int len) { if (p == NULL) { return (len < ZIP_BIG_PREVLEN) ? 1 : sizeof(uint32_t) + 1; } else { if (len < ZIP_BIG_PREVLEN) { p[0] = len; return 1; } else { return zipStorePrevEntryLengthLarge(p,len); } } } /* Return the number of bytes used to encode the length of the previous * entry. The length is returned by setting the var 'prevlensize'. */ #define ZIP_DECODE_PREVLENSIZE(ptr, prevlensize) \ do { \ if ((ptr)[0] < ZIP_BIG_PREVLEN) { \ (prevlensize) = 1; \ } else { \ (prevlensize) = 5; \ } \ } while(0) /* Return the length of the previous element, and the number of bytes that * are used in order to encode the previous element length. * 'ptr' must point to the prevlen prefix of an entry (that encodes the * length of the previous entry in order to navigate the elements backward). * The length of the previous entry is stored in 'prevlen', the number of * bytes needed to encode the previous entry length are stored in * 'prevlensize'. */ #define ZIP_DECODE_PREVLEN(ptr, prevlensize, prevlen) \ do { \ ZIP_DECODE_PREVLENSIZE(ptr, prevlensize); \ if ((prevlensize) == 1) { \ (prevlen) = (ptr)[0]; \ } else { /* prevlensize == 5 */ \ (prevlen) = ((ptr)[4] << 24) | ((ptr)[3] << 16) | ((ptr)[2] << 8) | ((ptr)[1]); \ } \ } while(0) /* Given a pointer 'p' to the prevlen info that prefixes an entry, this * function returns the difference in number of bytes needed to encode * the prevlen if the previous entry changes of size. * * So if A is the number of bytes used right now to encode the 'prevlen' * field. * * And B is the number of bytes that are needed in order to encode the * 'prevlen' if the previous element will be updated to one of size 'len'. * * Then the function returns B - A * * So the function returns a positive number if more space is needed, * a negative number if less space is needed, or zero if the same space * is needed. */ int zipPrevLenByteDiff(unsigned char *p, unsigned int len) { unsigned int prevlensize; ZIP_DECODE_PREVLENSIZE(p, prevlensize); return zipStorePrevEntryLength(NULL, len) - prevlensize; } /* Check if string pointed to by 'entry' can be encoded as an integer. * Stores the integer value in 'v' and its encoding in 'encoding'. */ int zipTryEncoding(unsigned char *entry, unsigned int entrylen, long long *v, unsigned char *encoding) { long long value; if (entrylen >= 32 || entrylen == 0) return 0; if (string2ll((char*)entry,entrylen,&value)) { /* Great, the string can be encoded. Check what's the smallest * of our encoding types that can hold this value. */ if (value >= 0 && value <= 12) { *encoding = ZIP_INT_IMM_MIN+value; } else if (value >= INT8_MIN && value <= INT8_MAX) { *encoding = ZIP_INT_8B; } else if (value >= INT16_MIN && value <= INT16_MAX) { *encoding = ZIP_INT_16B; } else if (value >= INT24_MIN && value <= INT24_MAX) { *encoding = ZIP_INT_24B; } else if (value >= INT32_MIN && value <= INT32_MAX) { *encoding = ZIP_INT_32B; } else { *encoding = ZIP_INT_64B; } *v = value; return 1; } return 0; } /* Store integer 'value' at 'p', encoded as 'encoding' */ void zipSaveInteger(unsigned char *p, int64_t value, unsigned char encoding) { int16_t i16; int32_t i32; int64_t i64; if (encoding == ZIP_INT_8B) { ((int8_t*)p)[0] = (int8_t)value; } else if (encoding == ZIP_INT_16B) { i16 = value; memcpy(p,&i16,sizeof(i16)); memrev16ifbe(p); } else if (encoding == ZIP_INT_24B) { i32 = ((uint64_t)value)<<8; memrev32ifbe(&i32); memcpy(p,((uint8_t*)&i32)+1,sizeof(i32)-sizeof(uint8_t)); } else if (encoding == ZIP_INT_32B) { i32 = value; memcpy(p,&i32,sizeof(i32)); memrev32ifbe(p); } else if (encoding == ZIP_INT_64B) { i64 = value; memcpy(p,&i64,sizeof(i64)); memrev64ifbe(p); } else if (encoding >= ZIP_INT_IMM_MIN && encoding <= ZIP_INT_IMM_MAX) { /* Nothing to do, the value is stored in the encoding itself. */ } else { assert(NULL); } } /* Read integer encoded as 'encoding' from 'p' */ int64_t zipLoadInteger(unsigned char *p, unsigned char encoding) { int16_t i16; int32_t i32; int64_t i64, ret = 0; if (encoding == ZIP_INT_8B) { ret = ((int8_t*)p)[0]; } else if (encoding == ZIP_INT_16B) { memcpy(&i16,p,sizeof(i16)); memrev16ifbe(&i16); ret = i16; } else if (encoding == ZIP_INT_32B) { memcpy(&i32,p,sizeof(i32)); memrev32ifbe(&i32); ret = i32; } else if (encoding == ZIP_INT_24B) { i32 = 0; memcpy(((uint8_t*)&i32)+1,p,sizeof(i32)-sizeof(uint8_t)); memrev32ifbe(&i32); ret = i32>>8; } else if (encoding == ZIP_INT_64B) { memcpy(&i64,p,sizeof(i64)); memrev64ifbe(&i64); ret = i64; } else if (encoding >= ZIP_INT_IMM_MIN && encoding <= ZIP_INT_IMM_MAX) { ret = (encoding & ZIP_INT_IMM_MASK)-1; } else { assert(NULL); } return ret; } /* Fills a struct with all information about an entry. * This function is the "unsafe" alternative to the one below. * Generally, all function that return a pointer to an element in the ziplist * will assert that this element is valid, so it can be freely used. * Generally functions such ziplistGet assume the input pointer is already * validated (since it's the return value of another function). */ static inline void zipEntry(unsigned char *p, zlentry *e) { ZIP_DECODE_PREVLEN(p, e->prevrawlensize, e->prevrawlen); ZIP_ENTRY_ENCODING(p + e->prevrawlensize, e->encoding); ZIP_DECODE_LENGTH(p + e->prevrawlensize, e->encoding, e->lensize, e->len); assert(e->lensize != 0); /* check that encoding was valid. */ e->headersize = e->prevrawlensize + e->lensize; e->p = p; } /* Fills a struct with all information about an entry. * This function is safe to use on untrusted pointers, it'll make sure not to * try to access memory outside the ziplist payload. * Returns 1 if the entry is valid, and 0 otherwise. */ static inline int zipEntrySafe(unsigned char* zl, size_t zlbytes, unsigned char *p, zlentry *e, int validate_prevlen) { unsigned char *zlfirst = zl + ZIPLIST_HEADER_SIZE; unsigned char *zllast = zl + zlbytes - ZIPLIST_END_SIZE; #define OUT_OF_RANGE(p) (unlikely((p) < zlfirst || (p) > zllast)) /* If there's no possibility for the header to reach outside the ziplist, * take the fast path. (max lensize and prevrawlensize are both 5 bytes) */ if (p >= zlfirst && p + 10 < zllast) { ZIP_DECODE_PREVLEN(p, e->prevrawlensize, e->prevrawlen); ZIP_ENTRY_ENCODING(p + e->prevrawlensize, e->encoding); ZIP_DECODE_LENGTH(p + e->prevrawlensize, e->encoding, e->lensize, e->len); e->headersize = e->prevrawlensize + e->lensize; e->p = p; /* We didn't call ZIP_ASSERT_ENCODING, so we check lensize was set to 0. */ if (unlikely(e->lensize == 0)) return 0; /* Make sure the entry doesn't reach outside the edge of the ziplist */ if (OUT_OF_RANGE(p + e->headersize + e->len)) return 0; /* Make sure prevlen doesn't reach outside the edge of the ziplist */ if (validate_prevlen && OUT_OF_RANGE(p - e->prevrawlen)) return 0; return 1; } /* Make sure the pointer doesn't reach outside the edge of the ziplist */ if (OUT_OF_RANGE(p)) return 0; /* Make sure the encoded prevlen header doesn't reach outside the allocation */ ZIP_DECODE_PREVLENSIZE(p, e->prevrawlensize); if (OUT_OF_RANGE(p + e->prevrawlensize)) return 0; /* Make sure encoded entry header is valid. */ ZIP_ENTRY_ENCODING(p + e->prevrawlensize, e->encoding); e->lensize = zipEncodingLenSize(e->encoding); if (unlikely(e->lensize == ZIP_ENCODING_SIZE_INVALID)) return 0; /* Make sure the encoded entry header doesn't reach outside the allocation */ if (OUT_OF_RANGE(p + e->prevrawlensize + e->lensize)) return 0; /* Decode the prevlen and entry len headers. */ ZIP_DECODE_PREVLEN(p, e->prevrawlensize, e->prevrawlen); ZIP_DECODE_LENGTH(p + e->prevrawlensize, e->encoding, e->lensize, e->len); e->headersize = e->prevrawlensize + e->lensize; /* Make sure the entry doesn't reach outside the edge of the ziplist */ if (OUT_OF_RANGE(p + e->headersize + e->len)) return 0; /* Make sure prevlen doesn't reach outside the edge of the ziplist */ if (validate_prevlen && OUT_OF_RANGE(p - e->prevrawlen)) return 0; e->p = p; return 1; #undef OUT_OF_RANGE } /* Return the total number of bytes used by the entry pointed to by 'p'. */ static inline unsigned int zipRawEntryLengthSafe(unsigned char* zl, size_t zlbytes, unsigned char *p) { zlentry e; zipEntrySafe(zl, zlbytes, p, &e, 0); return e.headersize + e.len; } /* Return the total number of bytes used by the entry pointed to by 'p'. */ static inline unsigned int zipRawEntryLength(unsigned char *p) { zlentry e; zipEntry(p, &e); return e.headersize + e.len; } /* Validate that the entry doesn't reach outside the ziplist allocation. */ static inline void zipAssertValidEntry(unsigned char* zl, size_t zlbytes, unsigned char *p) { zlentry e; int res = zipEntrySafe(zl, zlbytes, p, &e, 1); assert(res); (void)res; } /* Create a new empty ziplist. */ unsigned char *ziplistNew(void) { unsigned int bytes = ZIPLIST_HEADER_SIZE+ZIPLIST_END_SIZE; unsigned char *zl = zmalloc(bytes); ZIPLIST_BYTES(zl) = intrev32ifbe(bytes); ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(ZIPLIST_HEADER_SIZE); ZIPLIST_LENGTH(zl) = 0; zl[bytes-1] = ZIP_END; return zl; } /* Resize the ziplist. */ unsigned char *ziplistResize(unsigned char *zl, size_t len) { assert(len < UINT32_MAX); zl = zrealloc(zl,len); ZIPLIST_BYTES(zl) = intrev32ifbe(len); zl[len-1] = ZIP_END; return zl; } /* When an entry is inserted, we need to set the prevlen field of the next * entry to equal the length of the inserted entry. It can occur that this * length cannot be encoded in 1 byte and the next entry needs to be grow * a bit larger to hold the 5-byte encoded prevlen. This can be done for free, * because this only happens when an entry is already being inserted (which * causes a realloc and memmove). However, encoding the prevlen may require * that this entry is grown as well. This effect may cascade throughout * the ziplist when there are consecutive entries with a size close to * ZIP_BIG_PREVLEN, so we need to check that the prevlen can be encoded in * every consecutive entry. * * Note that this effect can also happen in reverse, where the bytes required * to encode the prevlen field can shrink. This effect is deliberately ignored, * because it can cause a "flapping" effect where a chain prevlen fields is * first grown and then shrunk again after consecutive inserts. Rather, the * field is allowed to stay larger than necessary, because a large prevlen * field implies the ziplist is holding large entries anyway. * * The pointer "p" points to the first entry that does NOT need to be * updated, i.e. consecutive fields MAY need an update. */ unsigned char *__ziplistCascadeUpdate(unsigned char *zl, unsigned char *p) { zlentry cur; size_t prevlen, prevlensize, prevoffset; /* Informat of the last changed entry. */ size_t firstentrylen; /* Used to handle insert at head. */ size_t rawlen, curlen = intrev32ifbe(ZIPLIST_BYTES(zl)); size_t extra = 0, cnt = 0, offset; size_t delta = 4; /* Extra bytes needed to update a entry's prevlen (5-1). */ unsigned char *tail = zl + intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl)); /* Empty ziplist */ if (p[0] == ZIP_END) return zl; zipEntry( p, &cur); /* no need for "safe" variant since the input pointer was validated by the function that returned it. */ firstentrylen = prevlen = cur.headersize + cur.len; prevlensize = zipStorePrevEntryLength(NULL, prevlen); prevoffset = p - zl; p += prevlen; /* Iterate ziplist to find out how many extra bytes do we need to update it. */ while (p[0] != ZIP_END) { assert(zipEntrySafe(zl, curlen, p, &cur, 0)); /* Abort when "prevlen" has not changed. */ if (cur.prevrawlen == prevlen) break; /* Abort when entry's "prevlensize" is big enough. */ if (cur.prevrawlensize >= prevlensize) { if (cur.prevrawlensize == prevlensize) { zipStorePrevEntryLength(p, prevlen); } else { /* This would result in shrinking, which we want to avoid. * So, set "prevlen" in the available bytes. */ zipStorePrevEntryLengthLarge(p, prevlen); } break; } /* cur.prevrawlen means cur is the former head entry. */ assert(cur.prevrawlen == 0 || cur.prevrawlen + delta == prevlen); /* Update prev entry's info and advance the cursor. */ rawlen = cur.headersize + cur.len; prevlen = rawlen + delta; prevlensize = zipStorePrevEntryLength(NULL, prevlen); prevoffset = p - zl; p += rawlen; extra += delta; cnt++; } /* Extra bytes is zero all update has been done(or no need to update). */ if (extra == 0) return zl; /* Update tail offset after loop. */ if (tail == zl + prevoffset) { /* When the last entry we need to update is also the tail, update tail offset * unless this is the only entry that was updated (so the tail offset didn't change). */ if (extra - delta != 0) { ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl)) + extra - delta); } } else { /* Update the tail offset in cases where the last entry we updated is not the tail. */ ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl)) + extra); } /* Now "p" points at the first unchanged byte in original ziplist, * move data after that to new ziplist. */ offset = p - zl; zl = ziplistResize(zl, curlen + extra); p = zl + offset; memmove(p + extra, p, curlen - offset - 1); p += extra; /* Iterate all entries that need to be updated tail to head. */ while (cnt) { zipEntry(zl + prevoffset, &cur); /* no need for "safe" variant since we already iterated on all these entries above. */ rawlen = cur.headersize + cur.len; /* Move entry to tail and reset prevlen. */ memmove(p - (rawlen - cur.prevrawlensize), zl + prevoffset + cur.prevrawlensize, rawlen - cur.prevrawlensize); p -= (rawlen + delta); if (cur.prevrawlen == 0) { /* "cur" is the previous head entry, update its prevlen with firstentrylen. */ zipStorePrevEntryLength(p, firstentrylen); } else { /* An entry's prevlen can only increment 4 bytes. */ zipStorePrevEntryLength(p, cur.prevrawlen+delta); } /* Forward to previous entry. */ prevoffset -= cur.prevrawlen; cnt--; } return zl; } /* Delete "num" entries, starting at "p". Returns pointer to the ziplist. */ unsigned char *__ziplistDelete(unsigned char *zl, unsigned char *p, unsigned int num) { unsigned int i, totlen, deleted = 0; size_t offset; int nextdiff = 0; zlentry first, tail; size_t zlbytes = intrev32ifbe(ZIPLIST_BYTES(zl)); zipEntry(p, &first); /* no need for "safe" variant since the input pointer was validated by the function that returned it. */ for (i = 0; p[0] != ZIP_END && i < num; i++) { p += zipRawEntryLengthSafe(zl, zlbytes, p); deleted++; } assert(p >= first.p); totlen = p-first.p; /* Bytes taken by the element(s) to delete. */ if (totlen > 0) { uint32_t set_tail; if (p[0] != ZIP_END) { /* Storing `prevrawlen` in this entry may increase or decrease the * number of bytes required compare to the current `prevrawlen`. * There always is room to store this, because it was previously * stored by an entry that is now being deleted. */ nextdiff = zipPrevLenByteDiff(p,first.prevrawlen); /* Note that there is always space when p jumps backward: if * the new previous entry is large, one of the deleted elements * had a 5 bytes prevlen header, so there is for sure at least * 5 bytes free and we need just 4. */ p -= nextdiff; assert(p >= first.p && p= first.p. we know totlen >= 0, * so we know that p > first.p and this is guaranteed not to reach * beyond the allocation, even if the entries lens are corrupted. */ size_t bytes_to_move = zlbytes-(p-zl)-1; memmove(first.p,p,bytes_to_move); } else { /* The entire tail was deleted. No need to move memory. */ set_tail = (first.p-zl)-first.prevrawlen; } /* Resize the ziplist */ offset = first.p-zl; zlbytes -= totlen - nextdiff; zl = ziplistResize(zl, zlbytes); p = zl+offset; /* Update record count */ ZIPLIST_INCR_LENGTH(zl,-deleted); /* Set the tail offset computed above */ assert(set_tail <= zlbytes - ZIPLIST_END_SIZE); ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(set_tail); /* When nextdiff != 0, the raw length of the next entry has changed, so * we need to cascade the update throughout the ziplist */ if (nextdiff != 0) zl = __ziplistCascadeUpdate(zl, p); } return zl; } /* Insert item at "p". */ unsigned char *__ziplistInsert(unsigned char *zl, unsigned char *p, unsigned char *s, unsigned int slen) { size_t curlen = intrev32ifbe(ZIPLIST_BYTES(zl)), reqlen, newlen; unsigned int prevlensize, prevlen = 0; size_t offset; int nextdiff = 0; unsigned char encoding = 0; long long value = 123456789; /* initialized to avoid warning. Using a value that is easy to see if for some reason we use it uninitialized. */ zlentry tail; /* Find out prevlen for the entry that is inserted. */ if (p[0] != ZIP_END) { ZIP_DECODE_PREVLEN(p, prevlensize, prevlen); } else { unsigned char *ptail = ZIPLIST_ENTRY_TAIL(zl); if (ptail[0] != ZIP_END) { prevlen = zipRawEntryLengthSafe(zl, curlen, ptail); } } /* See if the entry can be encoded */ if (zipTryEncoding(s,slen,&value,&encoding)) { /* 'encoding' is set to the appropriate integer encoding */ reqlen = zipIntSize(encoding); } else { /* 'encoding' is untouched, however zipStoreEntryEncoding will use the * string length to figure out how to encode it. */ reqlen = slen; } /* We need space for both the length of the previous entry and * the length of the payload. */ reqlen += zipStorePrevEntryLength(NULL,prevlen); reqlen += zipStoreEntryEncoding(NULL,encoding,slen); /* When the insert position is not equal to the tail, we need to * make sure that the next entry can hold this entry's length in * its prevlen field. */ int forcelarge = 0; nextdiff = (p[0] != ZIP_END) ? zipPrevLenByteDiff(p,reqlen) : 0; if (nextdiff == -4 && reqlen < 4) { nextdiff = 0; forcelarge = 1; } /* Store offset because a realloc may change the address of zl. */ offset = p-zl; newlen = curlen+reqlen+nextdiff; zl = ziplistResize(zl,newlen); p = zl+offset; /* Apply memory move when necessary and update tail offset. */ if (p[0] != ZIP_END) { /* Subtract one because of the ZIP_END bytes */ memmove(p+reqlen,p-nextdiff,curlen-offset-1+nextdiff); /* Encode this entry's raw length in the next entry. */ if (forcelarge) zipStorePrevEntryLengthLarge(p+reqlen,reqlen); else zipStorePrevEntryLength(p+reqlen,reqlen); /* Update offset for tail */ ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl)) + reqlen); /* When the tail contains more than one entry, we need to take * "nextdiff" in account as well. Otherwise, a change in the * size of prevlen doesn't have an effect on the *tail* offset. */ zipEntrySafe(zl, newlen, p + reqlen, &tail, 1); if (p[reqlen+tail.headersize+tail.len] != ZIP_END) { ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl)) + nextdiff); } } else { /* This element will be the new tail. */ ZIPLIST_TAIL_OFFSET(zl) = intrev32ifbe(p-zl); } /* When nextdiff != 0, the raw length of the next entry has changed, so * we need to cascade the update throughout the ziplist */ if (nextdiff != 0) { offset = p-zl; zl = __ziplistCascadeUpdate(zl,p+reqlen); p = zl+offset; } /* Write the entry */ p += zipStorePrevEntryLength(p,prevlen); p += zipStoreEntryEncoding(p,encoding,slen); if (ZIP_IS_STR(encoding)) { memcpy(p,s,slen); } else { zipSaveInteger(p,value,encoding); } ZIPLIST_INCR_LENGTH(zl,1); return zl; } /* Merge ziplists 'first' and 'second' by appending 'second' to 'first'. * * NOTE: The larger ziplist is reallocated to contain the new merged ziplist. * Either 'first' or 'second' can be used for the result. The parameter not * used will be free'd and set to NULL. * * After calling this function, the input parameters are no longer valid since * they are changed and free'd in-place. * * The result ziplist is the contents of 'first' followed by 'second'. * * On failure: returns NULL if the merge is impossible. * On success: returns the merged ziplist (which is expanded version of either * 'first' or 'second', also frees the other unused input ziplist, and sets the * input ziplist argument equal to newly reallocated ziplist return value. */ unsigned char *ziplistMerge(unsigned char **first, unsigned char **second) { /* If any params are null, we can't merge, so NULL. */ if (first == NULL || *first == NULL || second == NULL || *second == NULL) return NULL; /* Can't merge same list into itself. */ if (*first == *second) return NULL; size_t first_bytes = intrev32ifbe(ZIPLIST_BYTES(*first)); size_t first_len = intrev16ifbe(ZIPLIST_LENGTH(*first)); size_t second_bytes = intrev32ifbe(ZIPLIST_BYTES(*second)); size_t second_len = intrev16ifbe(ZIPLIST_LENGTH(*second)); int append; unsigned char *source, *target; size_t target_bytes, source_bytes; /* Pick the largest ziplist so we can resize easily in-place. * We must also track if we are now appending or prepending to * the target ziplist. */ if (first_len >= second_len) { /* retain first, append second to first. */ target = *first; target_bytes = first_bytes; source = *second; source_bytes = second_bytes; append = 1; } else { /* else, retain second, prepend first to second. */ target = *second; target_bytes = second_bytes; source = *first; source_bytes = first_bytes; append = 0; } /* Calculate final bytes (subtract one pair of metadata) */ size_t zlbytes = first_bytes + second_bytes - ZIPLIST_HEADER_SIZE - ZIPLIST_END_SIZE; size_t zllength = first_len + second_len; /* Combined zl length should be limited within UINT16_MAX */ zllength = zllength < UINT16_MAX ? zllength : UINT16_MAX; /* larger values can't be stored into ZIPLIST_BYTES */ assert(zlbytes < UINT32_MAX); /* Save offset positions before we start ripping memory apart. */ size_t first_offset = intrev32ifbe(ZIPLIST_TAIL_OFFSET(*first)); size_t second_offset = intrev32ifbe(ZIPLIST_TAIL_OFFSET(*second)); /* Extend target to new zlbytes then append or prepend source. */ target = zrealloc(target, zlbytes); if (append) { /* append == appending to target */ /* Copy source after target (copying over original [END]): * [TARGET - END, SOURCE - HEADER] */ memcpy(target + target_bytes - ZIPLIST_END_SIZE, source + ZIPLIST_HEADER_SIZE, source_bytes - ZIPLIST_HEADER_SIZE); } else { /* !append == prepending to target */ /* Move target *contents* exactly size of (source - [END]), * then copy source into vacated space (source - [END]): * [SOURCE - END, TARGET - HEADER] */ memmove(target + source_bytes - ZIPLIST_END_SIZE, target + ZIPLIST_HEADER_SIZE, target_bytes - ZIPLIST_HEADER_SIZE); memcpy(target, source, source_bytes - ZIPLIST_END_SIZE); } /* Update header metadata. */ ZIPLIST_BYTES(target) = intrev32ifbe(zlbytes); ZIPLIST_LENGTH(target) = intrev16ifbe(zllength); /* New tail offset is: * + N bytes of first ziplist * - 1 byte for [END] of first ziplist * + M bytes for the offset of the original tail of the second ziplist * - J bytes for HEADER because second_offset keeps no header. */ ZIPLIST_TAIL_OFFSET(target) = intrev32ifbe((first_bytes - ZIPLIST_END_SIZE) + (second_offset - ZIPLIST_HEADER_SIZE)); /* __ziplistCascadeUpdate just fixes the prev length values until it finds a * correct prev length value (then it assumes the rest of the list is okay). * We tell CascadeUpdate to start at the first ziplist's tail element to fix * the merge seam. */ target = __ziplistCascadeUpdate(target, target+first_offset); /* Now free and NULL out what we didn't realloc */ if (append) { zfree(*second); *second = NULL; *first = target; } else { zfree(*first); *first = NULL; *second = target; } return target; } unsigned char *ziplistPush(unsigned char *zl, unsigned char *s, unsigned int slen, int where) { unsigned char *p; p = (where == ZIPLIST_HEAD) ? ZIPLIST_ENTRY_HEAD(zl) : ZIPLIST_ENTRY_END(zl); return __ziplistInsert(zl,p,s,slen); } /* Returns an offset to use for iterating with ziplistNext. When the given * index is negative, the list is traversed back to front. When the list * doesn't contain an element at the provided index, NULL is returned. */ unsigned char *ziplistIndex(unsigned char *zl, int index) { unsigned char *p; unsigned int prevlensize, prevlen = 0; size_t zlbytes = intrev32ifbe(ZIPLIST_BYTES(zl)); if (index < 0) { index = (-index)-1; p = ZIPLIST_ENTRY_TAIL(zl); if (p[0] != ZIP_END) { /* No need for "safe" check: when going backwards, we know the header * we're parsing is in the range, we just need to assert (below) that * the size we take doesn't cause p to go outside the allocation. */ ZIP_DECODE_PREVLENSIZE(p, prevlensize); assert(p + prevlensize < zl + zlbytes - ZIPLIST_END_SIZE); ZIP_DECODE_PREVLEN(p, prevlensize, prevlen); while (prevlen > 0 && index--) { p -= prevlen; assert(p >= zl + ZIPLIST_HEADER_SIZE && p < zl + zlbytes - ZIPLIST_END_SIZE); ZIP_DECODE_PREVLEN(p, prevlensize, prevlen); } } } else { p = ZIPLIST_ENTRY_HEAD(zl); while (index--) { /* Use the "safe" length: When we go forward, we need to be careful * not to decode an entry header if it's past the ziplist allocation. */ p += zipRawEntryLengthSafe(zl, zlbytes, p); if (p[0] == ZIP_END) break; } } if (p[0] == ZIP_END || index > 0) return NULL; zipAssertValidEntry(zl, zlbytes, p); return p; } /* Return pointer to next entry in ziplist. * * zl is the pointer to the ziplist * p is the pointer to the current element * * The element after 'p' is returned, otherwise NULL if we are at the end. */ unsigned char *ziplistNext(unsigned char *zl, unsigned char *p) { ((void) zl); size_t zlbytes = intrev32ifbe(ZIPLIST_BYTES(zl)); /* "p" could be equal to ZIP_END, caused by ziplistDelete, * and we should return NULL. Otherwise, we should return NULL * when the *next* element is ZIP_END (there is no next entry). */ if (p[0] == ZIP_END) { return NULL; } p += zipRawEntryLength(p); if (p[0] == ZIP_END) { return NULL; } zipAssertValidEntry(zl, zlbytes, p); return p; } /* Return pointer to previous entry in ziplist. */ unsigned char *ziplistPrev(unsigned char *zl, unsigned char *p) { unsigned int prevlensize, prevlen = 0; /* Iterating backwards from ZIP_END should return the tail. When "p" is * equal to the first element of the list, we're already at the head, * and should return NULL. */ if (p[0] == ZIP_END) { p = ZIPLIST_ENTRY_TAIL(zl); return (p[0] == ZIP_END) ? NULL : p; } else if (p == ZIPLIST_ENTRY_HEAD(zl)) { return NULL; } else { ZIP_DECODE_PREVLEN(p, prevlensize, prevlen); assert(prevlen > 0); p-=prevlen; size_t zlbytes = intrev32ifbe(ZIPLIST_BYTES(zl)); zipAssertValidEntry(zl, zlbytes, p); return p; } } /* Get entry pointed to by 'p' and store in either '*sstr' or 'sval' depending * on the encoding of the entry. '*sstr' is always set to NULL to be able * to find out whether the string pointer or the integer value was set. * Return 0 if 'p' points to the end of the ziplist, 1 otherwise. */ unsigned int ziplistGet(unsigned char *p, unsigned char **sstr, unsigned int *slen, long long *sval) { zlentry entry; if (p == NULL || p[0] == ZIP_END) return 0; if (sstr) *sstr = NULL; zipEntry(p, &entry); /* no need for "safe" variant since the input pointer was validated by the function that returned it. */ if (ZIP_IS_STR(entry.encoding)) { if (sstr) { *slen = entry.len; *sstr = p+entry.headersize; } } else { if (sval) { *sval = zipLoadInteger(p+entry.headersize,entry.encoding); } } return 1; } /* Insert an entry at "p". */ unsigned char *ziplistInsert(unsigned char *zl, unsigned char *p, unsigned char *s, unsigned int slen) { return __ziplistInsert(zl,p,s,slen); } /* Delete a single entry from the ziplist, pointed to by *p. * Also update *p in place, to be able to iterate over the * ziplist, while deleting entries. */ unsigned char *ziplistDelete(unsigned char *zl, unsigned char **p) { size_t offset = *p-zl; zl = __ziplistDelete(zl,*p,1); /* Store pointer to current element in p, because ziplistDelete will * do a realloc which might result in a different "zl"-pointer. * When the delete direction is back to front, we might delete the last * entry and end up with "p" pointing to ZIP_END, so check this. */ *p = zl+offset; return zl; } /* Delete a range of entries from the ziplist. */ unsigned char *ziplistDeleteRange(unsigned char *zl, int index, unsigned int num) { unsigned char *p = ziplistIndex(zl,index); return (p == NULL) ? zl : __ziplistDelete(zl,p,num); } /* Replaces the entry at p. This is equivalent to a delete and an insert, * but avoids some overhead when replacing a value of the same size. */ unsigned char *ziplistReplace(unsigned char *zl, unsigned char *p, unsigned char *s, unsigned int slen) { /* get metadata of the current entry */ zlentry entry; zipEntry(p, &entry); /* compute length of entry to store, excluding prevlen */ unsigned int reqlen; unsigned char encoding = 0; long long value = 123456789; /* initialized to avoid warning. */ if (zipTryEncoding(s,slen,&value,&encoding)) { reqlen = zipIntSize(encoding); /* encoding is set */ } else { reqlen = slen; /* encoding == 0 */ } reqlen += zipStoreEntryEncoding(NULL,encoding,slen); if (reqlen == entry.lensize + entry.len) { /* Simply overwrite the element. */ p += entry.prevrawlensize; p += zipStoreEntryEncoding(p,encoding,slen); if (ZIP_IS_STR(encoding)) { memcpy(p,s,slen); } else { zipSaveInteger(p,value,encoding); } } else { /* Fallback. */ zl = ziplistDelete(zl,&p); zl = ziplistInsert(zl,p,s,slen); } return zl; } /* Compare entry pointer to by 'p' with 'sstr' of length 'slen'. */ /* Return 1 if equal. */ unsigned int ziplistCompare(unsigned char *p, unsigned char *sstr, unsigned int slen) { zlentry entry; unsigned char sencoding; long long zval, sval; if (p[0] == ZIP_END) return 0; zipEntry(p, &entry); /* no need for "safe" variant since the input pointer was validated by the function that returned it. */ if (ZIP_IS_STR(entry.encoding)) { /* Raw compare */ if (entry.len == slen) { return memcmp(p+entry.headersize,sstr,slen) == 0; } else { return 0; } } else { /* Try to compare encoded values. Don't compare encoding because * different implementations may encoded integers differently. */ if (zipTryEncoding(sstr,slen,&sval,&sencoding)) { zval = zipLoadInteger(p+entry.headersize,entry.encoding); return zval == sval; } } return 0; } /* Find pointer to the entry equal to the specified entry. Skip 'skip' entries * between every comparison. Returns NULL when the field could not be found. */ unsigned char * ziplistFind(unsigned char *zl, unsigned char *p, unsigned char *vstr, unsigned int vlen, unsigned int skip) { int skipcnt = 0; unsigned char vencoding = 0; long long vll = 0; size_t zlbytes = ziplistBlobLen(zl); while (p[0] != ZIP_END) { struct zlentry e; unsigned char *q; int res = zipEntrySafe(zl, zlbytes, p, &e, 1); assert(res); (void)res; q = p + e.prevrawlensize + e.lensize; if (skipcnt == 0) { /* Compare current entry with specified entry */ if (ZIP_IS_STR(e.encoding)) { if (e.len == vlen && memcmp(q, vstr, vlen) == 0) { return p; } } else { /* Find out if the searched field can be encoded. Note that * we do it only the first time, once done vencoding is set * to non-zero and vll is set to the integer value. */ if (vencoding == 0) { if (!zipTryEncoding(vstr, vlen, &vll, &vencoding)) { /* If the entry can't be encoded we set it to * UCHAR_MAX so that we don't retry again the next * time. */ vencoding = UCHAR_MAX; } /* Must be non-zero by now */ assert(vencoding); } /* Compare current entry with specified entry, do it only * if vencoding != UCHAR_MAX because if there is no encoding * possible for the field it can't be a valid integer. */ if (vencoding != UCHAR_MAX) { long long ll = zipLoadInteger(q, e.encoding); if (ll == vll) { return p; } } } /* Reset skip count */ skipcnt = skip; } else { /* Skip entry */ skipcnt--; } /* Move to next entry */ p = q + e.len; } return NULL; } /* Return length of ziplist. */ unsigned int ziplistLen(unsigned char *zl) { unsigned int len = 0; if (intrev16ifbe(ZIPLIST_LENGTH(zl)) < UINT16_MAX) { len = intrev16ifbe(ZIPLIST_LENGTH(zl)); } else { unsigned char *p = zl+ZIPLIST_HEADER_SIZE; size_t zlbytes = intrev32ifbe(ZIPLIST_BYTES(zl)); while (*p != ZIP_END) { p += zipRawEntryLengthSafe(zl, zlbytes, p); len++; } /* Re-store length if small enough */ if (len < UINT16_MAX) ZIPLIST_LENGTH(zl) = intrev16ifbe(len); } return len; } /* Return ziplist blob size in bytes. */ size_t ziplistBlobLen(unsigned char *zl) { return intrev32ifbe(ZIPLIST_BYTES(zl)); } void ziplistRepr(unsigned char *zl) { unsigned char *p; int index = 0; zlentry entry; size_t zlbytes = ziplistBlobLen(zl); printf("{total bytes %u} " "{num entries %u}\n" "{tail offset %u}\n", intrev32ifbe(ZIPLIST_BYTES(zl)), intrev16ifbe(ZIPLIST_LENGTH(zl)), intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl))); p = ZIPLIST_ENTRY_HEAD(zl); while(*p != ZIP_END) { zipEntrySafe(zl, zlbytes, p, &entry, 1); printf( "{\n" "\taddr 0x%08lx,\n" "\tindex %2d,\n" "\toffset %5lu,\n" "\thdr+entry len: %5u,\n" "\thdr len%2u,\n" "\tprevrawlen: %5u,\n" "\tprevrawlensize: %2u,\n" "\tpayload %5u\n", (long unsigned)p, index, (unsigned long)(p - zl), entry.headersize + entry.len, entry.headersize, entry.prevrawlen, entry.prevrawlensize, entry.len); printf("\tbytes: "); for (unsigned int i = 0; i < entry.headersize+entry.len; i++) { printf("%02x|",p[i]); } printf("\n"); p += entry.headersize; if (ZIP_IS_STR(entry.encoding)) { printf("\t[str]"); if (entry.len > 40) { if (fwrite(p,40,1,stdout) == 0) perror("fwrite"); printf("..."); } else { if (entry.len && fwrite(p, entry.len, 1, stdout) == 0) perror("fwrite"); } } else { printf("\t[int]%lld", (long long) zipLoadInteger(p,entry.encoding)); } printf("\n}\n"); p += entry.len; index++; } printf("{end}\n\n"); } /* Validate the integrity of the data structure. * when `deep` is 0, only the integrity of the header is validated. * when `deep` is 1, we scan all the entries one by one. */ int ziplistValidateIntegrity(unsigned char *zl, size_t size, int deep, ziplistValidateEntryCB entry_cb, void *cb_userdata) { /* check that we can actually read the header. (and ZIP_END) */ if (size < ZIPLIST_HEADER_SIZE + ZIPLIST_END_SIZE) return 0; /* check that the encoded size in the header must match the allocated size. */ size_t bytes = intrev32ifbe(ZIPLIST_BYTES(zl)); if (bytes != size) return 0; /* the last byte must be the terminator. */ if (zl[size - ZIPLIST_END_SIZE] != ZIP_END) return 0; /* make sure the tail offset isn't reaching outside the allocation. */ if (intrev32ifbe(ZIPLIST_TAIL_OFFSET(zl)) > size - ZIPLIST_END_SIZE) return 0; if (!deep) return 1; unsigned int count = 0; unsigned int header_count = intrev16ifbe(ZIPLIST_LENGTH(zl)); unsigned char *p = ZIPLIST_ENTRY_HEAD(zl); unsigned char *prev = NULL; size_t prev_raw_size = 0; while(*p != ZIP_END) { struct zlentry e; /* Decode the entry headers and fail if invalid or reaches outside the allocation */ if (!zipEntrySafe(zl, size, p, &e, 1)) return 0; /* Make sure the record stating the prev entry size is correct. */ if (e.prevrawlen != prev_raw_size) return 0; /* Optionally let the caller validate the entry too. */ if (entry_cb && !entry_cb(p, header_count, cb_userdata)) return 0; /* Move to the next entry */ prev_raw_size = e.headersize + e.len; prev = p; p += e.headersize + e.len; count++; } /* Make sure 'p' really does point to the end of the ziplist. */ if (p != zl + bytes - ZIPLIST_END_SIZE) return 0; /* Make sure the entry really do point to the start of the last entry. */ if (prev != NULL && prev != ZIPLIST_ENTRY_TAIL(zl)) return 0; /* Check that the count in the header is correct */ if (header_count != UINT16_MAX && count != header_count) return 0; return 1; } /* Randomly select a pair of key and value. * total_count is a pre-computed length/2 of the ziplist (to avoid calls to ziplistLen) * 'key' and 'val' are used to store the result key value pair. * 'val' can be NULL if the value is not needed. */ void ziplistRandomPair(unsigned char *zl, unsigned long total_count, ziplistEntry *key, ziplistEntry *val) { int ret; unsigned char *p; /* Avoid div by zero on corrupt ziplist */ assert(total_count); /* Generate even numbers, because ziplist saved K-V pair */ int r = (rand() % total_count) * 2; p = ziplistIndex(zl, r); ret = ziplistGet(p, &key->sval, &key->slen, &key->lval); assert(ret != 0); (void)ret; if (!val) return; p = ziplistNext(zl, p); ret = ziplistGet(p, &val->sval, &val->slen, &val->lval); assert(ret != 0); } /* int compare for qsort */ int uintCompare(const void *a, const void *b) { return (*(unsigned int *) a - *(unsigned int *) b); } /* Helper method to store a string into from val or lval into dest */ static inline void ziplistSaveValue(unsigned char *val, unsigned int len, long long lval, ziplistEntry *dest) { dest->sval = val; dest->slen = len; dest->lval = lval; } /* Randomly select count of key value pairs and store into 'keys' and * 'vals' args. The order of the picked entries is random, and the selections * are non-unique (repetitions are possible). * The 'vals' arg can be NULL in which case we skip these. */ void ziplistRandomPairs(unsigned char *zl, unsigned int count, ziplistEntry *keys, ziplistEntry *vals) { unsigned char *p, *key, *value; unsigned int klen = 0, vlen = 0; long long klval = 0, vlval = 0; /* Notice: the index member must be first due to the use in uintCompare */ typedef struct { unsigned int index; unsigned int order; } rand_pick; rand_pick *picks = zmalloc(sizeof(rand_pick)*count); unsigned int total_size = ziplistLen(zl)/2; /* Avoid div by zero on corrupt ziplist */ assert(total_size); /* create a pool of random indexes (some may be duplicate). */ for (unsigned int i = 0; i < count; i++) { picks[i].index = (rand() % total_size) * 2; /* Generate even indexes */ /* keep track of the order we picked them */ picks[i].order = i; } /* sort by indexes. */ qsort(picks, count, sizeof(rand_pick), uintCompare); /* fetch the elements form the ziplist into a output array respecting the original order. */ unsigned int zipindex = picks[0].index, pickindex = 0; p = ziplistIndex(zl, zipindex); while (ziplistGet(p, &key, &klen, &klval) && pickindex < count) { p = ziplistNext(zl, p); assert(ziplistGet(p, &value, &vlen, &vlval)); while (pickindex < count && zipindex == picks[pickindex].index) { int storeorder = picks[pickindex].order; ziplistSaveValue(key, klen, klval, &keys[storeorder]); if (vals) ziplistSaveValue(value, vlen, vlval, &vals[storeorder]); pickindex++; } zipindex += 2; p = ziplistNext(zl, p); } zfree(picks); } /* Randomly select count of key value pairs and store into 'keys' and * 'vals' args. The selections are unique (no repetitions), and the order of * the picked entries is NOT-random. * The 'vals' arg can be NULL in which case we skip these. * The return value is the number of items picked which can be lower than the * requested count if the ziplist doesn't hold enough pairs. */ unsigned int ziplistRandomPairsUnique(unsigned char *zl, unsigned int count, ziplistEntry *keys, ziplistEntry *vals) { unsigned char *p, *key; unsigned int klen = 0; long long klval = 0; unsigned int total_size = ziplistLen(zl)/2; unsigned int index = 0; if (count > total_size) count = total_size; /* To only iterate once, every time we try to pick a member, the probability * we pick it is the quotient of the count left we want to pick and the * count still we haven't visited in the dict, this way, we could make every * member be equally picked.*/ p = ziplistIndex(zl, 0); unsigned int picked = 0, remaining = count; while (picked < count && p) { double randomDouble = ((double)rand()) / RAND_MAX; double threshold = ((double)remaining) / (total_size - index); if (randomDouble <= threshold) { assert(ziplistGet(p, &key, &klen, &klval)); ziplistSaveValue(key, klen, klval, &keys[picked]); p = ziplistNext(zl, p); assert(p); if (vals) { assert(ziplistGet(p, &key, &klen, &klval)); ziplistSaveValue(key, klen, klval, &vals[picked]); } remaining--; picked++; } else { p = ziplistNext(zl, p); assert(p); } p = ziplistNext(zl, p); index++; } return picked; } ================================================ FILE: src/redis/ziplist.h ================================================ /* * Copyright (c) 2009-2012, Pieter Noordhuis * Copyright (c) 2009-2012, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef _ZIPLIST_H #define _ZIPLIST_H #define ZIPLIST_HEAD 0 #define ZIPLIST_TAIL 1 /* Each entry in the ziplist is either a string or an integer. */ typedef struct { /* When string is used, it is provided with the length (slen). */ unsigned char *sval; unsigned int slen; /* When integer is used, 'sval' is NULL, and lval holds the value. */ long long lval; } ziplistEntry; unsigned char *ziplistNew(void); unsigned char *ziplistMerge(unsigned char **first, unsigned char **second); unsigned char *ziplistPush(unsigned char *zl, unsigned char *s, unsigned int slen, int where); unsigned char *ziplistIndex(unsigned char *zl, int index); unsigned char *ziplistNext(unsigned char *zl, unsigned char *p); unsigned char *ziplistPrev(unsigned char *zl, unsigned char *p); unsigned int ziplistGet(unsigned char *p, unsigned char **sval, unsigned int *slen, long long *lval); unsigned char *ziplistInsert(unsigned char *zl, unsigned char *p, unsigned char *s, unsigned int slen); unsigned char *ziplistDelete(unsigned char *zl, unsigned char **p); unsigned char *ziplistDeleteRange(unsigned char *zl, int index, unsigned int num); unsigned char *ziplistReplace(unsigned char *zl, unsigned char *p, unsigned char *s, unsigned int slen); unsigned int ziplistCompare(unsigned char *p, unsigned char *s, unsigned int slen); unsigned char *ziplistFind(unsigned char *zl, unsigned char *p, unsigned char *vstr, unsigned int vlen, unsigned int skip); unsigned int ziplistLen(unsigned char *zl); size_t ziplistBlobLen(unsigned char *zl); void ziplistRepr(unsigned char *zl); typedef int (*ziplistValidateEntryCB)(unsigned char* p, unsigned int head_count, void* userdata); int ziplistValidateIntegrity(unsigned char *zl, size_t size, int deep, ziplistValidateEntryCB entry_cb, void *cb_userdata); void ziplistRandomPair(unsigned char *zl, unsigned long total_count, ziplistEntry *key, ziplistEntry *val); void ziplistRandomPairs(unsigned char *zl, unsigned int count, ziplistEntry *keys, ziplistEntry *vals); unsigned int ziplistRandomPairsUnique(unsigned char *zl, unsigned int count, ziplistEntry *keys, ziplistEntry *vals); int ziplistSafeToAdd(unsigned char* zl, size_t add); #ifdef REDIS_TEST int ziplistTest(int argc, char *argv[], int accurate); #endif #endif /* _ZIPLIST_H */ ================================================ FILE: src/redis/zmalloc.c ================================================ /* zmalloc - total amount of allocated memory aware version of malloc() * * Copyright (c) 2009-2010, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include #include #include #include #include /* This function provide us access to the original libc free(). This is useful * for instance to free results obtained by backtrace_symbols(). We need * to define this function before including zmalloc.h that may shadow the * free implementation if we use jemalloc or another non standard allocator. */ /*void zlibc_free(void *ptr) { free(ptr); }*/ #include #include #include "config.h" #include "zmalloc.h" #include "atomicvar.h" #ifdef HAVE_MALLOC_SIZE #define PREFIX_SIZE (0) #define ASSERT_NO_SIZE_OVERFLOW(sz) #else #if defined(__sun) || defined(__sparc) || defined(__sparc__) #define PREFIX_SIZE (sizeof(long long)) #else #define PREFIX_SIZE (sizeof(size_t)) #endif #define ASSERT_NO_SIZE_OVERFLOW(sz) assert((sz) + PREFIX_SIZE > (sz)) #endif /* When using the libc allocator, use a minimum allocation size to match the * jemalloc behavior that doesn't return NULL in this case. */ #define MALLOC_MIN_SIZE(x) ((x) > 0 ? (x) : sizeof(long)) /* Explicitly override malloc/free etc when using tcmalloc. */ #if defined(USE_TCMALLOC) #define malloc(size) tc_malloc(size) #define calloc(count,size) tc_calloc(count,size) #define realloc(ptr,size) tc_realloc(ptr,size) #define free(ptr) tc_free(ptr) #elif defined(USE_JEMALLOC) #define malloc(size) je_malloc(size) #define calloc(count,size) je_calloc(count,size) #define realloc(ptr,size) je_realloc(ptr,size) #define free(ptr) je_free(ptr) #define mallocx(size,flags) je_mallocx(size,flags) #define dallocx(ptr,flags) je_dallocx(ptr,flags) #endif #define update_zmalloc_stat_alloc(__n) used_memory_tl += (__n) #define update_zmalloc_stat_free(__n) used_memory_tl -= (__n) __thread ssize_t used_memory_tl = 0; static void zmalloc_default_oom(size_t size) { fprintf(stderr, "zmalloc: Out of memory trying to allocate %zu bytes\n", size); fflush(stderr); abort(); } static void (*zmalloc_oom_handler)(size_t) = zmalloc_default_oom; void init_zmalloc_threadlocal() { } /* Try allocating memory, and return NULL if failed. * '*usable' is set to the usable size if non NULL. */ void *ztrymalloc_usable(size_t size, size_t *usable) { ASSERT_NO_SIZE_OVERFLOW(size); void *ptr = malloc(MALLOC_MIN_SIZE(size)+PREFIX_SIZE); if (!ptr) return NULL; #ifdef HAVE_MALLOC_SIZE size = zmalloc_size(ptr); update_zmalloc_stat_alloc(size); if (usable) *usable = size; return ptr; #else *((size_t*)ptr) = size; update_zmalloc_stat_alloc(size+PREFIX_SIZE); if (usable) *usable = size; return (char*)ptr+PREFIX_SIZE; #endif } /* Allocate memory or panic */ void *zmalloc(size_t size) { void *ptr = ztrymalloc_usable(size, NULL); if (!ptr) zmalloc_oom_handler(size); return ptr; } /* Try allocating memory, and return NULL if failed. */ void *ztrymalloc(size_t size) { void *ptr = ztrymalloc_usable(size, NULL); return ptr; } /* Allocate memory or panic. * '*usable' is set to the usable size if non NULL. */ void *zmalloc_usable(size_t size, size_t *usable) { void *ptr = ztrymalloc_usable(size, usable); if (!ptr) zmalloc_oom_handler(size); return ptr; } size_t znallocx(size_t size) { #if defined(USE_JEMALLOC) return je_ncallocx(size, 0); #else return size; #endif } void zfree_size(void* ptr, size_t size) { #if defined(USE_JEMALLOC) je_sdallocx(ptr, size, 0); #else free(ptr); (void)size; #endif } /* Allocation and free functions that bypass the thread cache * and go straight to the allocator arena bins. * Currently implemented only for jemalloc. Used for online defragmentation. */ #ifdef HAVE_DEFRAG void *zmalloc_no_tcache(size_t size) { ASSERT_NO_SIZE_OVERFLOW(size); void *ptr = mallocx(size+PREFIX_SIZE, MALLOCX_TCACHE_NONE); if (!ptr) zmalloc_oom_handler(size); update_zmalloc_stat_alloc(zmalloc_size(ptr)); return ptr; } void zfree_no_tcache(void *ptr) { if (ptr == NULL) return; update_zmalloc_stat_free(zmalloc_size(ptr)); dallocx(ptr, MALLOCX_TCACHE_NONE); } #endif /* Try allocating memory and zero it, and return NULL if failed. * '*usable' is set to the usable size if non NULL. */ void *ztrycalloc_usable(size_t size, size_t *usable) { ASSERT_NO_SIZE_OVERFLOW(size); void *ptr = calloc(1, MALLOC_MIN_SIZE(size)+PREFIX_SIZE); if (ptr == NULL) return NULL; #ifdef HAVE_MALLOC_SIZE size = zmalloc_size(ptr); update_zmalloc_stat_alloc(size); if (usable) *usable = size; return ptr; #else *((size_t*)ptr) = size; update_zmalloc_stat_alloc(size+PREFIX_SIZE); if (usable) *usable = size; return (char*)ptr+PREFIX_SIZE; #endif } /* Allocate memory and zero it or panic */ void *zcalloc(size_t size) { void *ptr = ztrycalloc_usable(size, NULL); if (!ptr) zmalloc_oom_handler(size); return ptr; } /* Try allocating memory, and return NULL if failed. */ void *ztrycalloc(size_t size) { void *ptr = ztrycalloc_usable(size, NULL); return ptr; } /* Allocate memory or panic. * '*usable' is set to the usable size if non NULL. */ void *zcalloc_usable(size_t size, size_t *usable) { void *ptr = ztrycalloc_usable(size, usable); if (!ptr) zmalloc_oom_handler(size); return ptr; } /* Try reallocating memory, and return NULL if failed. * '*usable' is set to the usable size if non NULL. */ void *ztryrealloc_usable(void *ptr, size_t size, size_t *usable) { ASSERT_NO_SIZE_OVERFLOW(size); #ifndef HAVE_MALLOC_SIZE void *realptr; #endif size_t oldsize; void *newptr; /* not allocating anything, just redirect to free. */ if (size == 0 && ptr != NULL) { zfree(ptr); if (usable) *usable = 0; return NULL; } /* Not freeing anything, just redirect to malloc. */ if (ptr == NULL) return ztrymalloc_usable(size, usable); #ifdef HAVE_MALLOC_SIZE oldsize = zmalloc_size(ptr); newptr = realloc(ptr,size); if (newptr == NULL) { if (usable) *usable = 0; return NULL; } update_zmalloc_stat_free(oldsize); size = zmalloc_size(newptr); update_zmalloc_stat_alloc(size); if (usable) *usable = size; return newptr; #else realptr = (char*)ptr-PREFIX_SIZE; oldsize = *((size_t*)realptr); newptr = realloc(realptr,size+PREFIX_SIZE); if (newptr == NULL) { if (usable) *usable = 0; return NULL; } *((size_t*)newptr) = size; update_zmalloc_stat_free(oldsize); update_zmalloc_stat_alloc(size); if (usable) *usable = size; return (char*)newptr+PREFIX_SIZE; #endif } /* Reallocate memory and zero it or panic */ void *zrealloc(void *ptr, size_t size) { ptr = ztryrealloc_usable(ptr, size, NULL); if (!ptr && size != 0) zmalloc_oom_handler(size); return ptr; } /* Try Reallocating memory, and return NULL if failed. */ void *ztryrealloc(void *ptr, size_t size) { ptr = ztryrealloc_usable(ptr, size, NULL); return ptr; } /* Reallocate memory or panic. * '*usable' is set to the usable size if non NULL. */ void *zrealloc_usable(void *ptr, size_t size, size_t *usable) { ptr = ztryrealloc_usable(ptr, size, usable); if (!ptr && size != 0) zmalloc_oom_handler(size); return ptr; } /* Provide zmalloc_size() for systems where this function is not provided by * malloc itself, given that in that case we store a header with this * information as the first bytes of every allocation. */ #ifndef HAVE_MALLOC_SIZE size_t zmalloc_size(void *ptr) { void *realptr = (char*)ptr-PREFIX_SIZE; size_t size = *((size_t*)realptr); return size+PREFIX_SIZE; } size_t zmalloc_usable_size(void *ptr) { return zmalloc_size(ptr)-PREFIX_SIZE; } #endif void zfree(void *ptr) { #ifndef HAVE_MALLOC_SIZE void *realptr; size_t oldsize; #endif if (ptr == NULL) return; #ifdef HAVE_MALLOC_SIZE update_zmalloc_stat_free(zmalloc_size(ptr)); free(ptr); #else realptr = (char*)ptr-PREFIX_SIZE; oldsize = *((size_t*)realptr); update_zmalloc_stat_free(oldsize+PREFIX_SIZE); free(realptr); #endif } void zmalloc_set_oom_handler(void (*oom_handler)(size_t)) { zmalloc_oom_handler = oom_handler; } /* Get the RSS information in an OS-specific way. * * WARNING: the function zmalloc_get_rss() is not designed to be fast * and may not be called in the busy loops where Redis tries to release * memory expiring or swapping out objects. * * For this kind of "fast RSS reporting" usages use instead the * function RedisEstimateRSS() that is a much faster (and less precise) * version of the function. */ #if defined(HAVE_PROC_STAT) #include #include #include size_t zmalloc_get_rss(void) { int page = sysconf(_SC_PAGESIZE); size_t rss; char buf[4096]; char filename[256]; int fd, count; char *p, *x; snprintf(filename,256,"/proc/%ld/stat",(long) getpid()); if ((fd = open(filename,O_RDONLY)) == -1) return 0; if (read(fd,buf,4096) <= 0) { close(fd); return 0; } close(fd); p = buf; count = 23; /* RSS is the 24th field in /proc//stat */ while(p && count--) { p = strchr(p,' '); if (p) p++; } if (!p) return 0; x = strchr(p,' '); if (!x) return 0; *x = '\0'; rss = strtoll(p,NULL,10); rss *= page; return rss; } #elif defined(HAVE_TASKINFO) #include #include #include #include size_t zmalloc_get_rss(void) { task_t task = MACH_PORT_NULL; struct task_basic_info t_info; mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; if (task_for_pid(current_task(), getpid(), &task) != KERN_SUCCESS) return 0; task_info(task, TASK_BASIC_INFO, (task_info_t)&t_info, &t_info_count); return t_info.resident_size; } #elif defined(__FreeBSD__) || defined(__DragonFly__) #include #include #include size_t zmalloc_get_rss(void) { struct kinfo_proc info; size_t infolen = sizeof(info); int mib[4]; mib[0] = CTL_KERN; mib[1] = KERN_PROC; mib[2] = KERN_PROC_PID; mib[3] = getpid(); if (sysctl(mib, 4, &info, &infolen, NULL, 0) == 0) #if defined(__FreeBSD__) return (size_t)info.ki_rssize * getpagesize(); #else return (size_t)info.kp_vm_rssize * getpagesize(); #endif return 0L; } #elif defined(__NetBSD__) #include #include size_t zmalloc_get_rss(void) { struct kinfo_proc2 info; size_t infolen = sizeof(info); int mib[6]; mib[0] = CTL_KERN; mib[1] = KERN_PROC; mib[2] = KERN_PROC_PID; mib[3] = getpid(); mib[4] = sizeof(info); mib[5] = 1; if (sysctl(mib, 4, &info, &infolen, NULL, 0) == 0) return (size_t)info.p_vm_rssize * getpagesize(); return 0L; } #elif defined(HAVE_PSINFO) #include #include #include size_t zmalloc_get_rss(void) { struct prpsinfo info; char filename[256]; int fd; snprintf(filename,256,"/proc/%ld/psinfo",(long) getpid()); if ((fd = open(filename,O_RDONLY)) == -1) return 0; if (ioctl(fd, PIOCPSINFO, &info) == -1) { close(fd); return 0; } close(fd); return info.pr_rssize; } #else size_t zmalloc_get_rss(void) { /* If we can't get the RSS in an OS-specific way for this system just * return the memory usage we estimated in zmalloc().. * * Fragmentation will appear to be always 1 (no fragmentation) * of course... */ return zmalloc_used_memory(); } #endif #if defined(USE_JEMALLOC) int zmalloc_get_allocator_info(size_t *allocated, size_t *active, size_t *resident) { uint64_t epoch = 1; size_t sz; *allocated = *resident = *active = 0; /* Update the statistics cached by mallctl. */ sz = sizeof(epoch); je_mallctl("epoch", &epoch, &sz, &epoch, sz); sz = sizeof(size_t); /* Unlike RSS, this does not include RSS from shared libraries and other non * heap mappings. */ je_mallctl("stats.resident", resident, &sz, NULL, 0); /* Unlike resident, this doesn't not include the pages jemalloc reserves * for re-use (purge will clean that). */ je_mallctl("stats.active", active, &sz, NULL, 0); /* Unlike zmalloc_used_memory, this matches the stats.resident by taking * into account all allocations done by this process (not only zmalloc). */ je_mallctl("stats.allocated", allocated, &sz, NULL, 0); return 1; } void set_jemalloc_bg_thread(int enable) { /* let jemalloc do purging asynchronously, required when there's no traffic * after flushdb */ char val = !!enable; je_mallctl("background_thread", NULL, 0, &val, 1); } int jemalloc_purge() { /* return all unused (reserved) pages to the OS */ char tmp[32]; unsigned narenas = 0; size_t sz = sizeof(unsigned); if (!je_mallctl("arenas.narenas", &narenas, &sz, NULL, 0)) { sprintf(tmp, "arena.%d.purge", narenas); if (!je_mallctl(tmp, NULL, 0, NULL, 0)) return 0; } return -1; } #else int zmalloc_get_allocator_info(size_t *allocated, size_t *active, size_t *resident) { *allocated = *resident = *active = 0; return 1; } void set_jemalloc_bg_thread(int enable) { ((void)(enable)); } int jemalloc_purge() { return 0; } #endif #if defined(__APPLE__) /* For proc_pidinfo() used later in zmalloc_get_smap_bytes_by_field(). * Note that this file cannot be included in zmalloc.h because it includes * a Darwin queue.h file where there is a "LIST_HEAD" macro (!) defined * conficting with Redis user code. */ #include #endif /* Get the sum of the specified field (converted form kb to bytes) in * /proc/self/smaps. The field must be specified with trailing ":" as it * apperas in the smaps output. * * If a pid is specified, the information is extracted for such a pid, * otherwise if pid is -1 the information is reported is about the * current process. * * Example: zmalloc_get_smap_bytes_by_field("Rss:",-1); */ #if defined(HAVE_PROC_SMAPS) size_t zmalloc_get_smap_bytes_by_field(char *field, long pid) { char line[1024]; size_t bytes = 0; int flen = strlen(field); FILE *fp; if (pid == -1) { fp = fopen("/proc/self/smaps","r"); } else { char filename[128]; snprintf(filename,sizeof(filename),"/proc/%ld/smaps",pid); fp = fopen(filename,"r"); } if (!fp) return 0; while(fgets(line,sizeof(line),fp) != NULL) { if (strncmp(line,field,flen) == 0) { char *p = strchr(line,'k'); if (p) { *p = '\0'; bytes += strtol(line+flen,NULL,10) * 1024; } } } fclose(fp); return bytes; } #else /* Get sum of the specified field from libproc api call. * As there are per page value basis we need to convert * them accordingly. * * Note that AnonHugePages is a no-op as THP feature * is not supported in this platform */ size_t zmalloc_get_smap_bytes_by_field(char *field, long pid) { #if defined(__APPLE__) struct proc_regioninfo pri; if (pid == -1) pid = getpid(); if (proc_pidinfo(pid, PROC_PIDREGIONINFO, 0, &pri, PROC_PIDREGIONINFO_SIZE) == PROC_PIDREGIONINFO_SIZE) { int pagesize = getpagesize(); if (!strcmp(field, "Private_Dirty:")) { return (size_t)pri.pri_pages_dirtied * pagesize; } else if (!strcmp(field, "Rss:")) { return (size_t)pri.pri_pages_resident * pagesize; } else if (!strcmp(field, "AnonHugePages:")) { return 0; } } return 0; #endif ((void) field); ((void) pid); return 0; } #endif /* Return the total number bytes in pages marked as Private Dirty. * * Note: depending on the platform and memory footprint of the process, this * call can be slow, exceeding 1000ms! */ size_t zmalloc_get_private_dirty(long pid) { return zmalloc_get_smap_bytes_by_field("Private_Dirty:",pid); } /* Returns the size of physical memory (RAM) in bytes. * It looks ugly, but this is the cleanest way to achieve cross platform results. * Cleaned up from: * * http://nadeausoftware.com/articles/2012/09/c_c_tip_how_get_physical_memory_size_system * * Note that this function: * 1) Was released under the following CC attribution license: * http://creativecommons.org/licenses/by/3.0/deed.en_US. * 2) Was originally implemented by David Robert Nadeau. * 3) Was modified for Redis by Matt Stancliff. * 4) This note exists in order to comply with the original license. */ size_t zmalloc_get_memory_size(void) { #if defined(__unix__) || defined(__unix) || defined(unix) || \ (defined(__APPLE__) && defined(__MACH__)) #if defined(CTL_HW) && (defined(HW_MEMSIZE) || defined(HW_PHYSMEM64)) int mib[2]; mib[0] = CTL_HW; #if defined(HW_MEMSIZE) mib[1] = HW_MEMSIZE; /* OSX. --------------------- */ #elif defined(HW_PHYSMEM64) mib[1] = HW_PHYSMEM64; /* NetBSD, OpenBSD. --------- */ #endif int64_t size = 0; /* 64-bit */ size_t len = sizeof(size); if (sysctl( mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; return 0L; /* Failed? */ #elif defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) /* FreeBSD, Linux, OpenBSD, and Solaris. -------------------- */ return (size_t)sysconf(_SC_PHYS_PAGES) * (size_t)sysconf(_SC_PAGESIZE); #elif defined(CTL_HW) && (defined(HW_PHYSMEM) || defined(HW_REALMEM)) /* DragonFly BSD, FreeBSD, NetBSD, OpenBSD, and OSX. -------- */ int mib[2]; mib[0] = CTL_HW; #if defined(HW_REALMEM) mib[1] = HW_REALMEM; /* FreeBSD. ----------------- */ #elif defined(HW_PHYSMEM) mib[1] = HW_PHYSMEM; /* Others. ------------------ */ #endif unsigned int size = 0; /* 32-bit */ size_t len = sizeof(size); if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; return 0L; /* Failed? */ #else return 0L; /* Unknown method to get the data. */ #endif #else return 0L; /* Unknown OS. */ #endif } #ifdef REDIS_TEST #define UNUSED(x) ((void)(x)) int zmalloc_test(int argc, char **argv, int accurate) { void *ptr; UNUSED(argc); UNUSED(argv); UNUSED(accurate); printf("Malloc prefix size: %d\n", (int) PREFIX_SIZE); printf("Initial used memory: %zu\n", zmalloc_used_memory()); ptr = zmalloc(123); printf("Allocated 123 bytes; used: %zu\n", zmalloc_used_memory()); ptr = zrealloc(ptr, 456); printf("Reallocated to 456 bytes; used: %zu\n", zmalloc_used_memory()); zfree(ptr); printf("Freed pointer; used: %zu\n", zmalloc_used_memory()); return 0; } #endif ================================================ FILE: src/redis/zmalloc.h ================================================ /* zmalloc - total amount of allocated memory aware version of malloc() * * Copyright (c) 2009-2010, Salvatore Sanfilippo * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of Redis nor the names of its contributors may be used * to endorse or promote products derived from this software without * specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #ifndef __ZMALLOC_H #define __ZMALLOC_H #include /* Double expansion needed for stringification of macro values. */ #define __xstr(s) __zm_str(s) #define __zm_str(s) #s #if defined(USE_JEMALLOC) #define ZMALLOC_LIB ("jemalloc-" __xstr(JEMALLOC_VERSION_MAJOR) "." __xstr(JEMALLOC_VERSION_MINOR) "." __xstr(JEMALLOC_VERSION_BUGFIX)) #include #if (JEMALLOC_VERSION_MAJOR == 2 && JEMALLOC_VERSION_MINOR >= 1) || (JEMALLOC_VERSION_MAJOR > 2) #define HAVE_MALLOC_SIZE 1 #define zmalloc_size(p) je_malloc_usable_size(p) #else #error "Newer version of jemalloc required" #endif #elif defined(__APPLE__) #include #define HAVE_MALLOC_SIZE 1 #ifdef USE_ZMALLOC_MI #define zmalloc_size(p) zmalloc_usable_size(p) #else #define zmalloc_size(p) malloc_size(p) #endif #define ZMALLOC_LIB "macos" #endif /* On native libc implementations, we should still do our best to provide a * HAVE_MALLOC_SIZE capability. This can be set explicitly as well: * * NO_MALLOC_USABLE_SIZE disables it on all platforms, even if they are * known to support it. * USE_MALLOC_USABLE_SIZE forces use of malloc_usable_size() regardless * of platform. */ #ifndef ZMALLOC_LIB #define ZMALLOC_LIB "libc" #include #define HAVE_MALLOC_SIZE 1 #ifdef USE_ZMALLOC_MI #define zmalloc_size(p) zmalloc_usable_size(p) #else #define zmalloc_size(p) malloc_usable_size(p) #endif #endif // ZMALLOC_LIB /* We can enable the Redis defrag capabilities only if we are using Jemalloc * and the version used is our special version modified for Redis having * the ability to return per-allocation fragmentation hints. */ #if defined(USE_JEMALLOC) && defined(JEMALLOC_FRAG_HINT) #define HAVE_DEFRAG #endif void *zmalloc(size_t size); void *zcalloc(size_t size); void *zrealloc(void *ptr, size_t size); void *ztrymalloc(size_t size); void *ztrycalloc(size_t size); void *ztryrealloc(void *ptr, size_t size); void zfree(void *ptr); size_t znallocx(size_t size); // Equivalent to nallocx for jemalloc or mi_good_size for mimalloc. void zfree_size(void* ptr, size_t size); // equivalent to sdallocx or mi_free_size void *zmalloc_usable(size_t size, size_t *usable); void *zcalloc_usable(size_t size, size_t *usable); void *zrealloc_usable(void *ptr, size_t size, size_t *usable); void *ztrymalloc_usable(size_t size, size_t *usable); void *ztrycalloc_usable(size_t size, size_t *usable); void *ztryrealloc_usable(void *ptr, size_t size, size_t *usable); // size_t zmalloc_used_memory(void); void zmalloc_set_oom_handler(void (*oom_handler)(size_t)); size_t zmalloc_get_rss(void); int zmalloc_get_allocator_info(size_t *allocated, size_t *active, size_t *resident); void set_jemalloc_bg_thread(int enable); int jemalloc_purge(); size_t zmalloc_get_private_dirty(long pid); size_t zmalloc_get_smap_bytes_by_field(char *field, long pid); size_t zmalloc_get_memory_size(void); size_t zmalloc_usable_size(const void* p); /* get the memory usage + the number of wasted locations of memory Based on a given threshold (ratio < 1). Note that if a block is not used, it would not counted as wasted */ int zmalloc_get_allocator_wasted_blocks(float ratio, size_t* allocated, size_t* commited, size_t* wasted); struct fragmentation_info { size_t committed; // a temporary metric to compare against "committed" in production. // TODO: delete it once we are confident committed is computed correctly. size_t committed_golden; size_t wasted; unsigned bin; }; // Like zmalloc_get_allocator_wasted_blocks but incremental. // struct fragmentation_info must be passed first set to zero. Returns -1 needs to continue, // 0 if done. int zmalloc_get_allocator_fragmentation_step(float ratio, struct fragmentation_info* info); /* * checks whether a page that the pointer ptr located at is underutilized. * This uses the current local thread heap. * return 0 if not, 1 if underutilized */ struct mi_page_usage_stats_s; void zmalloc_page_is_underutilized(void* ptr, float ratio, int collect_stats, struct mi_page_usage_stats_s* result); char* zstrdup(const char* s); void init_zmalloc_threadlocal(void* heap); extern __thread ssize_t zmalloc_used_memory_tl; #undef __zm_str #undef __xstr #endif /* __ZMALLOC_H */ ================================================ FILE: src/redis/zmalloc_mi.c ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include #include #define MI_BUILD_RELEASE 1 #include #include #include #include "zmalloc.h" __thread ssize_t zmalloc_used_memory_tl = 0; __thread mi_heap_t* zmalloc_heap = NULL; mi_page_usage_stats_t mi_heap_page_is_underutilized(mi_heap_t* heap, void* p, float ratio, bool collect_stats); /* Allocate memory or panic */ void* zmalloc(size_t size) { assert(zmalloc_heap); void* res = mi_heap_malloc(zmalloc_heap, size); size_t usable = mi_usable_size(res); // assertion does not hold. Basically mi_good_size is not a good function for // doing accounting. // assert(usable == mi_good_size(size)); zmalloc_used_memory_tl += usable; return res; } void* ztrymalloc_usable(size_t size, size_t* usable) { return zmalloc_usable(size, usable); } size_t zmalloc_usable_size(const void* p) { return mi_usable_size(p); } void zfree(void* ptr) { size_t usable = mi_usable_size(ptr); // assert(zmalloc_used_memory_tl >= (ssize_t)usable); zmalloc_used_memory_tl -= usable; mi_free_size(ptr, usable); } void* zrealloc(void* ptr, size_t size) { size_t usable; return zrealloc_usable(ptr, size, &usable); } void* zcalloc(size_t size) { // mi_good_size(size) is not working. try for example, size=690557. void* res = mi_heap_calloc(zmalloc_heap, 1, size); size_t usable = mi_usable_size(res); zmalloc_used_memory_tl += usable; return res; } void* zmalloc_usable(size_t size, size_t* usable) { assert(zmalloc_heap); void* res = mi_heap_malloc(zmalloc_heap, size); size_t uss = mi_usable_size(res); *usable = uss; zmalloc_used_memory_tl += uss; return res; } void* zrealloc_usable(void* ptr, size_t size, size_t* usable) { ssize_t prev = mi_usable_size(ptr); void* res = mi_heap_realloc(zmalloc_heap, ptr, size); ssize_t uss = mi_usable_size(res); *usable = uss; zmalloc_used_memory_tl += (uss - prev); return res; } size_t znallocx(size_t size) { return mi_good_size(size); } void zfree_size(void* ptr, size_t size) { ssize_t uss = mi_usable_size(ptr); zmalloc_used_memory_tl -= uss; mi_free_size(ptr, uss); } void* ztrymalloc(size_t size) { size_t usable; return zmalloc_usable(size, &usable); } void* ztrycalloc(size_t size) { size_t g = mi_good_size(size); zmalloc_used_memory_tl += g; void* ptr = mi_heap_calloc(zmalloc_heap, 1, size); assert(mi_usable_size(ptr) == g); return ptr; } typedef struct Sum_s { size_t allocated; size_t comitted; } Sum_t; typedef struct { size_t allocated; size_t comitted; size_t wasted; float ratio; } MemUtilized_t; bool heap_visit_cb(const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { assert(area->used < (1u << 31)); Sum_t* sum = (Sum_t*)arg; // mimalloc mistakenly exports used in blocks instead of bytes. sum->allocated += block_size * area->used; sum->comitted += area->committed; return true; // continue iteration }; bool heap_count_wasted_blocks(const mi_heap_t* heap, const mi_heap_area_t* area, void* block, size_t block_size, void* arg) { assert(area->used < (1u << 31)); MemUtilized_t* sum = (MemUtilized_t*)arg; // mimalloc mistakenly exports used in blocks instead of bytes. size_t used = block_size * area->used; sum->allocated += used; sum->comitted += area->committed; if (used < area->committed * sum->ratio) { sum->wasted += (area->committed - used); } return true; // continue iteration }; int zmalloc_get_allocator_info(size_t* allocated, size_t* active, size_t* resident) { Sum_t sum = {0}; mi_heap_visit_blocks(zmalloc_heap, false /* visit all blocks*/, heap_visit_cb, &sum); *allocated = sum.allocated; *resident = sum.comitted; *active = 0; return 1; } int zmalloc_get_allocator_wasted_blocks(float ratio, size_t* allocated, size_t* commited, size_t* wasted) { MemUtilized_t sum = {.allocated = 0, .comitted = 0, .wasted = 0, .ratio = ratio}; mi_heap_visit_blocks(zmalloc_heap, false /* visit all blocks*/, heap_count_wasted_blocks, &sum); *allocated = sum.allocated; *commited = sum.comitted; *wasted = sum.wasted; return 1; } // Implemented based on this mimalloc code: // https://github.com/microsoft/mimalloc/blob/main/src/heap.c#L27 int zmalloc_get_allocator_fragmentation_step(float ratio, struct fragmentation_info* info) { if (zmalloc_heap->page_count == 0 || info->bin >= MI_BIN_FULL) { // We avoid iterating over full pages since they are fully utilized. return 0; } mi_page_queue_t* pq = &zmalloc_heap->pages[info->bin]; const mi_page_t* page = pq->first; while (page != NULL) { const mi_page_t* next = page->next; const size_t bsize = page->block_size; size_t committed = page->capacity * bsize; info->committed += committed; if (page->used < page->capacity) { size_t used = page->used * bsize; size_t threshold = (double)committed * ratio; if (used < threshold) { info->wasted += (committed - used); } } page = next; } info->bin++; if (info->bin == MI_BIN_FULL) { // reached end of bins, reset state info->committed_golden = info->committed; // Add total comitted size of MI_BIN_FULL that we do not traverse // as its tracked by zmalloc_heap->full_page_size variable. info->committed += zmalloc_heap->full_page_size; // TODO: it's a test code that makes sure `full_page_size` is correct. // Remove it once we are confident with the implementation. mi_page_queue_t* pq = &zmalloc_heap->pages[MI_BIN_FULL]; const mi_page_t* page = pq->first; while (page != NULL) { info->committed_golden += page->capacity * page->block_size; page = page->next; } info->bin = 0; return 0; } return -1; } void init_zmalloc_threadlocal(void* heap) { if (zmalloc_heap) return; zmalloc_heap = heap; } void zmalloc_page_is_underutilized(void* ptr, float ratio, int collect_stats, mi_page_usage_stats_t* result) { *result = mi_heap_page_is_underutilized(zmalloc_heap, ptr, ratio, collect_stats); } char* zstrdup(const char* s) { size_t l = strlen(s) + 1; char* p = zmalloc(l); memcpy(p, s, l); return p; } ================================================ FILE: src/server/CMakeLists.txt ================================================ option(DF_ENABLE_MEMORY_TRACKING "Adds memory tracking debugging via MEMORY TRACK command" ON) option(PRINT_STACKTRACES_ON_SIGNAL "Enables DF to print all fiber stacktraces on SIGUSR1" OFF) option(WITH_COLLECTION_CMDS "Compile SET/HASH/ZSET/STREAM commands" ON) option(WITH_EXTENSION_CMDS "Compile BLOOM/BITOPS/GEO/HLL/JSON commands" ON) option(WITH_TIERING "Compile for macos" ON) if(APPLE) message(STATUS "Macos detected. Set WITH_TIERING=off") set(WITH_TIERING OFF CACHE BOOL "Compile for macos" FORCE) endif() add_executable(dragonfly dfly_main.cc version_monitor.cc) add_custom_target(check_dfly WORKING_DIRECTORY .. COMMAND ctest -L DFLY) cxx_link(dragonfly base dragonfly_lib) if (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_BUILD_TYPE STREQUAL "Release") # Add core2 only to this file, thus avoiding instructions in this object file that # can cause SIGILL. set_source_files_properties(dfly_main.cc PROPERTIES COMPILE_FLAGS "-march=core2") endif() set_property(SOURCE dfly_main.cc APPEND PROPERTY COMPILE_DEFINITIONS SOURCE_PATH_FROM_BUILD_ENV=${CMAKE_SOURCE_DIR}) add_executable(dfly_bench dfly_bench.cc) cxx_link(dfly_bench dfly_parser_lib fibers2 absl::random_random redis_lib) # Include journal sources (not separate target for now) add_subdirectory(journal) if(NOT DEFINED DF_JOURNAL_SRCS) message(FATAL_ERROR "Journal source files not exported via DF_JOURNAL_SRCS") endif() # Define transaction library add_library(dfly_transaction db_slice.cc blocking_controller.cc cluster_support.cc common.cc command_registry.cc execution_state.cc stats.cc synchronization.cc ${DF_JOURNAL_SRCS} server_state.cc table.cc transaction.cc tx_base.cc serializer_commons.cc acl/acl_log.cc slowlog.cc channel_store.cc) cxx_link(dfly_transaction dfly_core strings_lib TRDP::fast_float TRDP::hdr_histogram) # Include search module add_subdirectory(search) if(NOT DEFINED DF_SEARCH_SRCS) message(FATAL_ERROR "Search source files not exported via DF_SEARCH_SRCS") endif() if (WITH_SEARCH) add_definitions(-DWITH_SEARCH) endif() # Include tiering module add_subdirectory(tiering) if (WITH_TIERING) add_definitions(-DWITH_TIERING) SET(DF_TIERING_SRCS tiered_storage.cc) helio_cxx_test(tiered_storage_test dfly_test_lib LABELS DFLY) endif() # Include cluster sources definitons (not separate target for now) add_subdirectory(cluster) if (NOT DEFINED DF_CLUSTER_SRCS) message(FATAL_ERROR "Cluster source files not exported via DF_CLUSTER_SRCS") endif() # Optionally compile collection commands if (WITH_COLLECTION_CMDS) set(DF_FAMILY_SRCS set_family.cc hset_family.cc zset_family.cc stream_family.cc) add_definitions(-DWITH_COLLECTION_CMDS) else() set(DF_FAMILY_SRCS collection_family_fallback.cc) endif() # Optionally compile extension commands if (WITH_EXTENSION_CMDS) list(APPEND DF_FAMILY_SRCS geo_family.cc hll_family.cc bitops_family.cc bloom_family.cc cms_family.cc json_family.cc) add_definitions(-DWITH_EXTENSION_CMDS) endif() # Optionally include tiered_storage which interfaces with tiering_module add_library(dragonfly_lib engine_shard.cc engine_shard_set.cc config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc error.cc family_utils.cc string_stats.cc ${DF_SEARCH_SRCS} server_family.cc string_family.cc list_family.cc generic_family.cc ${DF_FAMILY_SRCS} main_service.cc memory_cmd.cc rdb_load.cc rdb_load_context.cc rdb_save.cc replica.cc http_api.cc protocol_client.cc serializer_base.cc snapshot.cc script_mgr.cc detail/compressor.cc detail/decompress.cc detail/save_stages_controller.cc detail/snapshot_storage.cc version.cc container_utils.cc multi_command_squasher.cc ${DF_TIERING_SRCS} ${DF_CLUSTER_SRCS} acl/user.cc acl/user_registry.cc acl/acl_family.cc acl/validator.cc sharding.cc cmd_support.cc) if (DF_ENABLE_MEMORY_TRACKING) target_compile_definitions(dragonfly_lib PRIVATE DFLY_ENABLE_MEMORY_TRACKING) target_compile_definitions(dragonfly PRIVATE DFLY_ENABLE_MEMORY_TRACKING) endif() if (PRINT_STACKTRACES_ON_SIGNAL) target_compile_definitions(dragonfly_lib PRIVATE PRINT_STACKTRACES_ON_SIGNAL) endif() if (WITH_AWS) SET(AWS_LIB awsv2_lib) add_definitions(-DWITH_AWS) endif() if (WITH_GCP) SET(GCP_LIB gcp_lib) add_definitions(-DWITH_GCP) endif() cxx_link(dragonfly_lib dfly_transaction dfly_facade dfly_tiering redis_lib ${AWS_LIB} ${GCP_LIB} azure_lib jsonpath strings_lib html_lib http_client_lib absl::random_random TRDP::jsoncons TRDP::zstd TRDP::lz4 TRDP::croncpp TRDP::flatbuffers) if (DF_USE_SSL) set(TLS_LIB tls_lib) target_compile_definitions(dragonfly_lib PRIVATE DFLY_USE_SSL) endif() add_library(dfly_test_lib test_utils.cc) cxx_link(dfly_test_lib dragonfly_lib facade_test gtest_main_ext) helio_cxx_test(dragonfly_test dfly_test_lib LABELS DFLY) helio_cxx_test(multi_test dfly_test_lib LABELS DFLY) helio_cxx_test(generic_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(hset_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(list_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(server_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(set_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(stream_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(string_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(bitops_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(rdb_test dfly_test_lib DATA testdata/empty.rdb testdata/redis6_small.rdb testdata/redis6_stream.rdb testdata/hll.rdb testdata/redis7_small.rdb testdata/redis_json.rdb testdata/RDB_TYPE_STREAM_LISTPACKS_2.rdb testdata/RDB_TYPE_STREAM_LISTPACKS_3.rdb testdata/ignore_expiry.rdb LABELS DFLY) helio_cxx_test(zset_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(geo_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(blocking_controller_test dfly_test_lib LABELS DFLY) helio_cxx_test(json_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(json_family_memory_test dfly_test_lib LABELS DFLY) helio_cxx_test(journal/journal_test dfly_test_lib LABELS DFLY) helio_cxx_test(hll_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(string_stats_test dfly_test_lib LABELS DFLY) helio_cxx_test(bloom_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(cms_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(cluster/cluster_config_test dfly_test_lib LABELS DFLY) helio_cxx_test(cluster/cluster_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(acl/acl_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(engine_shard_set_test dfly_test_lib LABELS DFLY) helio_cxx_test(serializer_base_test dfly_test_lib LABELS DFLY) add_dependencies(check_dfly dragonfly_test json_family_test list_family_test generic_family_test memcache_parser_test rdb_test journal_test redis_parser_test stream_family_test string_family_test bitops_family_test set_family_test zset_family_test geo_family_test hll_family_test cluster_config_test cluster_family_test acl_family_test json_family_memory_test) if (WITH_SEARCH) helio_cxx_test(search/search_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(search/aggregator_test dfly_test_lib LABELS DFLY) helio_cxx_test(search/index_join_test dfly_test_lib LABELS DFLY) add_dependencies(check_dfly search_family_test aggregator_test index_join_test) endif() ================================================ FILE: src/server/acl/acl_commands_def.h ================================================ // Copyright 2026, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "base/logging.h" namespace dfly::acl { /* There are 21 ACL categories as of redis 7 * */ enum AclCat { KEYSPACE = 1ULL << 0, READ = 1ULL << 1, WRITE = 1ULL << 2, SET = 1ULL << 3, SORTEDSET = 1ULL << 4, LIST = 1ULL << 5, HASH = 1ULL << 6, STRING = 1ULL << 7, BITMAP = 1ULL << 8, HYPERLOGLOG = 1ULL << 9, GEO = 1ULL << 10, STREAM = 1ULL << 11, PUBSUB = 1ULL << 12, ADMIN = 1ULL << 13, FAST = 1ULL << 14, SLOW = 1ULL << 15, BLOCKING = 1ULL << 16, DANGEROUS = 1ULL << 17, CONNECTION = 1ULL << 18, TRANSACTION = 1ULL << 19, SCRIPTING = 1ULL << 20, // Extensions CMS = 1ULL << 27, BLOOM = 1ULL << 28, FT_SEARCH = 1ULL << 29, THROTTLE = 1ULL << 30, JSON = 1ULL << 31 }; constexpr uint64_t ALL_COMMANDS = std::numeric_limits::max(); constexpr uint64_t NONE_COMMANDS = std::numeric_limits::min(); inline size_t NumberOfFamilies(size_t number = 0) { static size_t number_of_families = number; return number_of_families; } using CategoryIndexTable = absl::flat_hash_map; using ReverseCategoryIndexTable = std::vector; // bit index to index in the REVERSE_CATEGORY_INDEX_TABLE using CategoryToIdxStore = absl::flat_hash_map; using RevCommandField = std::vector; using RevCommandsIndexStore = std::vector; using CategoryToCommandsIndexStore = absl::flat_hash_map>; // Special flag/mask for all constexpr uint32_t NONE = 0; constexpr uint32_t ALL = std::numeric_limits::max(); enum class KeyOp : int8_t { READ, WRITE, READ_WRITE }; using GlobType = std::pair; struct AclKeys { std::vector key_globs; // The user is allowed to "touch" any key. No glob matching required. // Alias for ~* bool all_keys = false; }; // The second bool denotes if the pattern contains an asterisk and it's // used to pattern match PSUBSCRIBE that requires exact literals using GlobTypePubSub = std::pair; struct AclPubSub { std::vector globs; // The user can execute any variant of pub/sub/psub. No glob matching required. // Alias for &* just like all_keys for AclKeys above. bool all_channels = false; }; struct UserCredentials { uint32_t acl_categories{0}; std::vector acl_commands; AclKeys keys; AclPubSub pub_sub; std::string ns; size_t db{0}; }; } // namespace dfly::acl ================================================ FILE: src/server/acl/acl_family.cc ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. #include "server/acl/acl_family.h" #include #include #include #include #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_set.h" #include "absl/flags/commandlineflag.h" #include "absl/strings/escaping.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "base/flags.h" #include "base/logging.h" #include "core/overloaded.h" #include "facade/dragonfly_connection.h" #include "facade/dragonfly_listener.h" #include "facade/facade_types.h" #include "facade/reply_builder.h" #include "io/file.h" #include "io/file_util.h" #include "server/acl/acl_commands_def.h" #include "server/acl/acl_log.h" #include "server/acl/validator.h" #include "server/command_registry.h" #include "server/common.h" #include "server/config_registry.h" #include "server/conn_context.h" #include "server/error.h" #include "server/server_state.h" #include "util/proactor_pool.h" using namespace std; ABSL_FLAG(string, aclfile, "", "Path and name to aclfile"); ABSL_DECLARE_FLAG(uint32_t, dbnum); namespace dfly::acl { namespace { string PasswordsToString(const absl::flat_hash_set& passwords, bool nopass, bool full_sha); using MaterializedContents = optional>>; MaterializedContents MaterializeFileContents(vector* usernames, string_view file_contents); string AclKeysToString(const AclKeys& keys); string AclPubSubToString(const AclPubSub& pub_sub); void SendAclSecurityEvents(const AclLog::LogEntry& entry, facade::RedisReplyBuilder* rb); string AclDbToString(size_t db); template void TraverseEvictImpl(P predicate, facade::Listener* main_listener, util::ProactorPool* pool); } // namespace AclFamily::AclFamily(UserRegistry* registry, util::ProactorPool* pool) : registry_(registry), pool_(pool) { dbnum_ = absl::GetFlag(FLAGS_dbnum); } void AclFamily::Acl(CmdArgList args, CommandContext* cmd_cntx) { cmd_cntx->SendError("Wrong number of arguments for acl command"); } void AclFamily::List(CmdArgList args, CommandContext* cmd_cntx) { const auto registry_with_lock = registry_->GetRegistryWithLock(); const auto& registry = registry_with_lock.registry; auto* rb = static_cast(cmd_cntx->rb()); rb->StartArray(registry.size()); for (const auto& [username, user] : registry) { string buffer = "user "; const string password = PasswordsToString(user.Passwords(), user.HasNopass(), false); const string acl_keys = AclKeysToString(user.Keys()); const string acl_pub_sub = AclPubSubToString(user.PubSub()); const string maybe_space_com = acl_keys.empty() ? "" : " "; const string acl_cat_and_commands = AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); const string db_index = AclDbToString(user.Db()); using namespace string_view_literals; absl::StrAppend(&buffer, username, " ", user.IsActive() ? "on "sv : "off "sv, password, acl_keys, maybe_space_com, acl_pub_sub, " ", acl_cat_and_commands, " $", db_index); rb->SendSimpleString(buffer); } } void AclFamily::StreamUpdatesToAllProactorConnections(const std::string& user, const Commands& update_commands, const AclKeys& update_keys, const AclPubSub& update_pub_sub, size_t db) { auto update_cb = [&]([[maybe_unused]] size_t id, util::Connection* conn) { DCHECK(conn); auto connection = static_cast(conn); if (!connection->IsHttp() && connection->cntx()) { auto* cntx = static_cast(connection->cntx()); if (user == cntx->authed_username) { cntx->acl_commands = update_commands; cntx->keys = update_keys; cntx->pub_sub = update_pub_sub; cntx->acl_db_idx = db; } } }; if (main_listener_ && main_listener_->protocol() == facade::Protocol::REDIS) { main_listener_->TraverseConnections(update_cb); } } using facade::ErrorReply; void AclFamily::SetUser(CmdArgList args, CommandContext* cmd_cntx) { string_view username = facade::ToSV(args[0]); auto reg = registry_->GetRegistryWithWriteLock(); const bool exists = reg.registry.contains(username); const bool has_all_keys = exists ? reg.registry.find(username)->second.Keys().all_keys : false; auto req = ParseAclSetUser(args.subspan(1), false, has_all_keys); auto error_case = [cmd_cntx](ErrorReply&& error) { cmd_cntx->SendError(error); }; auto update_case = [username, ®, cmd_cntx, this, exists](User::UpdateRequest&& req) { auto& user = reg.registry[username]; if (!exists) { User::UpdateRequest default_req; default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}}; user.Update(std::move(default_req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); } const bool reset_channels = req.reset_channels; user.Update(std::move(req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); // Send ok first because the connection might get evicted cmd_cntx->SendOk(); if (exists) { if (!reset_channels) { StreamUpdatesToAllProactorConnections(string(username), user.AclCommands(), user.Keys(), user.PubSub(), user.Db()); } // We evict connections that had their channels reseted else { EvictOpenConnectionsOnAllProactors({username}); } } }; std::visit(Overloaded{error_case, update_case}, std::move(req)); } void AclFamily::EvictOpenConnectionsOnAllProactors(const absl::flat_hash_set& users) { return TraverseEvictImpl( [&](auto* ctx) { auto* dfly_ctx = static_cast(ctx); return ctx && users.contains(dfly_ctx->authed_username); }, main_listener_, pool_); } void AclFamily::EvictOpenConnectionsOnAllProactorsWithRegistry( const UserRegistry::RegistryType& registry) { return TraverseEvictImpl( [&](auto* ctx) { auto* dfly_ctx = static_cast(ctx); return ctx && dfly_ctx->authed_username != "default" && registry.contains(dfly_ctx->authed_username); }, main_listener_, pool_); } void AclFamily::DelUser(CmdArgList args, CommandContext* cmd_cntx) { auto& registry = *registry_; absl::flat_hash_set users; for (auto arg : args) { string_view username = facade::ToSV(arg); if (username == "default") { continue; } if (registry.RemoveUser(username)) { users.insert(username); } } if (users.empty()) { cmd_cntx->rb()->SendLong(0); return; } VLOG(1) << "Evicting open acl connections"; EvictOpenConnectionsOnAllProactors(users); VLOG(1) << "Done evicting open acl connections"; cmd_cntx->rb()->SendLong(users.size()); } void AclFamily::WhoAmI(CmdArgList args, CommandContext* cmd_cntx) { auto* rb = static_cast(cmd_cntx->rb()); rb->SendBulkString(absl::StrCat("User is ", cmd_cntx->server_conn_cntx()->authed_username)); } string AclFamily::RegistryToString() const { auto registry_with_read_lock = registry_->GetRegistryWithLock(); auto& registry = registry_with_read_lock.registry; string result; for (auto& [username, user] : registry) { string command = "USER "; const string password = PasswordsToString(user.Passwords(), user.HasNopass(), true); const string acl_keys = AclKeysToString(user.Keys()); const string maybe_space = acl_keys.empty() ? "" : " "; const string acl_pub_sub = AclPubSubToString(user.PubSub()); const string acl_cat_and_commands = AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); const string db_index = AclDbToString(user.Db()); using namespace string_view_literals; absl::StrAppend(&result, command, username, " ", user.IsActive() ? "ON "sv : "OFF "sv, password, acl_keys, maybe_space, acl_pub_sub, " ", acl_cat_and_commands, " $", db_index, "\n"); } return result; } void AclFamily::Save(CmdArgList args, CommandContext* cmd_cntx) { auto acl_file_path = absl::GetFlag(FLAGS_aclfile); auto* builder = cmd_cntx->rb(); if (acl_file_path.empty()) { builder->SendError("Dragonfly is not configured to use an ACL file."); return; } auto res = io::OpenWrite(acl_file_path); if (!res) { std::string error = absl::StrCat("Failed to open the aclfile: ", res.error().message()); LOG(ERROR) << error; builder->SendError(error); return; } std::unique_ptr file(res.value()); std::string output = RegistryToString(); auto ec = file->Write(output); if (ec) { std::string error = absl::StrCat("Failed to write to the aclfile: ", ec.message()); LOG(ERROR) << error; builder->SendError(error); return; } ec = file->Close(); if (ec) { std::string error = absl::StrCat("Failed to close the aclfile ", ec.message()); LOG(WARNING) << error; builder->SendError(error); return; } builder->SendOk(); } GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path, SinkReplyBuilder* builder) { auto is_file_read = io::ReadFileToString(full_path); if (!is_file_read) { auto error = absl::StrCat("Dragonfly could not load ACL file ", full_path, " with error ", is_file_read.error().message()); LOG(WARNING) << error; return {std::move(error)}; } auto file_contents = std::move(is_file_read.value()); if (file_contents.empty()) { return {"Empty file"}; } std::vector usernames; auto materialized = MaterializeFileContents(&usernames, file_contents); if (!materialized) { std::string error = "Error materializing acl file"; LOG(WARNING) << error; return {std::move(error)}; } std::vector requests; for (auto& cmds : *materialized) { auto req = ParseAclSetUser(cmds, true); if (std::holds_alternative(req)) { auto error = std::move(std::get(req)); LOG(WARNING) << "Error while parsing aclfile: " << error.ToSv(); return {std::string(error.ToSv())}; } requests.push_back(std::move(std::get(req))); } auto registry_with_wlock = registry_->GetRegistryWithWriteLock(); auto& registry = registry_with_wlock.registry; if (builder) { builder->SendOk(); // Evict open connections for old users EvictOpenConnectionsOnAllProactorsWithRegistry(registry); registry.clear(); } for (size_t i = 0; i < usernames.size(); ++i) { User::UpdateRequest default_req; default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}}; auto& user = registry[usernames[i]]; user.Update(std::move(default_req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); user.Update(std::move(requests[i]), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); } if (!registry.contains("default")) { auto& user = registry["default"]; user.Update(registry_->DefaultUserUpdateRequest(), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); } return {}; } bool AclFamily::Load() { auto acl_file = absl::GetFlag(FLAGS_aclfile); return !LoadToRegistryFromFile(acl_file, nullptr); } void AclFamily::Load(CmdArgList args, CommandContext* cmd_cntx) { auto acl_file = absl::GetFlag(FLAGS_aclfile); auto* rb = static_cast(cmd_cntx->rb()); if (acl_file.empty()) { rb->SendError("Dragonfly is not configured to use an ACL file."); return; } const auto load_error = LoadToRegistryFromFile(acl_file, rb); if (load_error) { rb->SendError(absl::StrCat("Error loading: ", acl_file, " ", load_error.Format())); } } void AclFamily::Log(CmdArgList args, CommandContext* cmd_cntx) { auto* rb = static_cast(cmd_cntx->rb()); if (args.size() > 1) { return rb->SendError(facade::OpStatus::OUT_OF_RANGE); } size_t max_output = 10; if (!args.empty()) { auto option = facade::ToSV(args[0]); if (absl::EqualsIgnoreCase(option, "RESET")) { pool_->AwaitFiberOnAll( [](auto index, auto* context) { ServerState::tlocal()->acl_log.Reset(); }); rb->SendOk(); return; } if (!absl::SimpleAtoi(facade::ToSV(args[0]), &max_output)) { rb->SendError("Invalid count"); return; } } std::vector logs(pool_->size()); pool_->AwaitFiberOnAll([&logs, max_output](auto index, auto* context) { logs[index] = ServerState::tlocal()->acl_log.GetLog(max_output); }); size_t total_entries = 0; for (auto& log : logs) { total_entries += log.size(); } if (total_entries == 0) { rb->SendEmptyArray(); return; } auto n_way_minimum = [](const auto& logs) { size_t id = 0; AclLog::LogEntry limit; const AclLog::LogEntry* max = &limit; for (size_t i = 0; i < logs.size(); ++i) { if (!logs[i].empty() && logs[i].front() < *max) { id = i; max = &logs[i].front(); } } return id; }; rb->StartArray(total_entries); for (size_t i = 0; i < total_entries; ++i) { const auto min = n_way_minimum(logs); SendAclSecurityEvents(logs[min].front(), rb); logs[min].pop_front(); } } void AclFamily::Users(CmdArgList args, CommandContext* cmd_cntx) { const auto registry_with_lock = registry_->GetRegistryWithLock(); const auto& registry = registry_with_lock.registry; auto* rb = static_cast(cmd_cntx->rb()); rb->StartArray(registry.size()); for (const auto& [username, _] : registry) { rb->SendSimpleString(username); } } void AclFamily::Cat(CmdArgList args, CommandContext* cmd_cntx) { auto* rb = static_cast(cmd_cntx->rb()); if (args.size() > 1) { rb->SendError(facade::OpStatus::SYNTAX_ERR); return; } if (args.size() == 1) { string category = absl::AsciiStrToUpper(ArgS(args, 0)); if (!cat_table_.contains(category)) { auto error = absl::StrCat("Unknown category: ", category); rb->SendError(error); return; } const uint32_t cid_mask = cat_table_.find(category)->second; std::vector results; // TODO replace this with indexer auto cb = [cid_mask, &results](auto name, auto& cid) { if (cid_mask & cid.acl_categories()) { results.push_back(name); } }; cmd_registry_->Traverse(cb); rb->StartArray(results.size()); for (const auto& command : results) { rb->SendSimpleString(command); } return; } size_t total_categories = 0; for (auto& elem : reverse_cat_table_) { if (elem != "_RESERVED") { ++total_categories; } } rb->StartArray(total_categories); for (auto& elem : reverse_cat_table_) { if (elem != "_RESERVED") { rb->SendSimpleString(elem); } } } void AclFamily::GetUser(CmdArgList args, CommandContext* cmd_cntx) { auto username = facade::ToSV(args[0]); const auto registry_with_lock = registry_->GetRegistryWithLock(); const auto& registry = registry_with_lock.registry; auto* rb = static_cast(cmd_cntx->rb()); if (!registry.contains(username)) { rb->SendNull(); return; } auto& user = registry.find(username)->second; std::string status = user.IsActive() ? "on" : "off"; auto pass = PasswordsToString(user.Passwords(), user.HasNopass(), false); if (!pass.empty()) { pass.pop_back(); } rb->StartArray(10); rb->SendSimpleString("flags"); const size_t total_elements = (pass != "nopass") ? 1 : 2; rb->StartArray(total_elements); rb->SendSimpleString(status); if (total_elements == 2) { rb->SendSimpleString(pass); } rb->SendSimpleString("passwords"); if (pass != "nopass" && !pass.empty()) { rb->SendSimpleString(pass); } else { rb->SendEmptyArray(); } rb->SendSimpleString("commands"); const std::string acl_cat_and_commands = AclCatAndCommandToString(user.CatChanges(), user.CmdChanges()); rb->SendSimpleString(acl_cat_and_commands); rb->SendSimpleString("keys"); std::string keys = AclKeysToString(user.Keys()); if (!keys.empty()) { rb->SendSimpleString(keys); } else { rb->SendEmptyArray(); } rb->SendSimpleString("channels"); std::string pub_sub = AclPubSubToString(user.PubSub()); rb->SendSimpleString(pub_sub); } void AclFamily::GenPass(CmdArgList args, CommandContext* cmd_cntx) { auto* builder = cmd_cntx->rb(); if (args.length() > 1) { builder->SendError(facade::UnknownSubCmd("GENPASS", "ACL")); return; } uint32_t random_bits = 256; if (args.length() == 1) { auto requested_bits = facade::ArgS(args, 0); if (!absl::SimpleAtoi(requested_bits, &random_bits) || random_bits == 0 || random_bits > 4096) { return builder->SendError( "ACL GENPASS argument must be the number of bits for the output password, a positive " "number up to 4096"); } } std::random_device urandom("/dev/urandom"); const size_t result_length = (random_bits + 3) / 4; constexpr size_t step_size = sizeof(decltype(std::random_device::max())); std::string response; for (size_t bytes_written = 0; bytes_written < result_length; bytes_written += step_size) { absl::StrAppendFormat(&response, "%08x", urandom()); } response.resize(result_length); builder->SendSimpleString(response); } void AclFamily::DryRun(CmdArgList args, CommandContext* cmd_cntx) { auto* rb = static_cast(cmd_cntx->rb()); auto username = facade::ArgS(args, 0); const auto registry_with_lock = registry_->GetRegistryWithLock(); const auto& registry = registry_with_lock.registry; if (!registry.contains(username)) { auto error = absl::StrCat("User '", username, "' not found"); rb->SendError(error); return; } string command = absl::AsciiStrToUpper(ArgS(args, 1)); auto* cid = cmd_registry_->Find(command); if (!cid || cid->IsAlias()) { auto error = absl::StrCat("Command '", command, "' not found"); rb->SendError(error); return; } const auto& user = registry.find(username)->second; // Stub, used to mimic connection context for a user. ConnectionContext stub(nullptr, acl::UserCredentials{}); stub.acl_commands = user.AclCommandsRef(); // "mock" without an actual connection we can't know which db is active so we skip this check // for DryRun. stub.acl_db_idx = {}; stub.keys = {{}, true}; const auto [is_allowed, reason] = IsUserAllowedToInvokeCommandGeneric(stub, *cid, {}); if (is_allowed) { rb->SendOk(); return; } auto msg = absl::StrCat("This user has no permissions to run the '", command, "' command"); rb->SendBulkString(msg); } void AclFamily::Init(facade::Listener* main_listener, UserRegistry* registry) { main_listener_ = main_listener; registry_ = registry; config_registry.RegisterMutable("requirepass", [this](const absl::CommandLineFlag& flag) { User::UpdateRequest rqst; rqst.passwords.push_back({flag.CurrentValue()}); registry_->MaybeAddAndUpdate("default", std::move(rqst)); return true; }); auto acl_file = absl::GetFlag(FLAGS_aclfile); if (!acl_file.empty() && Load()) { return; } registry_->Init(&CategoryToIdx(), &reverse_cat_table_, &CategoryToCommandsIndex()); } std::string AclFamily::AclCatToString(uint32_t acl_category, User::Sign sign) const { std::string res = sign == User::Sign::PLUS ? "+@" : "-@"; if (acl_category == acl::ALL) { absl::StrAppend(&res, "all"); return res; } const auto& index = CategoryToIdx().at(acl_category); absl::StrAppend(&res, absl::AsciiStrToLower(reverse_cat_table_[index])); return res; } std::string AclFamily::AclCommandToString(size_t family, uint64_t mask, User::Sign sign) const { // This is constant but can be optimized with an indexer const auto& rev_index = CommandsRevIndexer(); std::string res; std::string prefix = (sign == User::Sign::PLUS) ? "+" : "-"; if (mask == ALL_COMMANDS) { for (const auto& cmd : rev_index[family]) { absl::StrAppend(&res, prefix, absl::AsciiStrToLower(cmd), " "); } res.pop_back(); return res; } size_t pos = 0; while (mask != 0) { ++pos; mask = mask >> 1; } --pos; absl::StrAppend(&res, prefix, absl::AsciiStrToLower(rev_index[family][pos])); return res; } namespace { struct CategoryAndMetadata { User::CategoryChange change; User::ChangeMetadata metadata; }; struct CommandAndMetadata { User::CommandChange change; User::ChangeMetadata metadata; }; using MergeResult = std::vector>; MergeResult MergeTables(const User::CategoryChanges& categories, const User::CommandChanges& commands) { MergeResult result; for (auto [cat, meta] : categories) { result.push_back(CategoryAndMetadata{cat, meta}); } for (auto [cmd, meta] : commands) { result.push_back(CommandAndMetadata{cmd, meta}); } std::sort(result.begin(), result.end(), [](const auto& l, const auto& r) { auto fetch = [](const auto& l) { return l.metadata.seq_no; }; return std::visit(fetch, l) < std::visit(fetch, r); }); return result; } using MaterializedContents = std::optional>>; MaterializedContents MaterializeFileContents(std::vector* usernames, std::string_view file_contents) { // This is fine, a very large file will top at 1-2 mb. And that's for 5000+ users with 400 // characters per line std::vector commands = absl::StrSplit(file_contents, "\n"); std::vector> materialized; materialized.reserve(commands.size()); usernames->reserve(commands.size()); for (auto& command : commands) { if (command.empty()) continue; std::vector cmds = absl::StrSplit(command, ' ', absl::SkipEmpty()); if (!absl::EqualsIgnoreCase(cmds[0], "USER") || cmds.size() < 4) { return {}; } usernames->push_back(std::string(cmds[1])); cmds.erase(cmds.begin(), cmds.begin() + 2); materialized.push_back(cmds); } return materialized; } struct ParseKeyResult { std::string glob; KeyOp op; bool all_keys{false}; bool reset_keys{false}; }; std::optional MaybeParseAclKey(std::string_view command) { if (absl::EqualsIgnoreCase(command, "ALLKEYS") || command == "~*") { return ParseKeyResult{"", {}, true}; } if (absl::EqualsIgnoreCase(command, "RESETKEYS")) { return ParseKeyResult{"", {}, false, true}; } auto op = KeyOp::READ_WRITE; if (absl::StartsWith(command, "%RW")) { command = command.substr(3); } else if (absl::StartsWith(command, "%R")) { op = KeyOp::READ; command = command.substr(2); } else if (absl::StartsWith(command, "%W")) { op = KeyOp::WRITE; command = command.substr(2); } if (!absl::StartsWith(command, "~")) { return {}; } auto key = command.substr(1); if (key.empty()) { return {}; } return ParseKeyResult{std::string(key), op}; } struct ParsePubSubResult { std::string glob; bool has_asterisk{false}; bool all_channels{false}; bool reset_channels{false}; }; std::optional MaybeParseAclPubSub(std::string_view command) { if (absl::EqualsIgnoreCase(command, "ALLCHANNELS") || command == "&*") { return ParsePubSubResult{"", false, true, false}; } if (absl::EqualsIgnoreCase(command, "RESETCHANNELS")) { return ParsePubSubResult{"", false, false, true}; } if (absl::StartsWith(command, "&") && command.size() >= 2) { const auto glob = command.substr(1); const bool has_asterisk = glob.find('*') != std::string_view::npos; return ParsePubSubResult{std::string(glob), has_asterisk}; } return {}; } std::optional MaybeParseAclDflySelect(std::string_view command, uint32_t dbnum) { if (!absl::StartsWith(command, "$")) { return std::nullopt; } size_t res = 0; if (absl::SimpleAtoi(command.substr(1), &res) && res < dbnum) { return {res}; } if (absl::EqualsIgnoreCase(command.substr(1), "ALL")) { return {std::numeric_limits::max()}; } return std::nullopt; } std::string PrettyPrintSha(std::string_view pass, bool all) { if (all) { return absl::BytesToHexString(pass); } return absl::BytesToHexString(pass.substr(0, 15)).substr(0, 15); }; std::optional MaybeParsePassword(std::string_view command, bool hashed) { using UpPass = User::UpdatePass; if (command == "nopass") { return UpPass{"", false, true}; } if (command == "resetpass") { return UpPass{"", false, false, true}; } if (command[0] == '>' || (hashed && command[0] == '#')) { return UpPass{std::string(command.substr(1))}; } if (command[0] == '<') { return UpPass{std::string(command.substr(1)), true}; } return {}; } std::optional MaybeParseStatus(std::string_view command) { if (command == "ON") { return true; } if (command == "OFF") { return false; } return {}; } std::string PasswordsToString(const absl::flat_hash_set& passwords, bool nopass, bool full_sha) { if (nopass) { return "nopass "; } std::string result; for (const auto& pass : passwords) { absl::StrAppend(&result, "#", PrettyPrintSha(pass, full_sha), " "); } return result; } std::string AclKeysToString(const AclKeys& keys) { if (keys.all_keys) { return "~*"; } std::string result; for (auto& [pattern, op] : keys.key_globs) { if (op == KeyOp::READ_WRITE) { absl::StrAppend(&result, "~", pattern, " "); continue; } std::string op_str = (op == KeyOp::READ) ? "R" : "W"; absl::StrAppend(&result, "%", op_str, "~", pattern, " "); } if (!result.empty()) { result.pop_back(); } return result; } std::string AclPubSubToString(const AclPubSub& pub_sub) { if (pub_sub.all_channels) { return "&*"; } std::string result = "resetchannels "; for (const auto& [glob, has_asterisk] : pub_sub.globs) { absl::StrAppend(&result, "&", glob, " "); } if (result.back() == ' ') { result.pop_back(); } return result; } void SendAclSecurityEvents(const AclLog::LogEntry& entry, facade::RedisReplyBuilder* rb) { rb->StartArray(12); rb->SendSimpleString("reason"); using Reason = AclLog::Reason; std::string reason; if (entry.reason == Reason::COMMAND) { reason = "COMMAND"; } else if (entry.reason == Reason::KEY) { reason = "KEY"; } else if (entry.reason == Reason::PUB_SUB) { reason = "PUB_SUB"; } else { reason = "AUTH"; } rb->SendSimpleString(reason); rb->SendSimpleString("object"); rb->SendSimpleString(entry.object); rb->SendSimpleString("username"); rb->SendSimpleString(entry.username); rb->SendSimpleString("age-seconds"); auto now_diff = std::chrono::system_clock::now() - entry.entry_creation; auto secs = std::chrono::duration_cast(now_diff); auto left_over = now_diff - std::chrono::duration_cast(secs); auto age = absl::StrCat(secs.count(), ".", left_over.count()); rb->SendSimpleString(absl::StrCat(age)); rb->SendSimpleString("client-info"); rb->SendSimpleString(entry.client_info); rb->SendSimpleString("timestamp-created"); rb->SendLong(entry.entry_creation.time_since_epoch().count()); } std::string AclDbToString(size_t db) { return std::numeric_limits::max() == db ? "all" : absl::StrCat(db); } // Fetches the connections that predicate P evaluates to true and shuts them // down gracefully. template void TraverseEvictImpl(P predicate, facade::Listener* main_listener, util::ProactorPool* pool) { auto close_cb = [&](unsigned idx, util::ProactorBase* p) { std::vector connections; auto traverse_cb = [&](unsigned id, util::Connection* conn) { auto connection = static_cast(conn); auto ctx = connection->cntx(); if (predicate(ctx)) { connections.push_back(connection->Borrow()); } }; main_listener->TraverseConnectionsOnThread(traverse_cb, UINT32_MAX, nullptr); for (auto& tcon : connections) { facade::Connection* conn = tcon.Get(); if (conn && conn->socket()->proactor()->GetPoolIndex() == p->GetPoolIndex()) { // preemptive for TlsSocket conn->ShutdownSelfBlocking(); } } }; pool->AwaitFiberOnAll(close_cb); } } // namespace std::string AclFamily::AclCatAndCommandToString(const User::CategoryChanges& cat, const User::CommandChanges& cmds) const { std::string result; auto tables = MergeTables(cat, cmds); auto cat_visitor = [&result, this](const CategoryAndMetadata& val) { const auto& [change, meta] = val; absl::StrAppend(&result, AclCatToString(change, meta.sign), " "); }; auto cmd_visitor = [&result, this](const CommandAndMetadata& val) { const auto& [change, meta] = val; const auto [family, bit_index] = change; absl::StrAppend(&result, AclCommandToString(family, bit_index, meta.sign), " "); }; Overloaded visitor{cat_visitor, cmd_visitor}; for (auto change : tables) { std::visit(visitor, change); } if (!result.empty()) { result.pop_back(); } return result; } using OptCat = std::optional; // bool == true if + // bool == false if - std::pair AclFamily::MaybeParseAclCategory(std::string_view command) const { if (absl::EqualsIgnoreCase(command, "ALLCOMMANDS")) { return {cat_table_.at("ALL"), true}; } if (absl::EqualsIgnoreCase(command, "NOCOMMANDS")) { return {cat_table_.at("ALL"), false}; } if (absl::StartsWith(command, "+@")) { auto res = cat_table_.find(command.substr(2)); if (res == cat_table_.end()) { return {}; } return {res->second, true}; } if (absl::StartsWith(command, "-@")) { auto res = cat_table_.find(command.substr(2)); if (res == cat_table_.end()) { return {}; } return {res->second, false}; } return {}; } std::optional AclFamily::MaybeParseNamespace(std::string_view command) const { constexpr std::string_view kPrefix = "NAMESPACE:"; if (absl::StartsWith(command, kPrefix)) { return std::string(command.substr(kPrefix.size())); } return std::nullopt; } std::pair AclFamily::MaybeParseAclCommand( std::string_view command) const { if (absl::StartsWith(command, "+")) { auto res = cmd_registry_->Find(command.substr(1)); if (!res || res->IsAlias()) { return {}; } std::pair cmd{res->GetFamily(), res->GetBitIndex()}; return {cmd, true}; } if (absl::StartsWith(command, "-")) { auto res = cmd_registry_->Find(command.substr(1)); if (!res || res->IsAlias()) { return {}; } std::pair cmd{res->GetFamily(), res->GetBitIndex()}; return {cmd, false}; } return {}; } using facade::ErrorReply; std::variant AclFamily::ParseAclSetUser( const facade::ArgRange& args, bool hashed, bool has_all_keys, bool has_all_channels) const { User::UpdateRequest req; for (std::string_view arg : args) { if (auto pass = MaybeParsePassword(facade::ToSV(arg), hashed); pass) { req.passwords.push_back(std::move(*pass)); if (hashed && absl::StartsWith(facade::ToSV(arg), "#")) { req.passwords.back().is_hashed = true; } continue; } if (auto res = MaybeParseAclKey(facade::ToSV(arg)); res) { auto& [glob, op, all_keys, reset_keys] = *res; if ((has_all_keys && !all_keys && !reset_keys) || (req.allow_all_keys && !all_keys && !reset_keys)) { return ErrorReply(absl::StrCat( "Error in ACL SETUSER modifier \'", facade::ToSV(arg), "\': Adding a pattern after the * pattern (or the " "'allkeys' flag) is not valid and does not have any effect. Try 'resetkeys' to start " "with an empty list of patterns")); } req.allow_all_keys = all_keys; req.reset_all_keys = reset_keys; if (reset_keys) { has_all_keys = false; } req.keys.push_back({std::move(glob), op, all_keys, reset_keys}); continue; } if (auto res = MaybeParseAclPubSub(facade::ToSV(arg)); res) { auto& [glob, has_asterisk, all_channels, reset_channels] = *res; if ((has_all_channels && !all_channels && !reset_channels) || (req.all_channels && !all_channels && !reset_channels)) { return ErrorReply( absl::StrCat("ERR Error in ACL SETUSER modifier \'", facade::ToSV(arg), "\': Adding a pattern after the * pattern (or the 'allchannels' flag) is " "not valid and does not have any effect. Try 'resetchannels' to start " "with an empty list of channels")); } req.all_channels = all_channels; req.reset_channels = reset_channels; if (reset_channels) { has_all_channels = false; } req.pub_sub.push_back({std::move(glob), has_asterisk, all_channels, reset_channels}); continue; } if (auto res = MaybeParseAclDflySelect(facade::ToSV(arg), dbnum_); res) { if (req.select_db) { return ErrorReply("ERR Error, select db $ was used twice"); } req.select_db = res; continue; } std::string command = absl::AsciiStrToUpper(arg); if (auto status = MaybeParseStatus(command); status) { if (req.is_active) { return ErrorReply("Multiple ON/OFF are not allowed"); } req.is_active = *status; continue; } auto [cat, add] = MaybeParseAclCategory(command); if (cat) { using Sign = User::Sign; using Val = std::pair; auto val = add ? Val{Sign::PLUS, *cat} : Val{Sign::MINUS, *cat}; req.updates.push_back(val); continue; } auto ns = MaybeParseNamespace(command); if (ns.has_value()) { req.ns = *ns; continue; } auto [cmd, sign] = MaybeParseAclCommand(command); if (!cmd) { return ErrorReply(absl::StrCat("Unrecognized parameter ", command)); } using Sign = User::Sign; using Val = User::UpdateRequest::CommandsValueType; auto [index, bit] = *cmd; auto val = sign ? Val{Sign::PLUS, index, bit} : Val{Sign::MINUS, index, bit}; req.updates.push_back(val); } return req; } void AclFamily::BuildIndexers(RevCommandsIndexStore families) { size_t family_count = acl::NumberOfFamilies(families.size()); CommandsRevIndexer(std::move(families)); CategoryToCommandsIndexStore index; cmd_registry_->Traverse([&](std::string_view, auto& cid) { const uint32_t cat = cid.acl_categories(); const size_t family = cid.GetFamily(); DCHECK_LT(family, family_count); const uint64_t bit_index = cid.GetBitIndex(); for (size_t i = 0; i < 32; ++i) { if (cat & 1 << i) { std::string_view cat_name = reverse_cat_table_[i]; if (index[cat_name].empty()) { index[cat_name].resize(CommandsRevIndexer().size()); } index[cat_name][family] |= bit_index; } } }); CategoryToCommandsIndex(std::move(index)); CategoryToIdxStore idx_store; for (size_t i = 0; i < 32; ++i) { idx_store[1 << i] = i; } CategoryToIdx(std::move(idx_store)); } void AclFamily::Help(CmdArgList args, CommandContext* cmd_cntx) { string_view help_arr[] = { "ACL [ [value] [opt] ...]. Subcommands are:", "CAT []", " List all commands that belong to , or all command categories", " when no category is specified.", "DELUSER [ ...]", " Delete a list of users.", "DRYRUN [ ...]", " Returns whether the user can execute the given command without executing the command.", "GETUSER ", " Get the user's details.", "GENPASS []", " Generate a secure 256-bit user password. The optional `bits` argument can", " be used to specify a different size.", "LIST", " Show users details in config file format.", "LOAD", " Reload users from the ACL file.", "LOG [ | RESET]", " Show the ACL log entries.", "SAVE", " Save the current config to the ACL file.", "SETUSER [ ...]", " Create or modify a user with the specified attributes.", "USERS", " List all the registered usernames.", "WHOAMI", " Return the current connection username.", "HELP", " Print this help."}; auto* rb = static_cast(cmd_cntx->rb()); return rb->SendSimpleStrArr(help_arr); } using MemberFunc = void (AclFamily::*)(CmdArgList args, CommandContext* cmd_cntx); CommandId::Handler HandlerFunc(AclFamily* acl, MemberFunc f) { return [=](CmdArgList args, CommandContext* cmd_cntx) { return (acl->*f)(args, cmd_cntx); }; } #define HFUNC(x) SetHandler(HandlerFunc(this, &AclFamily::x)) constexpr uint32_t kAcl = acl::CONNECTION; constexpr uint32_t kList = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kSetUser = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kDelUser = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kWhoAmI = acl::SLOW; constexpr uint32_t kSave = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kLoad = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kLog = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kUsers = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kCat = acl::SLOW; constexpr uint32_t kGetUser = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kDryRun = acl::ADMIN | acl::SLOW | acl::DANGEROUS; constexpr uint32_t kGenPass = acl::SLOW; constexpr uint32_t kHelp = acl::SLOW; // We can't implement the ACL commands and its respective subcommands LIST, CAT, etc // the usual way, (that is, one command called ACL which then dispatches to the subcommand // based on the second argument) because each of the subcommands has different ACL // categories. Therefore, to keep it compatible with the CommandId, I need to treat them // as separate commands in the registry. This is the least intrusive change because it's very // easy to handle that case explicitly in `DispatchCommand`. void AclFamily::Register(dfly::CommandRegistry* registry) { using CI = dfly::CommandId; const uint32_t kAclMask = CO::ADMIN | CO::NOSCRIPT | CO::LOADING; registry->StartFamily(); *registry << CI{"ACL", CO::NOSCRIPT | CO::LOADING, 0, 0, 0, acl::kAcl}.HFUNC(Acl); *registry << CI{"ACL LIST", kAclMask, 1, 0, 0, acl::kList}.HFUNC(List); *registry << CI{"ACL SETUSER", kAclMask, -2, 0, 0, acl::kSetUser}.HFUNC(SetUser); *registry << CI{"ACL DELUSER", kAclMask, -2, 0, 0, acl::kDelUser}.HFUNC(DelUser); *registry << CI{"ACL WHOAMI", kAclMask, 1, 0, 0, acl::kWhoAmI}.HFUNC(WhoAmI); *registry << CI{"ACL SAVE", kAclMask, 1, 0, 0, acl::kSave}.HFUNC(Save); *registry << CI{"ACL LOAD", kAclMask, 1, 0, 0, acl::kLoad}.HFUNC(Load); *registry << CI{"ACL LOG", kAclMask, 0, 0, 0, acl::kLog}.HFUNC(Log); *registry << CI{"ACL USERS", kAclMask, 1, 0, 0, acl::kUsers}.HFUNC(Users); *registry << CI{"ACL CAT", kAclMask, -1, 0, 0, acl::kCat}.HFUNC(Cat); *registry << CI{"ACL GETUSER", kAclMask, 2, 0, 0, acl::kGetUser}.HFUNC(GetUser); *registry << CI{"ACL DRYRUN", kAclMask, 3, 0, 0, acl::kDryRun}.HFUNC(DryRun); *registry << CI{"ACL GENPASS", CO::NOSCRIPT | CO::LOADING, -1, 0, 0, acl::kGenPass}.HFUNC( GenPass); *registry << CI{"ACL HELP", kAclMask, 0, 0, 0, acl::kHelp}.HFUNC(Help); cmd_registry_ = registry; // build indexers BuildIndexers(cmd_registry_->GetFamilies()); } #undef HFUNC } // namespace dfly::acl ================================================ FILE: src/server/acl/acl_family.h ================================================ // Copyright 2025, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include "absl/container/flat_hash_set.h" #include "facade/facade_types.h" #include "helio/util/proactor_pool.h" #include "server/acl/acl_commands_def.h" #include "server/acl/user_registry.h" #include "server/command_registry.h" #include "server/execution_state.h" namespace facade { class SinkReplyBuilder; class Listener; } // namespace facade namespace dfly { using facade::CmdArgList; class ConnectionContext; namespace acl { class AclFamily final { public: explicit AclFamily(UserRegistry* registry, util::ProactorPool* pool); void Register(CommandRegistry* registry); void Init(facade::Listener* listener, UserRegistry* registry); private: using SinkReplyBuilder = facade::SinkReplyBuilder; void Acl(CmdArgList args, CommandContext* cmd_cntx); void List(CmdArgList args, CommandContext* cmd_cntx); void SetUser(CmdArgList args, CommandContext* cmd_cntx); void DelUser(CmdArgList args, CommandContext* cmd_cntx); void WhoAmI(CmdArgList args, CommandContext* cmd_cntx); void Save(CmdArgList args, CommandContext* cmd_cntx); void Load(CmdArgList args, CommandContext* cmd_cntx); // Helper function for bootstrap bool Load(); void Log(CmdArgList args, CommandContext* cmd_cntx); void Users(CmdArgList args, CommandContext* cmd_cntx); void Cat(CmdArgList args, CommandContext* cmd_cntx); void GetUser(CmdArgList args, CommandContext* cmd_cntx); void DryRun(CmdArgList args, CommandContext* cmd_cntx); void GenPass(CmdArgList args, CommandContext* cmd_cntx); void Help(CmdArgList args, CommandContext* cmd_cntx); // Helper function that updates all open connections and their // respective ACL fields on all the available proactor threads using Commands = std::vector; void StreamUpdatesToAllProactorConnections(const std::string& user, const Commands& update_commands, const AclKeys& update_keys, const AclPubSub& update_pub_sub, size_t db); // Helper function that closes all open connection from the deleted user void EvictOpenConnectionsOnAllProactors(const absl::flat_hash_set& user); // Helper function that closes all open connections for users in the registry void EvictOpenConnectionsOnAllProactorsWithRegistry(const UserRegistry::RegistryType& registry); // Helper function that loads the acl state of an acl file into the user registry GenericError LoadToRegistryFromFile(std::string_view full_path, SinkReplyBuilder* builder); // Serializes the whole registry into a string std::string RegistryToString() const; std::string AclCatToString(uint32_t acl_category, User::Sign sign) const; std::string AclCommandToString(size_t family, uint64_t mask, User::Sign sign) const; // Serializes category and command to string std::string AclCatAndCommandToString(const User::CategoryChanges& cat, const User::CommandChanges& cmds) const; using OptCat = std::optional; std::pair MaybeParseAclCategory(std::string_view command) const; using OptCommand = std::optional>; std::pair MaybeParseAclCommand(std::string_view command) const; std::optional MaybeParseNamespace(std::string_view command) const; std::variant ParseAclSetUser( const facade::ArgRange& args, bool hashed = false, bool has_all_keys = false, bool has_all_channels = false) const; void BuildIndexers(RevCommandsIndexStore families); // Data members facade::Listener* main_listener_{nullptr}; UserRegistry* registry_; CommandRegistry* cmd_registry_; util::ProactorPool* pool_; // Indexes // See definitions for NONE and ALL in facade/acl_commands_def.h const CategoryIndexTable cat_table_{{"KEYSPACE", KEYSPACE}, {"READ", READ}, {"WRITE", WRITE}, {"SET", SET}, {"SORTEDSET", SORTEDSET}, {"LIST", LIST}, {"HASH", HASH}, {"STRING", STRING}, {"BITMAP", BITMAP}, {"HYPERLOG", HYPERLOGLOG}, {"GEO", GEO}, {"STREAM", STREAM}, {"PUBSUB", PUBSUB}, {"ADMIN", ADMIN}, {"FAST", FAST}, {"SLOW", SLOW}, {"BLOCKING", BLOCKING}, {"DANGEROUS", DANGEROUS}, {"CONNECTION", CONNECTION}, {"TRANSACTION", TRANSACTION}, {"SCRIPTING", SCRIPTING}, {"CMS", CMS}, {"BLOOM", BLOOM}, {"FT_SEARCH", FT_SEARCH}, {"SEARCH", FT_SEARCH}, // Alias for FT_SEARCH {"THROTTLE", THROTTLE}, {"JSON", JSON}, {"ALL", ALL}}; // bit 0 at index 0 // bit 1 at index 1 // bit n at index n const ReverseCategoryIndexTable reverse_cat_table_{ "KEYSPACE", "READ", "WRITE", "SET", "SORTEDSET", "LIST", "HASH", "STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM", "PUBSUB", "ADMIN", "FAST", "SLOW", "BLOCKING", "DANGEROUS", "CONNECTION", "TRANSACTION", "SCRIPTING", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "CMS", "BLOOM", "FT_SEARCH", "THROTTLE", "JSON"}; // We need this to act as a const member, since the initialization of const data members // must be done on the constructor. However, these are initialized a little later, when // we Register the commands const CategoryToIdxStore& CategoryToIdx(CategoryToIdxStore store = {}) const { static CategoryToIdxStore cat_idx = std::move(store); return cat_idx; } const RevCommandsIndexStore& CommandsRevIndexer(RevCommandsIndexStore store = {}) const { static RevCommandsIndexStore rev_index_store = std::move(store); return rev_index_store; } const CategoryToCommandsIndexStore& CategoryToCommandsIndex( CategoryToCommandsIndexStore store = {}) const { static CategoryToCommandsIndexStore index = std::move(store); return index; } size_t dbnum_ = 0; // Only for testing interface public: // Helper accessors for tests. Do not use them directly. const ReverseCategoryIndexTable& GetRevTable() const { return reverse_cat_table_; } // We could make CommandsRevIndexer public, but I want this to be // clear that this is for TESTING so do not use this in the codebase const RevCommandsIndexStore& GetCommandsRevIndexer() const { return CommandsRevIndexer(); } }; } // namespace acl } // namespace dfly ================================================ FILE: src/server/acl/acl_family_test.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "server/acl/acl_family.h" #include #include #include #include "base/flags.h" #include "base/gtest.h" #include "base/logging.h" #include "facade/facade_test.h" #include "server/acl/acl_commands_def.h" #include "server/command_registry.h" #include "server/test_utils.h" using namespace testing; ABSL_DECLARE_FLAG(std::vector, rename_command); ABSL_DECLARE_FLAG(std::vector, command_alias); namespace dfly { class AclFamilyTest : public BaseFamilyTest { protected: }; class AclFamilyTestRename : public BaseFamilyTest { void SetUp() override { absl::SetFlag(&FLAGS_rename_command, {"ACL=ROCKS"}); absl::SetFlag(&FLAGS_command_alias, {"___SET=SET"}); ResetService(); } }; TEST_F(AclFamilyTest, AclSetUser) { TestInitAclFam(); auto resp = Run({"ACL", "SETUSER"}); EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl setuser' command")); resp = Run({"ACL", "SETUSER", "kostas", "ONN"}); EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter ONN")); resp = Run({"ACL", "SETUSER", "kostas", "+@nonsense"}); EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter +@NONSENSE")); resp = Run({"ACL", "SETUSER", "vlad"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all $all", "user vlad off resetchannels -@all $all")); resp = Run({"ACL", "SETUSER", "vlad", "+ACL"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); EXPECT_THAT(vec, UnorderedElementsAre("user default on nopass ~* &* +@all $all", "user vlad off resetchannels -@all +acl $all")); resp = Run({"ACL", "SETUSER", "vlad", "on", ">pass", ">temp"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); vec = resp.GetVec(); EXPECT_THAT(vec.size(), 2); auto contains_vlad = [](const auto& vec) { const std::string default_user = "user default on nopass ~* &* +@all $all"; const std::string a_permutation = "user vlad on #a6864eb339b0e1f #d74ff0ee8da3b98 resetchannels -@all +acl $all"; const std::string b_permutation = "user vlad on #d74ff0ee8da3b98 #a6864eb339b0e1f resetchannels -@all +acl $all"; std::string_view other; if (vec[0] == default_user) { other = vec[1].GetView(); } else if (vec[1] == default_user) { other = vec[0].GetView(); } else { return false; } return other == a_permutation || other == b_permutation; }; EXPECT_THAT(contains_vlad(vec), true); resp = Run({"AUTH", "vlad", "pass"}); EXPECT_THAT(resp, "OK"); resp = Run({"AUTH", "vlad", "temp"}); EXPECT_THAT(resp, "OK"); resp = Run({"AUTH", "default", R"("")"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "vlad", ">another"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "vlad", "pass", "+@admin"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "adi", ">pass", "+@fast"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); auto vec = resp.GetVec(); EXPECT_THAT( vec, UnorderedElementsAre("user default on nopass ~* &* +@all $all", "user kostas off #d74ff0ee8da3b98 resetchannels -@all +@admin $all", "user adi off #d74ff0ee8da3b98 resetchannels -@all +@fast $all")); } TEST_F(AclFamilyTest, AclAuth) { TestInitAclFam(); auto resp = Run({"AUTH", "default", R"("")"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "shahar", ">mypass"}); EXPECT_THAT(resp, "OK"); resp = Run({"AUTH", "shahar", "wrongpass"}); EXPECT_THAT(resp, ErrArg("WRONGPASS invalid username-password pair or user is disabled.")); resp = Run({"AUTH", "shahar", "mypass"}); EXPECT_THAT(resp, ErrArg("WRONGPASS invalid username-password pair or user is disabled.")); // Activate the user resp = Run({"ACL", "SETUSER", "shahar", "ON", "+@fast"}); EXPECT_THAT(resp, "OK"); resp = Run({"AUTH", "shahar", "mypass"}); EXPECT_THAT(resp, "OK"); } TEST_F(AclFamilyTest, AclWhoAmI) { TestInitAclFam(); auto resp = Run({"ACL", "WHOAMI", "WHO"}); EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl whoami' command")); resp = Run({"ACL", "SETUSER", "kostas", "ON", ">pass", "+@SLOW"}); EXPECT_THAT(resp, "OK"); resp = Run({"AUTH", "kostas", "pass"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "WHOAMI"}); EXPECT_THAT(resp, "User is kostas"); } TEST_F(AclFamilyTest, TestAllCategories) { const auto* fam = TestInitAclFam(); for (auto& cat : fam->GetRevTable()) { if (cat != "_RESERVED") { auto resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("+@", cat)}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* &* +@all $all", absl::StrCat("user kostas off resetchannels -@all ", "+@", absl::AsciiStrToLower(cat), " $all"))); resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-@", cat)}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* &* +@all $all", absl::StrCat("user kostas off resetchannels -@all ", "-@", absl::AsciiStrToLower(cat), " $all"))); resp = Run({"ACL", "DELUSER", "kostas"}); EXPECT_THAT(resp, IntArg(1)); } } for (auto& cat : fam->GetRevTable()) { if (cat != "_RESERVED") { auto resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("+@", cat)}); EXPECT_THAT(resp, "OK"); } } // This won't work because of __RESERVED // TODO(fix this) // auto resp = Run({"ACL", "LIST"}); // EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL", // absl::StrCat("user kostas off nopass ", "+@ALL"))); // // TODO(Bug here fix none/all) // auto resp = Run({"ACL", "SETUSER", "kostas", "+@NONE"}); // EXPECT_THAT(resp, "OK"); // // resp = Run({"ACL", "LIST"}); // EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass +@ALL", "user kostas // off nopass +@NONE")); } TEST_F(AclFamilyTest, TestAllCommands) { const auto* fam = TestInitAclFam(); const auto& rev_indexer = fam->GetCommandsRevIndexer(); for (const auto& family : rev_indexer) { for (const auto& command_name : family) { auto resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("+", command_name)}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* &* +@all $all", absl::StrCat("user kostas off resetchannels -@all ", "+", absl::AsciiStrToLower(command_name), " $all"))); resp = Run({"ACL", "SETUSER", "kostas", absl::StrCat("-", command_name)}); resp = Run({"ACL", "LIST"}); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("user default on nopass ~* &* +@all $all", absl::StrCat("user kostas off resetchannels -@all ", "-", absl::AsciiStrToLower(command_name), " $all"))); resp = Run({"ACL", "DELUSER", "kostas"}); EXPECT_THAT(resp, IntArg(1)); } } } TEST_F(AclFamilyTest, TestUsers) { TestInitAclFam(); auto resp = Run({"ACL", "SETUSER", "abhra", "ON"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "ari"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "USERS"}); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("default", "abhra", "ari")); } TEST_F(AclFamilyTest, TestCat) { TestInitAclFam(); auto resp = Run({"ACL", "CAT", "nonsense"}); EXPECT_THAT(resp, ErrArg("ERR Unknown category: NONSENSE")); resp = Run({"ACL", "CAT"}); EXPECT_GE(resp.GetVec().size(), 24u); resp = Run({"ACL", "CAT", "STRING"}); EXPECT_THAT(resp.GetVec(), IsSupersetOf({"GETSET", "GETRANGE", "INCRBYFLOAT", "GETDEL", "DECRBY", "PREPEND", "SETEX", "MSET", "SET", "PSETEX", "SUBSTR", "DECR", "STRLEN", "INCR", "INCRBY", "MGET", "GET", "SETNX", "GETEX", "APPEND", "MSETNX", "SETRANGE"})); } TEST_F(AclFamilyTest, TestGetUser) { TestInitAclFam(); auto resp = Run({"ACL", "GETUSER", "kostas"}); EXPECT_THAT(resp, ArgType(RespExpr::NIL)); resp = Run({"ACL", "GETUSER", "default"}); const auto& vec = resp.GetVec(); EXPECT_THAT(vec[0], "flags"); EXPECT_THAT(vec[1].GetVec(), UnorderedElementsAre("on", "nopass")); EXPECT_THAT(vec[2], "passwords"); EXPECT_TRUE(vec[3].GetVec().empty()); EXPECT_THAT(vec[4], "commands"); EXPECT_THAT(vec[5], "+@all"); EXPECT_THAT(vec[6], "keys"); EXPECT_THAT(vec[7], "~*"); EXPECT_THAT(vec[8], "channels"); EXPECT_THAT(vec[9], "&*"); resp = Run({"ACL", "SETUSER", "kostas", "+@STRING", "+HSET"}); resp = Run({"ACL", "GETUSER", "kostas"}); const auto& kvec = resp.GetVec(); EXPECT_THAT(kvec[0], "flags"); EXPECT_THAT(kvec[1].GetVec(), UnorderedElementsAre("off")); EXPECT_THAT(kvec[2], "passwords"); EXPECT_TRUE(kvec[3].GetVec().empty()); EXPECT_THAT(kvec[4], "commands"); EXPECT_THAT(kvec[5], "-@all +@string +hset"); EXPECT_THAT(kvec[6], "keys"); EXPECT_THAT(kvec[7], RespArray(ElementsAre())); EXPECT_THAT(kvec[8], "channels"); EXPECT_THAT(kvec[9], "resetchannels"); } TEST_F(AclFamilyTest, TestDryRun) { TestInitAclFam(); auto resp = Run({"ACL", "DRYRUN"}); EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl dryrun' command")); resp = Run({"ACL", "DRYRUN", "default"}); EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl dryrun' command")); resp = Run({"ACL", "DRYRUN", "default", "get", "more"}); EXPECT_THAT(resp, ErrArg("ERR wrong number of arguments for 'acl dryrun' command")); resp = Run({"ACL", "DRYRUN", "kostas", "more"}); EXPECT_THAT(resp, ErrArg("ERR User 'kostas' not found")); resp = Run({"ACL", "DRYRUN", "default", "nope"}); EXPECT_THAT(resp, ErrArg("ERR Command 'NOPE' not found")); resp = Run({"ACL", "DRYRUN", "default", "SET"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "kostas", "+GET"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "DRYRUN", "kostas", "GET"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "DRYRUN", "kostas", "SET"}); EXPECT_THAT(resp, "This user has no permissions to run the 'SET' command"); } TEST_F(AclFamilyTest, AclGenPassTooManyArguments) { TestInitAclFam(); auto resp = Run({"ACL", "GENPASS", "1", "2"}); EXPECT_THAT(resp.GetString(), "ERR Unknown subcommand or wrong number of arguments for 'GENPASS'. Try ACL HELP."); } TEST_F(AclFamilyTest, AclGenPassOutOfRange) { std::string expectedError = "ERR ACL GENPASS argument must be the number of bits for the output password, a positive " "number up to 4096"; auto resp = Run({"ACL", "GENPASS", "-1"}); EXPECT_THAT(resp.GetString(), expectedError); resp = Run({"ACL", "GENPASS", "0"}); EXPECT_THAT(resp.GetString(), expectedError); resp = Run({"ACL", "GENPASS", "4097"}); EXPECT_THAT(resp.GetString(), expectedError); } TEST_F(AclFamilyTest, AclGenPass) { auto resp = Run({"ACL", "GENPASS"}); auto actualPassword = resp.GetString(); // should be 256 bits or 64 bytes in hex EXPECT_THAT(actualPassword.length(), 64); // 1 bit - 4 bits should all produce a single hex character for (int i = 1; i <= 4; i++) { resp = Run({"ACL", "GENPASS", std::to_string(i)}); EXPECT_THAT(resp.GetString().length(), 1); } // 5 bits - 8 bits should all produce two hex characters for (int i = 5; i <= 8; i++) { resp = Run({"ACL", "GENPASS", std::to_string(i)}); EXPECT_THAT(resp.GetString().length(), 2); } // and the pattern continues resp = Run({"ACL", "GENPASS", "9"}); EXPECT_THAT(resp.GetString().length(), 3); } TEST_F(AclFamilyTestRename, AclRename) { auto resp = Run({"ACL", "SETUSER", "billy"}); EXPECT_THAT(resp, ErrArg("ERR unknown command `ACL`")); resp = Run({"ROCKS", "SETUSER", "billy", "ON", ">mypass"}); EXPECT_THAT(resp.GetString(), "OK"); resp = Run({"ROCKS", "DELUSER", "billy"}); EXPECT_THAT(resp, IntArg(1)); } TEST_F(AclFamilyTest, TestKeys) { TestInitAclFam(); auto resp = Run({"ACL", "SETUSER", "temp", "~foo", "~bar*"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); auto& vec = resp.GetVec(); EXPECT_THAT(vec[6], "keys"); EXPECT_THAT(vec[7], "~foo ~bar*"); resp = Run({"ACL", "SETUSER", "temp", "~*", "~foo"}); EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~foo': Adding a pattern after the * " "pattern (or the 'allkeys' flag) is not valid and does not have any " "effect. Try 'resetkeys' to start with an empty list of patterns")); resp = Run({"ACL", "SETUSER", "temp", "~*"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "SETUSER", "temp", "~foo"}); EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '~foo': Adding a pattern after the * " "pattern (or the 'allkeys' flag) is not valid and does not have any " "effect. Try 'resetkeys' to start with an empty list of patterns")); resp = Run({"ACL", "SETUSER", "temp", "resetkeys"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); EXPECT_TRUE(resp.GetVec()[7].GetVec().empty()); resp = Run({"ACL", "SETUSER", "temp", "%R~foo"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); EXPECT_THAT(resp.GetVec()[7], "%R~foo"); resp = Run({"ACL", "SETUSER", "temp", "resetkeys", "%W~foo"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); EXPECT_THAT(resp.GetVec()[7], "%W~foo"); resp = Run({"ACL", "SETUSER", "temp", "resetkeys", "%RW~foo"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); EXPECT_THAT(resp.GetVec()[7], "~foo"); resp = Run({"ACL", "SETUSER", "temp", "resetkeys", "%K~foo"}); EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter %K~FOO")); resp = Run({"ACL", "SETUSER", "temp", "resetkeys", "%Rfoo"}); EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter %RFOO")); } TEST_F(AclFamilyTest, TestPubSub) { TestInitAclFam(); auto resp = Run({"ACL", "SETUSER", "temp", "&foo", "&b*r"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); auto vec = resp.GetVec(); EXPECT_THAT(vec[8], "channels"); EXPECT_THAT(vec[9], "resetchannels &foo &b*r"); resp = Run({"ACL", "SETUSER", "temp", "allchannels", "&bar"}); EXPECT_THAT(resp, ErrArg("ERR Error in ACL SETUSER modifier '&bar': Adding a pattern after the * " "pattern (or the 'allchannels' flag) is " "not valid and does not have any effect. Try 'resetchannels' to start " "with an empty list of channels")); resp = Run({"ACL", "SETUSER", "temp", "allchannels"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); vec = resp.GetVec(); EXPECT_THAT(vec[8], "channels"); EXPECT_THAT(vec[9], "&*"); resp = Run({"ACL", "SETUSER", "temp", "resetchannels", "&foo"}); EXPECT_THAT(resp, "OK"); resp = Run({"ACL", "GETUSER", "temp"}); vec = resp.GetVec(); EXPECT_THAT(vec[8], "channels"); EXPECT_THAT(vec[9], "resetchannels &foo"); resp = Run("ACL setuser demo on resetkeys resetchannels ~app|managed-resources|* " "&app|managed-resources|* +publish +ping >passwd"); resp = Run("AUTH demo passwd"); EXPECT_THAT(resp, "OK"); resp = Run("publish app|managed-resources|xyz test"); EXPECT_THAT(resp, IntArg(0)); } TEST_F(AclFamilyTest, TestAlias) { auto resp = Run({"ACL", "SETUSER", "luke", "+___SET"}); EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter +___SET")); resp = Run({"ACL", "SETUSER", "leia", "-___SET"}); EXPECT_THAT(resp, ErrArg("ERR Unrecognized parameter -___SET")); resp = Run({"ACL", "SETUSER", "anakin", "+SET"}); EXPECT_EQ(resp, "OK"); resp = Run({"ACL", "SETUSER", "jarjar", "allcommands"}); EXPECT_EQ(resp, "OK"); resp = Run({"ACL", "DRYRUN", "jarjar", "___SET"}); EXPECT_THAT(resp, ErrArg("ERR Command '___SET' not found")); EXPECT_EQ(Run({"ACL", "DRYRUN", "jarjar", "SET"}), "OK"); } TEST_F(AclFamilyTest, TestAclLogUB) { auto resp = Run({"ACL", "LOG"}); EXPECT_TRUE(resp.GetVec().empty()); resp = Run({"ACL", "LOG", "2", "RESET"}); EXPECT_THAT(resp, ErrArg("ERR index out of range")); } } // namespace dfly ================================================ FILE: src/server/acl/acl_log.cc ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "server/acl/acl_log.h" #include #include #include "base/flags.h" #include "base/logging.h" #include "facade/dragonfly_connection.h" #include "server/conn_context.h" ABSL_FLAG(uint32_t, acllog_max_len, 32, "Specify the number of log entries. Logs are kept locally for each thread " "and therefore the total number of entries are acllog_max_len * threads"); namespace dfly::acl { AclLog::AclLog() : total_entries_allowed_(absl::GetFlag(FLAGS_acllog_max_len)) { } void AclLog::Add(const ConnectionContext& cntx, std::string object, Reason reason, std::string tried_to_auth) { if (total_entries_allowed_ == 0) { return; } if (log_.size() == total_entries_allowed_) { log_.pop_back(); } std::string username; // We can't use a conditional here because the result is the common type which is a const-ref if (tried_to_auth.empty()) { username = cntx.authed_username; } else { username = std::move(tried_to_auth); } std::string client_info = cntx.conn()->GetClientInfo(); using clock = std::chrono::system_clock; LogEntry entry = {std::move(username), std::move(client_info), std::move(object), reason, clock::now()}; log_.push_front(std::move(entry)); } void AclLog::Reset() { log_.clear(); } AclLog::LogType AclLog::GetLog(size_t number_of_entries) const { auto start = log_.begin(); auto end = log_.size() <= number_of_entries ? log_.end() : std::next(start, number_of_entries); return {start, end}; } void AclLog::SetTotalEntries(size_t total_entries) { if (log_.size() > total_entries) { log_.erase(std::next(log_.begin(), total_entries), log_.end()); } total_entries_allowed_ = total_entries; } } // namespace dfly::acl ================================================ FILE: src/server/acl/acl_log.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include namespace dfly { class ConnectionContext; namespace acl { class AclLog { public: explicit AclLog(); enum class Reason { COMMAND, AUTH, KEY, PUB_SUB }; struct LogEntry { std::string username; std::string client_info; std::string object; Reason reason; using TimePoint = std::chrono::time_point; TimePoint entry_creation = TimePoint::max(); friend bool operator<(const LogEntry& lhs, const LogEntry& rhs) { return lhs.entry_creation < rhs.entry_creation; } }; void Add(const ConnectionContext& cntx, std::string object, Reason reason, std::string tried_to_auth = ""); void Reset(); using LogType = std::deque; LogType GetLog(size_t number_of_entries) const; void SetTotalEntries(size_t total_entries); private: LogType log_; size_t total_entries_allowed_; }; } // namespace acl } // namespace dfly ================================================ FILE: src/server/acl/user.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "server/acl/user.h" #include #include #include "absl/container/flat_hash_set.h" #include "absl/strings/escaping.h" #include "core/overloaded.h" namespace dfly::acl { namespace { std::string StringSHA256(std::string_view password) { std::string hash; hash.resize(SHA256_DIGEST_LENGTH); SHA256(reinterpret_cast(password.data()), password.size(), reinterpret_cast(hash.data())); return hash; } } // namespace User::User() { commands_ = std::vector(NumberOfFamilies(), 0); } void User::Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, const CategoryToCommandsIndexStore& cat_to_commands) { for (auto& pass : req.passwords) { if (pass.nopass) { SetNopass(); continue; } if (pass.unset) { UnsetPassword(pass.password); continue; } if (pass.reset_password) { password_hashes_.clear(); continue; } SetPasswordHash(pass.password, pass.is_hashed); } auto cat_visitor = [&, this](UpdateRequest::CategoryValueType cat) { auto [sign, category] = cat; if (sign == Sign::PLUS) { SetAclCategoriesAndIncrSeq(category, cat_to_id, reverse_cat, cat_to_commands); return; } UnsetAclCategoriesAndIncrSeq(category, cat_to_id, reverse_cat, cat_to_commands); }; auto cmd_visitor = [this](UpdateRequest::CommandsValueType cmd) { auto [sign, index, bit_index] = cmd; if (sign == Sign::PLUS) { SetAclCommandsAndIncrSeq(index, bit_index); return; } UnsetAclCommandsAndIncrSeq(index, bit_index); }; Overloaded visitor{cat_visitor, cmd_visitor}; for (auto req : req.updates) { std::visit(visitor, req); } if (!req.keys.empty()) { SetKeyGlobs(std::move(req.keys)); } if (!req.pub_sub.empty()) { SetPubSub(std::move(req.pub_sub)); } if (req.is_active) { SetIsActive(*req.is_active); } SetSelectDb(req.select_db); SetNamespace(req.ns); } void User::SetPasswordHash(std::string_view password, bool is_hashed) { nopass_ = false; if (is_hashed) { std::string binary; if (absl::HexStringToBytes(password, &binary)) { password_hashes_.insert(binary); } else { LOG(ERROR) << "Invalid password hash: " << password; } return; } password_hashes_.insert(StringSHA256(password)); } void User::UnsetPassword(std::string_view password) { password_hashes_.erase(StringSHA256(password)); } void User::SetNamespace(const std::string& ns) { namespace_ = ns; } void User::SetSelectDb(std::optional db) { if (db) { db_ = *db; } } size_t User::Db() const { return db_; } const std::string& User::Namespace() const { return namespace_; } bool User::HasPassword(std::string_view password) const { if (nopass_) { return true; } return password_hashes_.contains(StringSHA256(password)); } void User::SetAclCategoriesAndIncrSeq(uint32_t cat, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, const CategoryToCommandsIndexStore& cat_to_commands) { acl_categories_ |= cat; if (cat == acl::ALL) { SetAclCommands(std::numeric_limits::max(), 0); } else { auto id = cat_to_id.at(cat); std::string_view name = reverse_cat[id]; const auto& commands_group = cat_to_commands.at(name); for (size_t fam_id = 0; fam_id < commands_group.size(); ++fam_id) { SetAclCommands(fam_id, commands_group[fam_id]); } } CategoryChange change{cat}; cat_changes_[change] = ChangeMetadata{Sign::PLUS, seq_++}; } void User::UnsetAclCategoriesAndIncrSeq(uint32_t cat, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, const CategoryToCommandsIndexStore& cat_to_commands) { acl_categories_ ^= cat; if (cat == acl::ALL) { UnsetAclCommands(std::numeric_limits::max(), 0); } else { auto id = cat_to_id.at(cat); std::string_view name = reverse_cat[id]; const auto& commands_group = cat_to_commands.at(name); for (size_t fam_id = 0; fam_id < commands_group.size(); ++fam_id) { UnsetAclCommands(fam_id, commands_group[fam_id]); } } CategoryChange change{cat}; cat_changes_[change] = ChangeMetadata{Sign::MINUS, seq_++}; } void User::SetAclCommands(size_t index, uint64_t bit_index) { if (index == std::numeric_limits::max()) { for (auto& family : commands_) { family = ALL_COMMANDS; } return; } commands_[index] |= bit_index; } void User::SetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index) { SetAclCommands(index, bit_index); CommandChange change{index, bit_index}; cmd_changes_[change] = ChangeMetadata{Sign::PLUS, seq_++}; } void User::UnsetAclCommands(size_t index, uint64_t bit_index) { if (index == std::numeric_limits::max()) { for (auto& family : commands_) { family = NONE_COMMANDS; } return; } SetAclCommands(index, bit_index); commands_[index] ^= bit_index; } void User::UnsetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index) { UnsetAclCommands(index, bit_index); CommandChange change{index, bit_index}; cmd_changes_[change] = ChangeMetadata{Sign::MINUS, seq_++}; } uint32_t User::AclCategory() const { return acl_categories_; } std::vector User::AclCommands() const { return commands_; } const std::vector& User::AclCommandsRef() const { return commands_; } void User::SetIsActive(bool is_active) { is_active_ = is_active; } bool User::IsActive() const { return is_active_; } const absl::flat_hash_set& User::Passwords() const { return password_hashes_; } bool User::HasNopass() const { return nopass_; } const AclKeys& User::Keys() const { return keys_; } const AclPubSub& User::PubSub() const { return pub_sub_; } const User::CategoryChanges& User::CatChanges() const { return cat_changes_; } const User::CommandChanges& User::CmdChanges() const { return cmd_changes_; } void User::SetKeyGlobs(std::vector keys) { for (auto& key : keys) { if (key.all_keys) { keys_.key_globs.clear(); keys_.all_keys = true; } else if (key.reset_keys) { keys_.key_globs.clear(); keys_.all_keys = false; } else { keys_.key_globs.push_back({std::move(key.key), key.op}); } } } void User::SetPubSub(std::vector pub_sub) { for (auto& pattern : pub_sub) { if (pattern.all_channels) { pub_sub_.globs.clear(); pub_sub_.all_channels = true; } else if (pattern.reset_channels) { pub_sub_.globs.clear(); pub_sub_.all_channels = false; } else { pub_sub_.globs.push_back({std::move(pattern.pattern), pattern.has_asterisk}); } } } void User::SetNopass() { nopass_ = true; password_hashes_.clear(); } } // namespace dfly::acl ================================================ FILE: src/server/acl/user.h ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "server/acl/acl_commands_def.h" namespace dfly::acl { class User final { public: enum class Sign : int8_t { PLUS, MINUS }; struct UpdateKey { std::string key; KeyOp op; bool all_keys = false; bool reset_keys = false; }; struct UpdatePass { std::string password; // Set to denote remove password bool unset{false}; bool nopass{false}; bool reset_password{false}; bool is_hashed{false}; }; struct UpdatePubSub { std::string pattern; bool has_asterisk{false}; bool all_channels{false}; bool reset_channels{false}; }; struct UpdateRequest { std::vector passwords; std::optional is_active{}; bool is_hashed{false}; // Categories and commands using CategoryValueType = std::pair; // If index s numberic_limits::max() then it's a +all flag using CommandsValueType = std::tuple; using UpdateType = std::vector>; UpdateType updates; // keys std::vector keys; bool reset_all_keys{false}; bool allow_all_keys{false}; // pub/sub std::vector pub_sub; bool reset_channels{false}; bool all_channels{false}; // TODO allow reset all // bool reset_all{false}; // DFLY specific std::optional select_db; std::string ns; }; using CategoryChange = uint32_t; using CommandChange = std::pair; struct ChangeMetadata { Sign sign; size_t seq_no; }; /* Used for default user * password = nopass * acl_categories = +@all * is_active = true; */ User(); User(const User&) = delete; User(User&&) = default; // For single step updates void Update(UpdateRequest&& req, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, const CategoryToCommandsIndexStore& cat_to_commands); bool HasPassword(std::string_view password) const; uint32_t AclCategory() const; std::vector AclCommands() const; const std::vector& AclCommandsRef() const; bool IsActive() const; const absl::flat_hash_set& Passwords() const; bool HasNopass() const; // Selector maps a command string (like HSET, SET etc) to // its respective ID within the commands vector. static size_t Selector(std::string_view); const AclKeys& Keys() const; const AclPubSub& PubSub() const; const std::string& Namespace() const; size_t Db() const; using CategoryChanges = absl::flat_hash_map; using CommandChanges = absl::flat_hash_map; const CategoryChanges& CatChanges() const; const CommandChanges& CmdChanges() const; private: void SetAclCategoriesAndIncrSeq(uint32_t cat, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, const CategoryToCommandsIndexStore& cat_to_commands); void UnsetAclCategoriesAndIncrSeq(uint32_t cat, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, const CategoryToCommandsIndexStore& cat_to_commands); // For ACL commands void SetAclCommands(size_t index, uint64_t bit_index); void UnsetAclCommands(size_t index, uint64_t bit_index); void SetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index); void UnsetAclCommandsAndIncrSeq(size_t index, uint64_t bit_index); // For is_active flag void SetIsActive(bool is_active); // For passwords void SetPasswordHash(std::string_view password, bool is_hashed); void UnsetPassword(std::string_view password); // For ACL key globs void SetKeyGlobs(std::vector keys); // For ACL pub/sub void SetPubSub(std::vector pub_sub); void SetNamespace(const std::string& ns); void SetSelectDb(std::optional db); // Set NOPASS and remove all passwords void SetNopass(); // Passwords for each user absl::flat_hash_set password_hashes_; // if `nopass` is used bool nopass_ = false; uint32_t acl_categories_{NONE}; // Each element index in the vector corresponds to a familly of commands // Each bit in the uin64_t field at index id, corresponds to a specific // command of that family. Look on TableCommandBuilder and on Service::Register // on how this mapping is built during the startup/registration of commands std::vector commands_; // We also need to track all the explicit changes (ACL SETUSER) of acl's in-order. // To speed up insertion we use the flat_hash_map and a seq_ variable which is a // strictly monotonically increasing number that is used for ordering. Both of these // indexers are merged and then sorted by the seq_ number when for example we print // the ACL rules of each user via ACL LIST. CategoryChanges cat_changes_; CommandChanges cmd_changes_; // Global modification order for changes in rules for acl commands and categories size_t seq_ = 0; // Glob patterns for the keys that a user is allowed to read/write AclKeys keys_; // Glob patterns for pub/sub channels AclPubSub pub_sub_; // if the user is on/off bool is_active_{false}; std::string namespace_; // if db == std::numeric_limits::max() then all db's. // Otherwise user restricted to the value of db_ size_t db_{std::numeric_limits::max()}; }; } // namespace dfly::acl ================================================ FILE: src/server/acl/user_registry.cc ================================================ // Copyright 2023, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #include "server/acl/user_registry.h" #include #include #include "base/flags.h" #include "facade/facade_types.h" #include "server/acl/acl_commands_def.h" ABSL_DECLARE_FLAG(std::string, requirepass); using namespace util; namespace dfly::acl { void UserRegistry::MaybeAddAndUpdate(std::string_view username, User::UpdateRequest req) { std::unique_lock lock(mu_); auto& user = registry_[username]; user.Update(std::move(req), *cat_to_id_table_, *reverse_cat_table_, *cat_to_commands_table_); } bool UserRegistry::RemoveUser(std::string_view username) { std::unique_lock lock(mu_); return registry_.erase(username); } UserCredentials UserRegistry::GetCredentials(std::string_view username) const { std::shared_lock lock(mu_); auto it = registry_.find(username); if (it == registry_.end()) { return {}; } auto& user = it->second; return {user.AclCategory(), user.AclCommands(), user.Keys(), user.PubSub(), user.Namespace(), user.Db()}; } bool UserRegistry::IsUserActive(std::string_view username) const { std::shared_lock lock(mu_); auto it = registry_.find(username); if (it == registry_.end()) { return false; } return it->second.IsActive(); } bool UserRegistry::AuthUser(std::string_view username, std::string_view password) const { std::shared_lock lock(mu_); const auto& user = registry_.find(username); if (user == registry_.end()) { return false; } return user->second.IsActive() && user->second.HasPassword(password); } UserRegistry::RegistryViewWithLock UserRegistry::GetRegistryWithLock() const { std::shared_lock lock(mu_); return {std::move(lock), registry_}; } UserRegistry::RegistryWithWriteLock UserRegistry::GetRegistryWithWriteLock() { std::unique_lock lock(mu_); return {std::move(lock), registry_}; } UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock lk, const User& user, bool exists) : user(user), exists(exists), registry_lk_(std::move(lk)) { } User::UpdateRequest UserRegistry::DefaultUserUpdateRequest() const { // Assign field by field to supress an annoying compiler warning User::UpdateRequest req; req.passwords = std::vector{{"", false, true}}; req.is_active = true; req.updates = {std::pair{User::Sign::PLUS, acl::ALL}}; req.keys = {User::UpdateKey{"~*", KeyOp::READ_WRITE, true, false}}; req.pub_sub = {User::UpdatePubSub{"", false, true, false}}; return req; } void UserRegistry::Init(const CategoryToIdxStore* cat_to_id_table, const ReverseCategoryIndexTable* reverse_cat_table, const CategoryToCommandsIndexStore* cat_to_commands_table) { // if there exists an acl file to load from, requirepass // will not overwrite the default's user password loaded from // that file. Loading the default's user password from a file // has higher priority than the deprecated flag cat_to_id_table_ = cat_to_id_table; reverse_cat_table_ = reverse_cat_table; cat_to_commands_table_ = cat_to_commands_table; auto default_user = DefaultUserUpdateRequest(); auto maybe_password = absl::GetFlag(FLAGS_requirepass); if (!maybe_password.empty()) { default_user.passwords.front().password = std::move(maybe_password); default_user.passwords.front().nopass = false; } else if (const char* env_var = getenv("DFLY_PASSWORD"); env_var) { default_user.passwords.front().password = env_var; default_user.passwords.front().nopass = false; } else if (const char* env_var = getenv("DFLY_requirepass"); env_var) { default_user.passwords.front().password = env_var; default_user.passwords.front().nopass = false; } MaybeAddAndUpdate("default", std::move(default_user)); } } // namespace dfly::acl ================================================ FILE: src/server/acl/user_registry.h ================================================ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // #pragma once #include #include #include #include #include #include #include "server/acl/user.h" #include "util/fibers/synchronization.h" namespace dfly::acl { class UserRegistry { private: template